src/euclidean/wrap.rs

changeset 198
3868555d135c
parent 183
d077dff509f1
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/euclidean/wrap.rs	Fri May 15 14:46:30 2026 -0500
@@ -0,0 +1,313 @@
+/*!
+Wrappers for implemention [`Euclidean`] operations.
+*/
+
+use crate::euclidean::Euclidean;
+use crate::instance::Space;
+use crate::types::Float;
+
+pub trait WrapGuard<'a, F: Float> {
+    type View<'b>: Euclidean<F>
+    where
+        Self: 'b;
+    fn get_view(&self) -> Self::View<'_>;
+}
+
+pub trait WrapGuardMut<'a, F: Float> {
+    type ViewMut<'b>: Euclidean<F>
+    where
+        Self: 'b;
+    fn get_view_mut(&mut self) -> Self::ViewMut<'_>;
+}
+
+pub trait Wrapped: Space {
+    type WrappedField: Float;
+    type Guard<'a>: WrapGuard<'a, Self::WrappedField>
+    where
+        Self: 'a;
+    type GuardMut<'a>: WrapGuardMut<'a, Self::WrappedField>
+    where
+        Self: 'a;
+    type UnwrappedOutput;
+    type WrappedOutput;
+    fn get_guard(&self) -> Self::Guard<'_>;
+    fn get_guard_mut(&mut self) -> Self::GuardMut<'_>;
+    fn wrap(output: Self::UnwrappedOutput) -> Self::WrappedOutput;
+}
+
+#[macro_export]
+macro_rules! wrap {
+    // Rust macros are totally fucked up. $trait:path does not work, have to
+    // manually code paths through $($trait:ident)::+.
+    (impl_unary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
+        impl<$($qual)*> $($trait)::+ for $type {
+            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
+            fn $fn(self) -> Self::Output {
+                let a = self.get_guard();
+                Self::wrap(a.get_view().$fn())
+            }
+        }
+    };
+    (impl_binary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
+        impl<$($qual)*> $($trait)::+<$type> for $type {
+            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
+            fn $fn(self, other: $type) -> Self::Output {
+                let a = self.get_guard();
+                let b = other.get_guard();
+                Self::wrap(a.get_view().$fn(b.get_view()))
+            }
+        }
+
+        impl<'a, $($qual)*> $($trait)::+<$type> for &'a $type {
+            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
+            fn $fn(self, other: $type) -> Self::Output {
+                let a = self.get_guard();
+                let b = other.get_guard();
+                <$type>::wrap(a.get_view().$fn(b.get_view()))
+            }
+        }
+
+        impl<'a, 'b, $($qual)*> $($trait)::+<&'b $type> for &'a $type {
+            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
+            fn $fn(self, other: &'b $type) -> Self::Output {
+                let a = self.get_guard();
+                let b = other.get_guard();
+                <$type>::wrap(a.get_view().$fn(b.get_view()))
+            }
+        }
+
+        impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type {
+            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
+            fn $fn(self, other: &'b $type) -> Self::Output {
+                let a = self.get_guard();
+                let b = other.get_guard();
+                Self::wrap(a.get_view().$fn(b.get_view()))
+            }
+        }
+    };
+    (impl_scalar $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
+        impl<$($qual)*> $($trait)::+<$F> for $type
+        // where
+        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
+        //     //$type::Unwrapped: $($trait)::+<F>,
+        {
+            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
+            fn $fn(self, t: $F) -> Self::Output {
+                let a = self.get_guard();
+                Self::wrap(a.get_view().$fn(t))
+            }
+        }
+
+        impl<'a, $($qual)*> $($trait)::+<$F> for &'a $type
+        // where
+        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
+        //     //$type::Unwrapped: $($trait)::+<F>,
+        {
+            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
+            fn $fn(self, t: $F) -> Self::Output {
+                let a = self.get_guard();
+                <$type>::wrap(a.get_view().$fn(t))
+            }
+        }
+
+    };
+    (impl_scalar_lhs $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
+        impl<$($qual)*> $($trait)::+<$type> for $F
+        // where
+        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = $F>,
+        // // where
+        // //     $F: $($trait)::+<$type::Unwrapped>,
+        {
+            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
+            fn $fn(self, rhs: $type) -> Self::Output {
+                let b = rhs.get_guard();
+                <$type>::wrap(self.$fn(b.get_view()))
+            }
+        }
+    };
+    (impl_binary_mut $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
+        impl<$($qual)*> $($trait)::+<$type> for $type {
+            fn $fn(&mut self, rhs: $type) {
+                let mut a = self.get_guard_mut();
+                let b = rhs.get_guard();
+                a.get_view_mut().$fn(b.get_view())
+            }
+        }
+
+        impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type {
+            fn $fn(&mut self, rhs: &'b $type) {
+                let mut a = self.get_guard_mut();
+                let b = rhs.get_guard();
+                a.get_view_mut().$fn(b.get_view())
+            }
+        }
+    };
+    (impl_scalar_mut $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
+        impl<$($qual)*> $($trait)::+<$F> for $type
+        // where
+        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
+        // // where
+        // //     $type::UnwrappedMut: $($trait)::+<$($trait)::+<F>>,
+        {
+            fn $fn(&mut self, t: $F) {
+                let mut a = self.get_guard_mut();
+                a.get_view_mut().$fn(t)
+            }
+        }
+    };
+    // ($type:ty) => {
+    //     $crate::wrap!(imp<> do $type);
+    // };
+    ($F:ty; $type:ty where $($qual:tt)*) => {
+
+        $crate::wrap!(impl_unary $type, std::ops::Neg, neg where $($qual)*);
+        $crate::wrap!(impl_binary $type, std::ops::Add, add where $($qual)*);
+        $crate::wrap!(impl_binary $type, std::ops::Sub, sub where $($qual)*);
+        $crate::wrap!(impl_scalar $F, $type, std::ops::Mul, mul where $($qual)*);
+        $crate::wrap!(impl_scalar $F, $type, std::ops::Div, div where $($qual)*);
+        $crate::wrap!(impl_scalar_lhs $F, $type, std::ops::Mul, mul where $($qual)*);
+        $crate::wrap!(impl_binary_mut $type, std::ops::AddAssign, add_assign where $($qual)*);
+        $crate::wrap!(impl_binary_mut $type, std::ops::SubAssign, sub_assign where $($qual)*);
+        $crate::wrap!(impl_scalar_mut $F, $type, std::ops::MulAssign, mul_assign where $($qual)*);
+        $crate::wrap!(impl_scalar_mut $F, $type, std::ops::DivAssign, div_assign where $($qual)*);
+
+        $crate::self_ownable!($type where $($qual)*);
+
+        impl<$($qual)*> $crate::norms::Norm<$crate::norms::L2, $F> for $type
+        {
+            fn norm(&self, p : $crate::norms::L2) -> $F {
+                let a = self.get_guard();
+                $crate::norms::Norm::norm(&a.get_view(), p)
+            }
+        }
+
+        impl<$($qual)*> $crate::norms::Dist<$crate::norms::L2, $F> for $type
+        {
+            fn dist<I: $crate::instance::Instance<Self>>(&self, other : I, p : $crate::norms::L2) -> $F {
+                other.eval_ref(|other| {
+                    let a = self.get_guard();
+                    let b = other.get_guard();
+                    a.get_view().dist(b.get_view(), p)
+                })
+            }
+        }
+
+        impl<$($qual)*> $crate::norms::Normed<$F> for $type {
+            type NormExp = $crate::norms::L2;
+
+            fn norm_exponent(&self) -> Self::NormExp {
+                $crate::norms::L2
+            }
+        }
+
+        impl<$($qual)*> $crate::norms::HasDual<$F> for $type {
+            type DualSpace = Self;
+
+            fn dual_origin(&self) -> Self {
+                $crate::linops::VectorSpace::similar_origin(self)
+            }
+        }
+
+        impl<$($qual)*> $crate::euclidean::Euclidean<$F> for $type
+        // where
+        //     Self: $crate::euclidean::wrap::Wrapped<WrappedField = $F>
+        //         + Sized
+        //         + std::ops::Mul<F, Output = <Self as $crate::linops::AXPY>::Owned>
+        //         + std::ops::MulAssign<F>
+        //         + std::ops::Div<F, Output = <Self as $crate::linops::AXPY>::Owned>
+        //         + std::ops::DivAssign<F>
+        //         + std::ops::Add<Self, Output = <Self as $crate::linops::AXPY>::Owned>
+        //         + std::ops::Sub<Self, Output = <Self as $crate::linops::AXPY>::Owned>
+        //         + for<'b> std::ops::Add<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
+        //         + for<'b> std::ops::Sub<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
+        //         + std::ops::AddAssign<Self>
+        //         + for<'b> std::ops::AddAssign<&'b Self>
+        //         + std::ops::SubAssign<Self>
+        //         + for<'b> std::ops::SubAssign<&'b Self>
+        //         + std::ops::Neg<Output = <Self as $crate::linops::AXPY>::Owned>,
+        {
+            type PrincipalE = Self;
+
+            fn dot<I: $crate::instance::Instance<Self>>(&self, other: I) -> $F {
+                other.eval_decompose(|other| {
+                    let a = self.get_guard();
+                    let b = other.get_guard();
+                    a.get_view().dot(&b.get_view())
+                })
+            }
+
+            fn norm2_squared(&self) -> $F {
+                let a = self.get_guard();
+                a.get_view().norm2_squared()
+            }
+
+            fn dist2_squared<I: $crate::instance::Instance<Self>>(&self, other: I) -> $F {
+                other.eval_decompose(|other| {
+                    let a = self.get_guard();
+                    let b = other.get_guard();
+                    a.get_view().dist2_squared(b.get_view())
+                })
+            }
+        }
+
+        impl<$($qual)*> $crate::linops::VectorSpace for $type
+        // where
+        //     Self : $crate::euclidean::wrap::Wrapped<WrappedField = $F>,
+        //     Self::Unwrapped : $crate::linops::AXPY<Field = F>,
+        //     Self: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
+        //     Self::Unwrapped: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
+        {
+            type Field = $F;
+            type PrincipalV = Self;
+
+            /// Return a similar zero as `self`.
+            fn similar_origin(&self) -> Self::PrincipalV {
+                let a = self.get_guard();
+                Self::wrap(a.get_view().similar_origin())
+            }
+        }
+
+        impl<$($qual)*> $crate::linops::AXPY for $type
+        // where
+        //     Self : $crate::euclidean::wrap::Wrapped<WrappedField = $F>,
+        //     Self::Unwrapped : $crate::linops::AXPY<Field = F>,
+        //     Self: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
+        //     Self::Unwrapped: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
+        {
+             fn axpy<I: $crate::instance::Instance<Self>>(&mut self, α: $F, x: I, β: $F) {
+                x.eval_decompose(|other| {
+                    let mut a = self.get_guard_mut();
+                    let b = other.get_guard();
+                    $crate::linops::AXPY::axpy(&mut a.get_view_mut(), α, b.get_view(), β)
+                })
+            }
+
+            fn copy_from<I: $crate::instance::Instance<Self>>(&mut self, x: I) {
+                x.eval_decompose(|other| {
+                    let mut a = self.get_guard_mut();
+                    let b = other.get_guard();
+                    $crate::linops::AXPY::copy_from(&mut a.get_view_mut(), b.get_view())
+                })
+            }
+
+            fn scale_from<I: $crate::instance::Instance<Self>>(&mut self, α: $F, x: I) {
+                x.eval_decompose(|other| {
+                    let mut a = self.get_guard_mut();
+                    let b = other.get_guard();
+                    $crate::linops::AXPY::scale_from(&mut a.get_view_mut(), α, b.get_view())
+                })
+            }
+
+            /// Set self to zero.
+            fn set_zero(&mut self) {
+                let mut a = self.get_guard_mut();
+                a.get_view_mut().set_zero()
+            }
+        }
+
+        impl<$($qual)*> $crate::instance::Space for $type {
+            type Decomp = $crate::instance::BasicDecomposition;
+            type Principal = Self;
+        }
+    };
+}

mercurial