src/euclidean/wrap.rs

Sat, 30 Aug 2025 22:43:37 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Sat, 30 Aug 2025 22:43:37 -0500
branch
dev
changeset 147
d6009939e832
parent 146
3f9a03f95457
child 149
2f1798c65fd6
permissions
-rw-r--r--

Bump pyo3

/*!
Wrappers for implemention [`Euclidean`] operations.
*/

pub trait Wrapped {
    type Unwrapped;
    type UnwrappedMut;
    type UnwrappedOutput;
    fn get_view(&self) -> Self::Unwrapped;
    fn get_view_mut(&mut self) -> Self::UnwrappedMut;
    fn wrap(output: Self::UnwrappedOutput) -> Self;
}

#[macro_export]
macro_rules! wrap {
    (impl_unary<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
        impl<$($qual)*> $trait for $type {
            type Output = Self::WrappedOutput;
            fn $fn(self) -> Self::Output {
                Self::wrap(self.get_view().$fn())
            }
        }
    };
    (impl_binary<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
        impl<$($qual)*> $trait<$type> for $type {
            type Output = Self::WrappedOutput;
            fn $fn(self, other: $type) -> Self::Output {
                Self::wrap(self.get_view().$fn(other.get_view()))
            }
        }

        impl<'a, $($qual)*> $trait<$type> for &'a $type {
            type Output = Self::WrappedOutput;
            fn $fn(self, other: $type) -> Self::Output {
                Self::wrap(self.get_view().$fn(other.get_view()))
            }
        }

        impl<'a, 'b, $($qual)*> $trait<&'b $type> for &'a $type {
            type Output = Self::WrappedOutput;
            fn $fn(self, other: $type) -> Self::Output {
                Self::wrap(self.get_view().$fn(other.get_view()))
            }
        }

        impl<'b, $($qual)*> $trait<&'b $type> for $type {
            type Output = Self::WrappedOutput;
            fn $fn(self, other: $type) -> Self::Output {
                Self::wrap(self.get_view().$fn(other.get_view()))
            }
        }
    };
    (impl_scalar<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
        impl<$($qual)*, F: Num> $trait<F> for $type,
        where
            $type::Unwrapped: $trait<F>,
        {
            type Output = Self::WrappedOutput;
            fn $fn(self, t: F) -> Self::Output {
                Self::wrap(self.get_view().$fn(t))
            }
        }

        impl<'a, $($qual)*, F: Num> $trait<F> for &'a $type,
        where
            $type::Unwrapped: $trait<F>,
        {
            type Output = Self::WrappedOutput;
            fn $fn(self, t: F) -> Self::Output {
                Self::wrap(self.get_view().$fn(t))
            }
        }

    };
    (impl_scalar_lhs<$($qual:tt)*> $type:ty, $trait:path, $fn:ident, $F:ty) => {
        impl<$($qual)*> $trait<$type> for $F
        where
            $F: $type::Unwrapped,
        {
            type Output = Self::WrappedOutput;
            fn $fn(self, rhs: $type) -> Self::Output {
                Self::wrap(self.$fn(rhs.get_view()))
            }
        }
    };
    (impl_binary_mut<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
        impl<$($qual)*> $trait<$type> for $type {
            fn $fn(&mut self, rhs: $type) {
                self.get_view_mut().$fn(rhs.get_view())
            }
        }

        impl<'b, $($qual)*> $trait<&'b $type> for $type {
            fn $fn(&mut self, rhs: $type) {
                self.get_view_mut().$fn(rhs.get_view())
            }
        }
    };
    (impl_scalar_mut<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
        impl<$($qual)*> $trait<F> for $type
        where
            $type::UnwrappedMut: $trait<F>,
        {
            fn $fn(&mut self, t: F) {
                self.unwrap_mut().$fn(t)
            }
        }
    };
    ($type:ty) => {
        $crate::wrap!(imp<> do $type);
    };
    (imp<$($qual:tt)*> $type:ty) => {
        $crate::wrap!(impl_unary<$($qual)*> $type, std::ops::Neg, neg);
        $crate::wrap!(impl_binary<$($qual)*> $type, std::ops::Add, add);
        $crate::wrap!(impl_binary<$($qual)*> $type, std::ops::Sub, sub);
        $crate::wrap!(impl_scalar<$($qual)*> $type, std::ops::Mul, mul);
        $crate::wrap!(impl_scalar<$($qual)*> $type, std::ops::Div, div);
        $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Mul, mul, f32);
        $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Mul, mul, f64);
        $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Div, div, f32);
        $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Div, div, f64);
        $crate::wrap!(impl_binary_mut<$($qual)*> $type, std::ops::AddAssign, add_assign);
        $crate::wrap!(impl_binary_mut<$($qual)*> $type, std::ops::SubAssign, sub_assign);
        $crate::wrap!(impl_scalar_mut<$($qual)*> $type, std::ops::MulAssign, mul_assign);
        $crate::wrap!(impl_scalar_mut<$($qual)*> $type, std::ops::DivAssign, div_assign);

        /// We only support 'closed' `Euclidean` `Pair`s, as more general ones cause
        /// compiler overflows.
        impl<$($qual)* F: $crate::types::Float> $crate::euclidean::Euclidean<F> for $type
        where
            //Pair<A, B>: Euclidean<F>,
            Self: $crate::euclidean::wrap::Wrapped
                + Sized
                + std::ops::Mul<F, Output = <Self as $crate::linops::AXPY>::Owned>
                + std::ops::MulAssign<F>
                + std::ops::Div<F, Output = <Self as $crate::linops::AXPY>::Owned>
                + std::ops::DivAssign<F>
                + std::ops::Add<Self, Output = <Self as $crate::linops::AXPY>::Owned>
                + std::ops::Sub<Self, Output = <Self as $crate::linops::AXPY>::Owned>
                + for<'b> std::ops::Add<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
                + for<'b> std::ops::Sub<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
                + std::ops::AddAssign<Self>
                + for<'b> std::ops::AddAssign<&'b Self>
                + std::ops::SubAssign<Self>
                + for<'b> std::ops::SubAssign<&'b Self>
                + std::ops::Neg<Output = <Self as $crate::linops::AXPY>::Owned>,
        {
            fn dot<I: $crate::instance::Instance<Self>>(&self, other: I) -> F {
                other.eval_decompose(|x| self.get_view().dot(x.get_view()))
            }

            fn norm2_squared(&self) -> F {
                self.get_view().norm2_squared()
            }

            fn dist2_squared<I: $crate::instance::Instance<Self>>(&self, other: I) -> F {
                other.eval_decompose(|x| self.get_view().dist2_squared(x.get_view()))
            }
        }

        impl<$($qual)* F : $crate::types::Float> $crate::linops::AXPY for $type
        where
            Self : $crate::euclidean::wrap::Wrapped,
            Self::Unwrapped : $crate::linops::AXPY<Field = F>,
            Self: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
            Self::Unwrapped: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
        {
            type Field = F;
            type Owned = Self;

            fn axpy<I: $crate::instance::Instance<Self>>(&mut self, α: F, x: I, β: F) {
                x.eval_decompose(|v| {
                    self.get_mut_view().axpy(α, v.get_view(), β)
                })
            }

            fn copy_from<I: $crate::instance::Instance<Self>>(&mut self, x: I) {
                x.eval_decompose(|Pair(u, v)| {
                    self.get_mut_view().copy_from(v.get_view())
                })
            }

            fn scale_from<I: $crate::instance::Instance<Self>>(&mut self, α: F, x: I) {
                x.eval_decompose(|v| {
                    self.get_mut_view().scale_from(α, v.get_view())
                })
            }

            /// Return a similar zero as `self`.
            fn similar_origin(&self) -> Self::Owned {
                Self::wrap(self.get_view().similar_origin())
            }

            /// Set self to zero.
            fn set_zero(&mut self) {
                self.get_mut_view().set_zero()
            }
        }

        impl<$($qual)*> $crate::instance::Space for $type
        where Self : $crate::euclidean::wrap::Wrapped {
            type Decomp = Self::Unwrapped::Decomp;
        }
    };
}

mercurial