diff -r 1f19c6bbf07b -r 3868555d135c src/euclidean/wrap.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/euclidean/wrap.rs Fri May 15 14:46:30 2026 -0500 @@ -0,0 +1,313 @@ +/*! +Wrappers for implemention [`Euclidean`] operations. +*/ + +use crate::euclidean::Euclidean; +use crate::instance::Space; +use crate::types::Float; + +pub trait WrapGuard<'a, F: Float> { + type View<'b>: Euclidean + where + Self: 'b; + fn get_view(&self) -> Self::View<'_>; +} + +pub trait WrapGuardMut<'a, F: Float> { + type ViewMut<'b>: Euclidean + where + Self: 'b; + fn get_view_mut(&mut self) -> Self::ViewMut<'_>; +} + +pub trait Wrapped: Space { + type WrappedField: Float; + type Guard<'a>: WrapGuard<'a, Self::WrappedField> + where + Self: 'a; + type GuardMut<'a>: WrapGuardMut<'a, Self::WrappedField> + where + Self: 'a; + type UnwrappedOutput; + type WrappedOutput; + fn get_guard(&self) -> Self::Guard<'_>; + fn get_guard_mut(&mut self) -> Self::GuardMut<'_>; + fn wrap(output: Self::UnwrappedOutput) -> Self::WrappedOutput; +} + +#[macro_export] +macro_rules! wrap { + // Rust macros are totally fucked up. $trait:path does not work, have to + // manually code paths through $($trait:ident)::+. + (impl_unary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+ for $type { + type Output = ::WrappedOutput; + fn $fn(self) -> Self::Output { + let a = self.get_guard(); + Self::wrap(a.get_view().$fn()) + } + } + }; + (impl_binary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+<$type> for $type { + type Output = ::WrappedOutput; + fn $fn(self, other: $type) -> Self::Output { + let a = self.get_guard(); + let b = other.get_guard(); + Self::wrap(a.get_view().$fn(b.get_view())) + } + } + + impl<'a, $($qual)*> $($trait)::+<$type> for &'a $type { + type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput; + fn $fn(self, other: $type) -> Self::Output { + let a = self.get_guard(); + let b = other.get_guard(); + <$type>::wrap(a.get_view().$fn(b.get_view())) + } + } + + impl<'a, 'b, $($qual)*> $($trait)::+<&'b $type> for &'a $type { + type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput; + fn $fn(self, other: &'b $type) -> Self::Output { + let a = self.get_guard(); + let b = other.get_guard(); + <$type>::wrap(a.get_view().$fn(b.get_view())) + } + } + + impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type { + type Output = ::WrappedOutput; + fn $fn(self, other: &'b $type) -> Self::Output { + let a = self.get_guard(); + let b = other.get_guard(); + Self::wrap(a.get_view().$fn(b.get_view())) + } + } + }; + (impl_scalar $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+<$F> for $type + // where + // $type: $crate::euclidean::wrap::Wrapped, + // //$type::Unwrapped: $($trait)::+, + { + type Output = ::WrappedOutput; + fn $fn(self, t: $F) -> Self::Output { + let a = self.get_guard(); + Self::wrap(a.get_view().$fn(t)) + } + } + + impl<'a, $($qual)*> $($trait)::+<$F> for &'a $type + // where + // $type: $crate::euclidean::wrap::Wrapped, + // //$type::Unwrapped: $($trait)::+, + { + type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput; + fn $fn(self, t: $F) -> Self::Output { + let a = self.get_guard(); + <$type>::wrap(a.get_view().$fn(t)) + } + } + + }; + (impl_scalar_lhs $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+<$type> for $F + // where + // $type: $crate::euclidean::wrap::Wrapped, + // // where + // // $F: $($trait)::+<$type::Unwrapped>, + { + type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput; + fn $fn(self, rhs: $type) -> Self::Output { + let b = rhs.get_guard(); + <$type>::wrap(self.$fn(b.get_view())) + } + } + }; + (impl_binary_mut $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+<$type> for $type { + fn $fn(&mut self, rhs: $type) { + let mut a = self.get_guard_mut(); + let b = rhs.get_guard(); + a.get_view_mut().$fn(b.get_view()) + } + } + + impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type { + fn $fn(&mut self, rhs: &'b $type) { + let mut a = self.get_guard_mut(); + let b = rhs.get_guard(); + a.get_view_mut().$fn(b.get_view()) + } + } + }; + (impl_scalar_mut $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+<$F> for $type + // where + // $type: $crate::euclidean::wrap::Wrapped, + // // where + // // $type::UnwrappedMut: $($trait)::+<$($trait)::+>, + { + fn $fn(&mut self, t: $F) { + let mut a = self.get_guard_mut(); + a.get_view_mut().$fn(t) + } + } + }; + // ($type:ty) => { + // $crate::wrap!(imp<> do $type); + // }; + ($F:ty; $type:ty where $($qual:tt)*) => { + + $crate::wrap!(impl_unary $type, std::ops::Neg, neg where $($qual)*); + $crate::wrap!(impl_binary $type, std::ops::Add, add where $($qual)*); + $crate::wrap!(impl_binary $type, std::ops::Sub, sub where $($qual)*); + $crate::wrap!(impl_scalar $F, $type, std::ops::Mul, mul where $($qual)*); + $crate::wrap!(impl_scalar $F, $type, std::ops::Div, div where $($qual)*); + $crate::wrap!(impl_scalar_lhs $F, $type, std::ops::Mul, mul where $($qual)*); + $crate::wrap!(impl_binary_mut $type, std::ops::AddAssign, add_assign where $($qual)*); + $crate::wrap!(impl_binary_mut $type, std::ops::SubAssign, sub_assign where $($qual)*); + $crate::wrap!(impl_scalar_mut $F, $type, std::ops::MulAssign, mul_assign where $($qual)*); + $crate::wrap!(impl_scalar_mut $F, $type, std::ops::DivAssign, div_assign where $($qual)*); + + $crate::self_ownable!($type where $($qual)*); + + impl<$($qual)*> $crate::norms::Norm<$crate::norms::L2, $F> for $type + { + fn norm(&self, p : $crate::norms::L2) -> $F { + let a = self.get_guard(); + $crate::norms::Norm::norm(&a.get_view(), p) + } + } + + impl<$($qual)*> $crate::norms::Dist<$crate::norms::L2, $F> for $type + { + fn dist>(&self, other : I, p : $crate::norms::L2) -> $F { + other.eval_ref(|other| { + let a = self.get_guard(); + let b = other.get_guard(); + a.get_view().dist(b.get_view(), p) + }) + } + } + + impl<$($qual)*> $crate::norms::Normed<$F> for $type { + type NormExp = $crate::norms::L2; + + fn norm_exponent(&self) -> Self::NormExp { + $crate::norms::L2 + } + } + + impl<$($qual)*> $crate::norms::HasDual<$F> for $type { + type DualSpace = Self; + + fn dual_origin(&self) -> Self { + $crate::linops::VectorSpace::similar_origin(self) + } + } + + impl<$($qual)*> $crate::euclidean::Euclidean<$F> for $type + // where + // Self: $crate::euclidean::wrap::Wrapped + // + Sized + // + std::ops::Mul::Owned> + // + std::ops::MulAssign + // + std::ops::Div::Owned> + // + std::ops::DivAssign + // + std::ops::Add::Owned> + // + std::ops::Sub::Owned> + // + for<'b> std::ops::Add<&'b Self, Output = ::Owned> + // + for<'b> std::ops::Sub<&'b Self, Output = ::Owned> + // + std::ops::AddAssign + // + for<'b> std::ops::AddAssign<&'b Self> + // + std::ops::SubAssign + // + for<'b> std::ops::SubAssign<&'b Self> + // + std::ops::Neg::Owned>, + { + type PrincipalE = Self; + + fn dot>(&self, other: I) -> $F { + other.eval_decompose(|other| { + let a = self.get_guard(); + let b = other.get_guard(); + a.get_view().dot(&b.get_view()) + }) + } + + fn norm2_squared(&self) -> $F { + let a = self.get_guard(); + a.get_view().norm2_squared() + } + + fn dist2_squared>(&self, other: I) -> $F { + other.eval_decompose(|other| { + let a = self.get_guard(); + let b = other.get_guard(); + a.get_view().dist2_squared(b.get_view()) + }) + } + } + + impl<$($qual)*> $crate::linops::VectorSpace for $type + // where + // Self : $crate::euclidean::wrap::Wrapped, + // Self::Unwrapped : $crate::linops::AXPY, + // Self: std::ops::MulAssign + std::ops::DivAssign, + // Self::Unwrapped: std::ops::MulAssign + std::ops::DivAssign, + { + type Field = $F; + type PrincipalV = Self; + + /// Return a similar zero as `self`. + fn similar_origin(&self) -> Self::PrincipalV { + let a = self.get_guard(); + Self::wrap(a.get_view().similar_origin()) + } + } + + impl<$($qual)*> $crate::linops::AXPY for $type + // where + // Self : $crate::euclidean::wrap::Wrapped, + // Self::Unwrapped : $crate::linops::AXPY, + // Self: std::ops::MulAssign + std::ops::DivAssign, + // Self::Unwrapped: std::ops::MulAssign + std::ops::DivAssign, + { + fn axpy>(&mut self, α: $F, x: I, β: $F) { + x.eval_decompose(|other| { + let mut a = self.get_guard_mut(); + let b = other.get_guard(); + $crate::linops::AXPY::axpy(&mut a.get_view_mut(), α, b.get_view(), β) + }) + } + + fn copy_from>(&mut self, x: I) { + x.eval_decompose(|other| { + let mut a = self.get_guard_mut(); + let b = other.get_guard(); + $crate::linops::AXPY::copy_from(&mut a.get_view_mut(), b.get_view()) + }) + } + + fn scale_from>(&mut self, α: $F, x: I) { + x.eval_decompose(|other| { + let mut a = self.get_guard_mut(); + let b = other.get_guard(); + $crate::linops::AXPY::scale_from(&mut a.get_view_mut(), α, b.get_view()) + }) + } + + /// Set self to zero. + fn set_zero(&mut self) { + let mut a = self.get_guard_mut(); + a.get_view_mut().set_zero() + } + } + + impl<$($qual)*> $crate::instance::Space for $type { + type Decomp = $crate::instance::BasicDecomposition; + type Principal = Self; + } + }; +}