diff -r 3f9a03f95457 -r d6009939e832 src/euclidean/wrap.rs --- a/src/euclidean/wrap.rs Sat Aug 30 22:25:28 2025 -0500 +++ b/src/euclidean/wrap.rs Sat Aug 30 22:43:37 2025 -0500 @@ -2,118 +2,102 @@ Wrappers for implemention [`Euclidean`] operations. */ -use super::Euclidean; -use crate::instance::{Decomposition, DecompositionMut, Instance, InstanceMut, MyCow}; -use crate::linops::AXPY; -use crate::loc::Loc; -use crate::mapping::Space; -use crate::norms::{HasDual, Norm, NormExponent, Normed, PairNorm, L2}; -use crate::types::{Float, Num}; -use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use serde::{Deserialize, Serialize}; -use std::clone::Clone; +pub trait Wrapped { + type Unwrapped; + type UnwrappedMut; + type UnwrappedOutput; + fn get_view(&self) -> Self::Unwrapped; + fn get_view_mut(&mut self) -> Self::UnwrappedMut; + fn wrap(output: Self::UnwrappedOutput) -> Self; +} -macro_rules! impl_unary { - ($type:ty, $trait:ident, $fn:ident) => { - impl $trait for $type { +#[macro_export] +macro_rules! wrap { + (impl_unary<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => { + impl<$($qual)*> $trait for $type { type Output = Self::WrappedOutput; fn $fn(self) -> Self::Output { - $type::wrap(self.get_view().$fn()) + Self::wrap(self.get_view().$fn()) } } }; -} - -macro_rules! impl_binary { - ($type:ty, $trait:ident, $fn:ident) => { - impl $trait<$type> for $type { + (impl_binary<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => { + impl<$($qual)*> $trait<$type> for $type { type Output = Self::WrappedOutput; fn $fn(self, other: $type) -> Self::Output { - $type::wrap(self.get_view().$fn(other.get_view())) + Self::wrap(self.get_view().$fn(other.get_view())) } } - impl<'a> $trait<$type> for &'a $type { + impl<'a, $($qual)*> $trait<$type> for &'a $type { type Output = Self::WrappedOutput; fn $fn(self, other: $type) -> Self::Output { - $type::wrap(self.get_view().$fn(other.get_view())) + Self::wrap(self.get_view().$fn(other.get_view())) } } - impl<'a, 'b> $trait<&'b $type> for &'a $type { + impl<'a, 'b, $($qual)*> $trait<&'b $type> for &'a $type { type Output = Self::WrappedOutput; fn $fn(self, other: $type) -> Self::Output { - $type::wrap(self.get_view().$fn(other.get_view())) + Self::wrap(self.get_view().$fn(other.get_view())) } } - impl<'b> $trait<&'b $type> for $type { + impl<'b, $($qual)*> $trait<&'b $type> for $type { type Output = Self::WrappedOutput; fn $fn(self, other: $type) -> Self::Output { - $type::wrap(self.get_view().$fn(other.get_view())) + Self::wrap(self.get_view().$fn(other.get_view())) } } }; -} - -macro_rules! impl_scalar { - ($type:ty, $trait:ident, $fn:ident) => { - impl $trait for $type, + (impl_scalar<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => { + impl<$($qual)*, F: Num> $trait for $type, where $type::Unwrapped: $trait, { type Output = Self::WrappedOutput; fn $fn(self, t: F) -> Self::Output { - $type::wrap(self.get_view().$fn(t)) + Self::wrap(self.get_view().$fn(t)) } } - impl<'a, F: Num> $trait for &'a $type, + impl<'a, $($qual)*, F: Num> $trait for &'a $type, where $type::Unwrapped: $trait, { type Output = Self::WrappedOutput; fn $fn(self, t: F) -> Self::Output { - $type::wrap(self.get_view().$fn(t)) + Self::wrap(self.get_view().$fn(t)) } } }; -} - -macro_rules! impl_scalar_lhs { - ($type:ty, $trait:ident, $fn:ident, $F:ty) => { - impl $trait<$type> for $F + (impl_scalar_lhs<$($qual:tt)*> $type:ty, $trait:path, $fn:ident, $F:ty) => { + impl<$($qual)*> $trait<$type> for $F where $F: $type::Unwrapped, { type Output = Self::WrappedOutput; fn $fn(self, rhs: $type) -> Self::Output { - $type::wrap(self.$fn(rhs.get_view())) + Self::wrap(self.$fn(rhs.get_view())) } } }; -} - -macro_rules! impl_binary_mut { - ($type:ty, $trait:ident, $fn:ident) => { - impl $trait<$type> for $type { + (impl_binary_mut<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => { + impl<$($qual)*> $trait<$type> for $type { fn $fn(&mut self, rhs: $type) { self.get_view_mut().$fn(rhs.get_view()) } } - impl<'b> $trait<&'b $type> for $type { + impl<'b, $($qual)*> $trait<&'b $type> for $type { fn $fn(&mut self, rhs: $type) { self.get_view_mut().$fn(rhs.get_view()) } } }; -} - -macro_rules! impl_scalar_mut { - ($type:ty, $trait:ident, $fn:ident) => { - impl $trait for $type + (impl_scalar_mut<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => { + impl<$($qual)*> $trait for $type where $type::UnwrappedMut: $trait, { @@ -122,45 +106,46 @@ } } }; -} - -macro_rules! wrap { ($type:ty) => { - impl_unary!($type, Neg, neg); - impl_binary!($type, Add, add); - impl_binary!($type, Sub, sub); - impl_scalar!($type, Mul, mul); - impl_scalar!($type, Div, div); - impl_scalar_lhs!($type, Mul, mul, f32); - impl_scalar_lhs!($type, Mul, mul, f64); - impl_scalar_lhs!($type, Div, div, f32); - impl_scalar_lhs!($type, Div, div, f64); - impl_binary_mut!($type, AddAssign, add_assign); - impl_binary_mut!($type, SubAssign, sub_assign); - impl_scalar_mut!($type, MulAssign, mul_assign); - impl_scalar_mut!($type, DivAssign, div_assign); + $crate::wrap!(imp<> do $type); + }; + (imp<$($qual:tt)*> $type:ty) => { + $crate::wrap!(impl_unary<$($qual)*> $type, std::ops::Neg, neg); + $crate::wrap!(impl_binary<$($qual)*> $type, std::ops::Add, add); + $crate::wrap!(impl_binary<$($qual)*> $type, std::ops::Sub, sub); + $crate::wrap!(impl_scalar<$($qual)*> $type, std::ops::Mul, mul); + $crate::wrap!(impl_scalar<$($qual)*> $type, std::ops::Div, div); + $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Mul, mul, f32); + $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Mul, mul, f64); + $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Div, div, f32); + $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Div, div, f64); + $crate::wrap!(impl_binary_mut<$($qual)*> $type, std::ops::AddAssign, add_assign); + $crate::wrap!(impl_binary_mut<$($qual)*> $type, std::ops::SubAssign, sub_assign); + $crate::wrap!(impl_scalar_mut<$($qual)*> $type, std::ops::MulAssign, mul_assign); + $crate::wrap!(impl_scalar_mut<$($qual)*> $type, std::ops::DivAssign, div_assign); /// We only support 'closed' `Euclidean` `Pair`s, as more general ones cause /// compiler overflows. - impl Euclidean for $type + impl<$($qual)* F: $crate::types::Float> $crate::euclidean::Euclidean for $type where //Pair: Euclidean, - Self: Sized - + Mul::Owned> - + MulAssign - + Div::Owned> - + DivAssign - + Add::Owned> - + Sub::Owned> - + for<'b> Add<&'b Self, Output = ::Owned> - + for<'b> Sub<&'b Self, Output = ::Owned> - + AddAssign - + for<'b> AddAssign<&'b Self> - + SubAssign - + for<'b> SubAssign<&'b Self> - + Neg::Owned>, + Self: $crate::euclidean::wrap::Wrapped + + Sized + + std::ops::Mul::Owned> + + std::ops::MulAssign + + std::ops::Div::Owned> + + std::ops::DivAssign + + std::ops::Add::Owned> + + std::ops::Sub::Owned> + + for<'b> std::ops::Add<&'b Self, Output = ::Owned> + + for<'b> std::ops::Sub<&'b Self, Output = ::Owned> + + std::ops::AddAssign + + for<'b> std::ops::AddAssign<&'b Self> + + std::ops::SubAssign + + for<'b> std::ops::SubAssign<&'b Self> + + std::ops::Neg::Owned>, { - fn dot>(&self, other: I) -> F { + fn dot>(&self, other: I) -> F { other.eval_decompose(|x| self.get_view().dot(x.get_view())) } @@ -168,34 +153,34 @@ self.get_view().norm2_squared() } - fn dist2_squared>(&self, other: I) -> F { + fn dist2_squared>(&self, other: I) -> F { other.eval_decompose(|x| self.get_view().dist2_squared(x.get_view())) } } - impl AXPY for $type + impl<$($qual)* F : $crate::types::Float> $crate::linops::AXPY for $type where - Self::Unwrapped : AXPY, - F: Num, - Self: MulAssign + DivAssign, - Self::Unwrapped: MulAssign + DivAssign, + Self : $crate::euclidean::wrap::Wrapped, + Self::Unwrapped : $crate::linops::AXPY, + Self: std::ops::MulAssign + std::ops::DivAssign, + Self::Unwrapped: std::ops::MulAssign + std::ops::DivAssign, { type Field = F; - type Owned = Pair; + type Owned = Self; - fn axpy>(&mut self, α: F, x: I, β: F) { + fn axpy>(&mut self, α: F, x: I, β: F) { x.eval_decompose(|v| { self.get_mut_view().axpy(α, v.get_view(), β) }) } - fn copy_from>(&mut self, x: I) { + fn copy_from>(&mut self, x: I) { x.eval_decompose(|Pair(u, v)| { self.get_mut_view().copy_from(v.get_view()) }) } - fn scale_from>(&mut self, α: F, x: I) { + fn scale_from>(&mut self, α: F, x: I) { x.eval_decompose(|v| { self.get_mut_view().scale_from(α, v.get_view()) }) @@ -203,7 +188,7 @@ /// Return a similar zero as `self`. fn similar_origin(&self) -> Self::Owned { - $type::wrap(self.get_view().similar_origin()) + Self::wrap(self.get_view().similar_origin()) } /// Set self to zero. @@ -212,8 +197,9 @@ } } - impl Space for $type { - type Decomp = $type::Unwrapped::Decomp + impl<$($qual)*> $crate::instance::Space for $type + where Self : $crate::euclidean::wrap::Wrapped { + type Decomp = Self::Unwrapped::Decomp; } }; }