src/euclidean/wrap.rs

branch
dev
changeset 147
d6009939e832
parent 146
3f9a03f95457
child 149
2f1798c65fd6
--- a/src/euclidean/wrap.rs	Sat Aug 30 22:25:28 2025 -0500
+++ b/src/euclidean/wrap.rs	Sat Aug 30 22:43:37 2025 -0500
@@ -2,118 +2,102 @@
 Wrappers for implemention [`Euclidean`] operations.
 */
 
-use super::Euclidean;
-use crate::instance::{Decomposition, DecompositionMut, Instance, InstanceMut, MyCow};
-use crate::linops::AXPY;
-use crate::loc::Loc;
-use crate::mapping::Space;
-use crate::norms::{HasDual, Norm, NormExponent, Normed, PairNorm, L2};
-use crate::types::{Float, Num};
-use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
-use serde::{Deserialize, Serialize};
-use std::clone::Clone;
+pub trait Wrapped {
+    type Unwrapped;
+    type UnwrappedMut;
+    type UnwrappedOutput;
+    fn get_view(&self) -> Self::Unwrapped;
+    fn get_view_mut(&mut self) -> Self::UnwrappedMut;
+    fn wrap(output: Self::UnwrappedOutput) -> Self;
+}
 
-macro_rules! impl_unary {
-    ($type:ty, $trait:ident, $fn:ident) => {
-        impl $trait for $type {
+#[macro_export]
+macro_rules! wrap {
+    (impl_unary<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
+        impl<$($qual)*> $trait for $type {
             type Output = Self::WrappedOutput;
             fn $fn(self) -> Self::Output {
-                $type::wrap(self.get_view().$fn())
+                Self::wrap(self.get_view().$fn())
             }
         }
     };
-}
-
-macro_rules! impl_binary {
-    ($type:ty, $trait:ident, $fn:ident) => {
-        impl $trait<$type> for $type {
+    (impl_binary<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
+        impl<$($qual)*> $trait<$type> for $type {
             type Output = Self::WrappedOutput;
             fn $fn(self, other: $type) -> Self::Output {
-                $type::wrap(self.get_view().$fn(other.get_view()))
+                Self::wrap(self.get_view().$fn(other.get_view()))
             }
         }
 
-        impl<'a> $trait<$type> for &'a $type {
+        impl<'a, $($qual)*> $trait<$type> for &'a $type {
             type Output = Self::WrappedOutput;
             fn $fn(self, other: $type) -> Self::Output {
-                $type::wrap(self.get_view().$fn(other.get_view()))
+                Self::wrap(self.get_view().$fn(other.get_view()))
             }
         }
 
-        impl<'a, 'b> $trait<&'b $type> for &'a $type {
+        impl<'a, 'b, $($qual)*> $trait<&'b $type> for &'a $type {
             type Output = Self::WrappedOutput;
             fn $fn(self, other: $type) -> Self::Output {
-                $type::wrap(self.get_view().$fn(other.get_view()))
+                Self::wrap(self.get_view().$fn(other.get_view()))
             }
         }
 
-        impl<'b> $trait<&'b $type> for $type {
+        impl<'b, $($qual)*> $trait<&'b $type> for $type {
             type Output = Self::WrappedOutput;
             fn $fn(self, other: $type) -> Self::Output {
-                $type::wrap(self.get_view().$fn(other.get_view()))
+                Self::wrap(self.get_view().$fn(other.get_view()))
             }
         }
     };
-}
-
-macro_rules! impl_scalar {
-    ($type:ty, $trait:ident, $fn:ident) => {
-        impl<F: Num> $trait<F> for $type,
+    (impl_scalar<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
+        impl<$($qual)*, F: Num> $trait<F> for $type,
         where
             $type::Unwrapped: $trait<F>,
         {
             type Output = Self::WrappedOutput;
             fn $fn(self, t: F) -> Self::Output {
-                $type::wrap(self.get_view().$fn(t))
+                Self::wrap(self.get_view().$fn(t))
             }
         }
 
-        impl<'a, F: Num> $trait<F> for &'a $type,
+        impl<'a, $($qual)*, F: Num> $trait<F> for &'a $type,
         where
             $type::Unwrapped: $trait<F>,
         {
             type Output = Self::WrappedOutput;
             fn $fn(self, t: F) -> Self::Output {
-                $type::wrap(self.get_view().$fn(t))
+                Self::wrap(self.get_view().$fn(t))
             }
         }
 
     };
-}
-
-macro_rules! impl_scalar_lhs {
-    ($type:ty, $trait:ident, $fn:ident, $F:ty) => {
-        impl<A, B> $trait<$type> for $F
+    (impl_scalar_lhs<$($qual:tt)*> $type:ty, $trait:path, $fn:ident, $F:ty) => {
+        impl<$($qual)*> $trait<$type> for $F
         where
             $F: $type::Unwrapped,
         {
             type Output = Self::WrappedOutput;
             fn $fn(self, rhs: $type) -> Self::Output {
-                $type::wrap(self.$fn(rhs.get_view()))
+                Self::wrap(self.$fn(rhs.get_view()))
             }
         }
     };
-}
-
-macro_rules! impl_binary_mut {
-    ($type:ty, $trait:ident, $fn:ident) => {
-        impl $trait<$type> for $type {
+    (impl_binary_mut<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
+        impl<$($qual)*> $trait<$type> for $type {
             fn $fn(&mut self, rhs: $type) {
                 self.get_view_mut().$fn(rhs.get_view())
             }
         }
 
-        impl<'b> $trait<&'b $type> for $type {
+        impl<'b, $($qual)*> $trait<&'b $type> for $type {
             fn $fn(&mut self, rhs: $type) {
                 self.get_view_mut().$fn(rhs.get_view())
             }
         }
     };
-}
-
-macro_rules! impl_scalar_mut {
-    ($type:ty, $trait:ident, $fn:ident) => {
-        impl $trait<F> for $type
+    (impl_scalar_mut<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
+        impl<$($qual)*> $trait<F> for $type
         where
             $type::UnwrappedMut: $trait<F>,
         {
@@ -122,45 +106,46 @@
             }
         }
     };
-}
-
-macro_rules! wrap {
     ($type:ty) => {
-        impl_unary!($type, Neg, neg);
-        impl_binary!($type, Add, add);
-        impl_binary!($type, Sub, sub);
-        impl_scalar!($type, Mul, mul);
-        impl_scalar!($type, Div, div);
-        impl_scalar_lhs!($type, Mul, mul, f32);
-        impl_scalar_lhs!($type, Mul, mul, f64);
-        impl_scalar_lhs!($type, Div, div, f32);
-        impl_scalar_lhs!($type, Div, div, f64);
-        impl_binary_mut!($type, AddAssign, add_assign);
-        impl_binary_mut!($type, SubAssign, sub_assign);
-        impl_scalar_mut!($type, MulAssign, mul_assign);
-        impl_scalar_mut!($type, DivAssign, div_assign);
+        $crate::wrap!(imp<> do $type);
+    };
+    (imp<$($qual:tt)*> $type:ty) => {
+        $crate::wrap!(impl_unary<$($qual)*> $type, std::ops::Neg, neg);
+        $crate::wrap!(impl_binary<$($qual)*> $type, std::ops::Add, add);
+        $crate::wrap!(impl_binary<$($qual)*> $type, std::ops::Sub, sub);
+        $crate::wrap!(impl_scalar<$($qual)*> $type, std::ops::Mul, mul);
+        $crate::wrap!(impl_scalar<$($qual)*> $type, std::ops::Div, div);
+        $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Mul, mul, f32);
+        $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Mul, mul, f64);
+        $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Div, div, f32);
+        $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Div, div, f64);
+        $crate::wrap!(impl_binary_mut<$($qual)*> $type, std::ops::AddAssign, add_assign);
+        $crate::wrap!(impl_binary_mut<$($qual)*> $type, std::ops::SubAssign, sub_assign);
+        $crate::wrap!(impl_scalar_mut<$($qual)*> $type, std::ops::MulAssign, mul_assign);
+        $crate::wrap!(impl_scalar_mut<$($qual)*> $type, std::ops::DivAssign, div_assign);
 
         /// We only support 'closed' `Euclidean` `Pair`s, as more general ones cause
         /// compiler overflows.
-        impl<F: Float> Euclidean<F> for $type
+        impl<$($qual)* F: $crate::types::Float> $crate::euclidean::Euclidean<F> for $type
         where
             //Pair<A, B>: Euclidean<F>,
-            Self: Sized
-                + Mul<F, Output = <Self as AXPY>::Owned>
-                + MulAssign<F>
-                + Div<F, Output = <Self as AXPY>::Owned>
-                + DivAssign<F>
-                + Add<Self, Output = <Self as AXPY>::Owned>
-                + Sub<Self, Output = <Self as AXPY>::Owned>
-                + for<'b> Add<&'b Self, Output = <Self as AXPY>::Owned>
-                + for<'b> Sub<&'b Self, Output = <Self as AXPY>::Owned>
-                + AddAssign<Self>
-                + for<'b> AddAssign<&'b Self>
-                + SubAssign<Self>
-                + for<'b> SubAssign<&'b Self>
-                + Neg<Output = <Self as AXPY>::Owned>,
+            Self: $crate::euclidean::wrap::Wrapped
+                + Sized
+                + std::ops::Mul<F, Output = <Self as $crate::linops::AXPY>::Owned>
+                + std::ops::MulAssign<F>
+                + std::ops::Div<F, Output = <Self as $crate::linops::AXPY>::Owned>
+                + std::ops::DivAssign<F>
+                + std::ops::Add<Self, Output = <Self as $crate::linops::AXPY>::Owned>
+                + std::ops::Sub<Self, Output = <Self as $crate::linops::AXPY>::Owned>
+                + for<'b> std::ops::Add<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
+                + for<'b> std::ops::Sub<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
+                + std::ops::AddAssign<Self>
+                + for<'b> std::ops::AddAssign<&'b Self>
+                + std::ops::SubAssign<Self>
+                + for<'b> std::ops::SubAssign<&'b Self>
+                + std::ops::Neg<Output = <Self as $crate::linops::AXPY>::Owned>,
         {
-            fn dot<I: Instance<Self>>(&self, other: I) -> F {
+            fn dot<I: $crate::instance::Instance<Self>>(&self, other: I) -> F {
                 other.eval_decompose(|x| self.get_view().dot(x.get_view()))
             }
 
@@ -168,34 +153,34 @@
                 self.get_view().norm2_squared()
             }
 
-            fn dist2_squared<I: Instance<Self>>(&self, other: I) -> F {
+            fn dist2_squared<I: $crate::instance::Instance<Self>>(&self, other: I) -> F {
                 other.eval_decompose(|x| self.get_view().dist2_squared(x.get_view()))
             }
         }
 
-        impl<F> AXPY for $type
+        impl<$($qual)* F : $crate::types::Float> $crate::linops::AXPY for $type
         where
-            Self::Unwrapped : AXPY<Field = $F>,
-            F: Num,
-            Self: MulAssign<F> + DivAssign<F>,
-            Self::Unwrapped: MulAssign<F> + DivAssign<F>,
+            Self : $crate::euclidean::wrap::Wrapped,
+            Self::Unwrapped : $crate::linops::AXPY<Field = F>,
+            Self: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
+            Self::Unwrapped: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
         {
             type Field = F;
-            type Owned = Pair<A::Owned, B::Owned>;
+            type Owned = Self;
 
-            fn axpy<I: Instance<Self>>(&mut self, α: F, x: I, β: F) {
+            fn axpy<I: $crate::instance::Instance<Self>>(&mut self, α: F, x: I, β: F) {
                 x.eval_decompose(|v| {
                     self.get_mut_view().axpy(α, v.get_view(), β)
                 })
             }
 
-            fn copy_from<I: Instance<Self>>(&mut self, x: I) {
+            fn copy_from<I: $crate::instance::Instance<Self>>(&mut self, x: I) {
                 x.eval_decompose(|Pair(u, v)| {
                     self.get_mut_view().copy_from(v.get_view())
                 })
             }
 
-            fn scale_from<I: Instance<Self>>(&mut self, α: F, x: I) {
+            fn scale_from<I: $crate::instance::Instance<Self>>(&mut self, α: F, x: I) {
                 x.eval_decompose(|v| {
                     self.get_mut_view().scale_from(α, v.get_view())
                 })
@@ -203,7 +188,7 @@
 
             /// Return a similar zero as `self`.
             fn similar_origin(&self) -> Self::Owned {
-                $type::wrap(self.get_view().similar_origin())
+                Self::wrap(self.get_view().similar_origin())
             }
 
             /// Set self to zero.
@@ -212,8 +197,9 @@
             }
         }
 
-        impl<A: Space, B: Space> Space for $type {
-            type Decomp = $type::Unwrapped::Decomp
+        impl<$($qual)*> $crate::instance::Space for $type
+        where Self : $crate::euclidean::wrap::Wrapped {
+            type Decomp = Self::Unwrapped::Decomp;
         }
     };
 }

mercurial