src/euclidean/wrap.rs

branch
dev
changeset 183
d077dff509f1
parent 179
724413fc8d17
--- a/src/euclidean/wrap.rs	Fri Sep 05 13:30:53 2025 -0500
+++ b/src/euclidean/wrap.rs	Sat Sep 06 23:29:34 2025 -0500
@@ -6,14 +6,32 @@
 use crate::instance::Space;
 use crate::types::Float;
 
+pub trait WrapGuard<'a, F: Float> {
+    type View<'b>: Euclidean<F>
+    where
+        Self: 'b;
+    fn get_view(&self) -> Self::View<'_>;
+}
+
+pub trait WrapGuardMut<'a, F: Float> {
+    type ViewMut<'b>: Euclidean<F>
+    where
+        Self: 'b;
+    fn get_view_mut(&mut self) -> Self::ViewMut<'_>;
+}
+
 pub trait Wrapped: Space {
     type WrappedField: Float;
-    type Unwrapped: Euclidean<Self::WrappedField>;
-    type UnwrappedMut: Euclidean<Self::WrappedField>;
-    type UnwrappedOutput: Euclidean<Self::WrappedField>;
+    type Guard<'a>: WrapGuard<'a, Self::WrappedField>
+    where
+        Self: 'a;
+    type GuardMut<'a>: WrapGuardMut<'a, Self::WrappedField>
+    where
+        Self: 'a;
+    type UnwrappedOutput;
     type WrappedOutput;
-    fn get_view(&self) -> Self::Unwrapped;
-    fn get_view_mut(&mut self) -> Self::UnwrappedMut;
+    fn get_guard(&self) -> Self::Guard<'_>;
+    fn get_guard_mut(&mut self) -> Self::GuardMut<'_>;
     fn wrap(output: Self::UnwrappedOutput) -> Self::WrappedOutput;
 }
 
@@ -25,7 +43,8 @@
         impl<$($qual)*> $($trait)::+ for $type {
             type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
             fn $fn(self) -> Self::Output {
-                Self::wrap(self.get_view().$fn())
+                let a = self.get_guard();
+                Self::wrap(a.get_view().$fn())
             }
         }
     };
@@ -33,28 +52,36 @@
         impl<$($qual)*> $($trait)::+<$type> for $type {
             type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
             fn $fn(self, other: $type) -> Self::Output {
-                Self::wrap(self.get_view().$fn(other.get_view()))
+                let a = self.get_guard();
+                let b = other.get_guard();
+                Self::wrap(a.get_view().$fn(b.get_view()))
             }
         }
 
         impl<'a, $($qual)*> $($trait)::+<$type> for &'a $type {
             type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
             fn $fn(self, other: $type) -> Self::Output {
-                <$type>::wrap(self.get_view().$fn(other.get_view()))
+                let a = self.get_guard();
+                let b = other.get_guard();
+                <$type>::wrap(a.get_view().$fn(b.get_view()))
             }
         }
 
         impl<'a, 'b, $($qual)*> $($trait)::+<&'b $type> for &'a $type {
             type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
             fn $fn(self, other: &'b $type) -> Self::Output {
-                <$type>::wrap(self.get_view().$fn(other.get_view()))
+                let a = self.get_guard();
+                let b = other.get_guard();
+                <$type>::wrap(a.get_view().$fn(b.get_view()))
             }
         }
 
         impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type {
             type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
             fn $fn(self, other: &'b $type) -> Self::Output {
-                Self::wrap(self.get_view().$fn(other.get_view()))
+                let a = self.get_guard();
+                let b = other.get_guard();
+                Self::wrap(a.get_view().$fn(b.get_view()))
             }
         }
     };
@@ -66,7 +93,8 @@
         {
             type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
             fn $fn(self, t: $F) -> Self::Output {
-                Self::wrap(self.get_view().$fn(t))
+                let a = self.get_guard();
+                Self::wrap(a.get_view().$fn(t))
             }
         }
 
@@ -77,7 +105,8 @@
         {
             type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
             fn $fn(self, t: $F) -> Self::Output {
-                <$type>::wrap(self.get_view().$fn(t))
+                let a = self.get_guard();
+                <$type>::wrap(a.get_view().$fn(t))
             }
         }
 
@@ -91,20 +120,25 @@
         {
             type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
             fn $fn(self, rhs: $type) -> Self::Output {
-                <$type>::wrap(self.$fn(rhs.get_view()))
+                let b = rhs.get_guard();
+                <$type>::wrap(self.$fn(b.get_view()))
             }
         }
     };
     (impl_binary_mut $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
         impl<$($qual)*> $($trait)::+<$type> for $type {
             fn $fn(&mut self, rhs: $type) {
-                self.get_view_mut().$fn(rhs.get_view())
+                let mut a = self.get_guard_mut();
+                let b = rhs.get_guard();
+                a.get_view_mut().$fn(b.get_view())
             }
         }
 
         impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type {
             fn $fn(&mut self, rhs: &'b $type) {
-                self.get_view_mut().$fn(rhs.get_view())
+                let mut a = self.get_guard_mut();
+                let b = rhs.get_guard();
+                a.get_view_mut().$fn(b.get_view())
             }
         }
     };
@@ -116,7 +150,8 @@
         // //     $type::UnwrappedMut: $($trait)::+<$($trait)::+<F>>,
         {
             fn $fn(&mut self, t: $F) {
-                self.get_view_mut().$fn(t)
+                let mut a = self.get_guard_mut();
+                a.get_view_mut().$fn(t)
             }
         }
     };
@@ -141,14 +176,19 @@
         impl<$($qual)*> $crate::norms::Norm<$crate::norms::L2, $F> for $type
         {
             fn norm(&self, p : $crate::norms::L2) -> $F {
-                $crate::norms::Norm::norm(&self.get_view(), p)
+                let a = self.get_guard();
+                $crate::norms::Norm::norm(&a.get_view(), p)
             }
         }
 
         impl<$($qual)*> $crate::norms::Dist<$crate::norms::L2, $F> for $type
         {
             fn dist<I: $crate::instance::Instance<Self>>(&self, other : I, p : $crate::norms::L2) -> $F {
-                other.eval_ref(|x| self.get_view().dist(x.get_view(), p))
+                other.eval_ref(|other| {
+                    let a = self.get_guard();
+                    let b = other.get_guard();
+                    a.get_view().dist(b.get_view(), p)
+                })
             }
         }
 
@@ -189,15 +229,24 @@
             type PrincipalE = Self;
 
             fn dot<I: $crate::instance::Instance<Self>>(&self, other: I) -> $F {
-                other.eval_decompose(|x| self.get_view().dot(&x.get_view()))
+                other.eval_decompose(|other| {
+                    let a = self.get_guard();
+                    let b = other.get_guard();
+                    a.get_view().dot(&b.get_view())
+                })
             }
 
             fn norm2_squared(&self) -> $F {
-                self.get_view().norm2_squared()
+                let a = self.get_guard();
+                a.get_view().norm2_squared()
             }
 
             fn dist2_squared<I: $crate::instance::Instance<Self>>(&self, other: I) -> $F {
-                other.eval_decompose(|x| self.get_view().dist2_squared(x.get_view()))
+                other.eval_decompose(|other| {
+                    let a = self.get_guard();
+                    let b = other.get_guard();
+                    a.get_view().dist2_squared(b.get_view())
+                })
             }
         }
 
@@ -213,7 +262,8 @@
 
             /// Return a similar zero as `self`.
             fn similar_origin(&self) -> Self::PrincipalV {
-                Self::wrap(self.get_view().similar_origin())
+                let a = self.get_guard();
+                Self::wrap(a.get_view().similar_origin())
             }
         }
 
@@ -225,26 +275,33 @@
         //     Self::Unwrapped: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
         {
              fn axpy<I: $crate::instance::Instance<Self>>(&mut self, α: $F, x: I, β: $F) {
-                x.eval_decompose(|v| {
-                    $crate::linops::AXPY::axpy(&mut self.get_view_mut(), α, v.get_view(), β)
+                x.eval_decompose(|other| {
+                    let mut a = self.get_guard_mut();
+                    let b = other.get_guard();
+                    $crate::linops::AXPY::axpy(&mut a.get_view_mut(), α, b.get_view(), β)
                 })
             }
 
             fn copy_from<I: $crate::instance::Instance<Self>>(&mut self, x: I) {
-                x.eval_decompose(|v| {
-                    $crate::linops::AXPY::copy_from(&mut self.get_view_mut(), v.get_view())
+                x.eval_decompose(|other| {
+                    let mut a = self.get_guard_mut();
+                    let b = other.get_guard();
+                    $crate::linops::AXPY::copy_from(&mut a.get_view_mut(), b.get_view())
                 })
             }
 
             fn scale_from<I: $crate::instance::Instance<Self>>(&mut self, α: $F, x: I) {
-                x.eval_decompose(|v| {
-                    $crate::linops::AXPY::scale_from(&mut self.get_view_mut(), α, v.get_view())
+                x.eval_decompose(|other| {
+                    let mut a = self.get_guard_mut();
+                    let b = other.get_guard();
+                    $crate::linops::AXPY::scale_from(&mut a.get_view_mut(), α, b.get_view())
                 })
             }
 
             /// Set self to zero.
             fn set_zero(&mut self) {
-                self.get_view_mut().set_zero()
+                let mut a = self.get_guard_mut();
+                a.get_view_mut().set_zero()
             }
         }
 

mercurial