src/euclidean/wrap.rs

Sat, 06 Sep 2025 23:29:34 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Sat, 06 Sep 2025 23:29:34 -0500
branch
dev
changeset 183
d077dff509f1
parent 179
724413fc8d17
permissions
-rw-r--r--

wrap guard interface

/*!
Wrappers for implemention [`Euclidean`] operations.
*/

use crate::euclidean::Euclidean;
use crate::instance::Space;
use crate::types::Float;

pub trait WrapGuard<'a, F: Float> {
    type View<'b>: Euclidean<F>
    where
        Self: 'b;
    fn get_view(&self) -> Self::View<'_>;
}

pub trait WrapGuardMut<'a, F: Float> {
    type ViewMut<'b>: Euclidean<F>
    where
        Self: 'b;
    fn get_view_mut(&mut self) -> Self::ViewMut<'_>;
}

pub trait Wrapped: Space {
    type WrappedField: Float;
    type Guard<'a>: WrapGuard<'a, Self::WrappedField>
    where
        Self: 'a;
    type GuardMut<'a>: WrapGuardMut<'a, Self::WrappedField>
    where
        Self: 'a;
    type UnwrappedOutput;
    type WrappedOutput;
    fn get_guard(&self) -> Self::Guard<'_>;
    fn get_guard_mut(&mut self) -> Self::GuardMut<'_>;
    fn wrap(output: Self::UnwrappedOutput) -> Self::WrappedOutput;
}

#[macro_export]
macro_rules! wrap {
    // Rust macros are totally fucked up. $trait:path does not work, have to
    // manually code paths through $($trait:ident)::+.
    (impl_unary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+ for $type {
            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self) -> Self::Output {
                let a = self.get_guard();
                Self::wrap(a.get_view().$fn())
            }
        }
    };
    (impl_binary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+<$type> for $type {
            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, other: $type) -> Self::Output {
                let a = self.get_guard();
                let b = other.get_guard();
                Self::wrap(a.get_view().$fn(b.get_view()))
            }
        }

        impl<'a, $($qual)*> $($trait)::+<$type> for &'a $type {
            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, other: $type) -> Self::Output {
                let a = self.get_guard();
                let b = other.get_guard();
                <$type>::wrap(a.get_view().$fn(b.get_view()))
            }
        }

        impl<'a, 'b, $($qual)*> $($trait)::+<&'b $type> for &'a $type {
            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, other: &'b $type) -> Self::Output {
                let a = self.get_guard();
                let b = other.get_guard();
                <$type>::wrap(a.get_view().$fn(b.get_view()))
            }
        }

        impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type {
            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, other: &'b $type) -> Self::Output {
                let a = self.get_guard();
                let b = other.get_guard();
                Self::wrap(a.get_view().$fn(b.get_view()))
            }
        }
    };
    (impl_scalar $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+<$F> for $type
        // where
        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
        //     //$type::Unwrapped: $($trait)::+<F>,
        {
            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, t: $F) -> Self::Output {
                let a = self.get_guard();
                Self::wrap(a.get_view().$fn(t))
            }
        }

        impl<'a, $($qual)*> $($trait)::+<$F> for &'a $type
        // where
        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
        //     //$type::Unwrapped: $($trait)::+<F>,
        {
            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, t: $F) -> Self::Output {
                let a = self.get_guard();
                <$type>::wrap(a.get_view().$fn(t))
            }
        }

    };
    (impl_scalar_lhs $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+<$type> for $F
        // where
        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = $F>,
        // // where
        // //     $F: $($trait)::+<$type::Unwrapped>,
        {
            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, rhs: $type) -> Self::Output {
                let b = rhs.get_guard();
                <$type>::wrap(self.$fn(b.get_view()))
            }
        }
    };
    (impl_binary_mut $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+<$type> for $type {
            fn $fn(&mut self, rhs: $type) {
                let mut a = self.get_guard_mut();
                let b = rhs.get_guard();
                a.get_view_mut().$fn(b.get_view())
            }
        }

        impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type {
            fn $fn(&mut self, rhs: &'b $type) {
                let mut a = self.get_guard_mut();
                let b = rhs.get_guard();
                a.get_view_mut().$fn(b.get_view())
            }
        }
    };
    (impl_scalar_mut $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+<$F> for $type
        // where
        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
        // // where
        // //     $type::UnwrappedMut: $($trait)::+<$($trait)::+<F>>,
        {
            fn $fn(&mut self, t: $F) {
                let mut a = self.get_guard_mut();
                a.get_view_mut().$fn(t)
            }
        }
    };
    // ($type:ty) => {
    //     $crate::wrap!(imp<> do $type);
    // };
    ($F:ty; $type:ty where $($qual:tt)*) => {

        $crate::wrap!(impl_unary $type, std::ops::Neg, neg where $($qual)*);
        $crate::wrap!(impl_binary $type, std::ops::Add, add where $($qual)*);
        $crate::wrap!(impl_binary $type, std::ops::Sub, sub where $($qual)*);
        $crate::wrap!(impl_scalar $F, $type, std::ops::Mul, mul where $($qual)*);
        $crate::wrap!(impl_scalar $F, $type, std::ops::Div, div where $($qual)*);
        $crate::wrap!(impl_scalar_lhs $F, $type, std::ops::Mul, mul where $($qual)*);
        $crate::wrap!(impl_binary_mut $type, std::ops::AddAssign, add_assign where $($qual)*);
        $crate::wrap!(impl_binary_mut $type, std::ops::SubAssign, sub_assign where $($qual)*);
        $crate::wrap!(impl_scalar_mut $F, $type, std::ops::MulAssign, mul_assign where $($qual)*);
        $crate::wrap!(impl_scalar_mut $F, $type, std::ops::DivAssign, div_assign where $($qual)*);

        $crate::self_ownable!($type where $($qual)*);

        impl<$($qual)*> $crate::norms::Norm<$crate::norms::L2, $F> for $type
        {
            fn norm(&self, p : $crate::norms::L2) -> $F {
                let a = self.get_guard();
                $crate::norms::Norm::norm(&a.get_view(), p)
            }
        }

        impl<$($qual)*> $crate::norms::Dist<$crate::norms::L2, $F> for $type
        {
            fn dist<I: $crate::instance::Instance<Self>>(&self, other : I, p : $crate::norms::L2) -> $F {
                other.eval_ref(|other| {
                    let a = self.get_guard();
                    let b = other.get_guard();
                    a.get_view().dist(b.get_view(), p)
                })
            }
        }

        impl<$($qual)*> $crate::norms::Normed<$F> for $type {
            type NormExp = $crate::norms::L2;

            fn norm_exponent(&self) -> Self::NormExp {
                $crate::norms::L2
            }
        }

        impl<$($qual)*> $crate::norms::HasDual<$F> for $type {
            type DualSpace = Self;

            fn dual_origin(&self) -> Self {
                $crate::linops::VectorSpace::similar_origin(self)
            }
        }

        impl<$($qual)*> $crate::euclidean::Euclidean<$F> for $type
        // where
        //     Self: $crate::euclidean::wrap::Wrapped<WrappedField = $F>
        //         + Sized
        //         + std::ops::Mul<F, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + std::ops::MulAssign<F>
        //         + std::ops::Div<F, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + std::ops::DivAssign<F>
        //         + std::ops::Add<Self, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + std::ops::Sub<Self, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + for<'b> std::ops::Add<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + for<'b> std::ops::Sub<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + std::ops::AddAssign<Self>
        //         + for<'b> std::ops::AddAssign<&'b Self>
        //         + std::ops::SubAssign<Self>
        //         + for<'b> std::ops::SubAssign<&'b Self>
        //         + std::ops::Neg<Output = <Self as $crate::linops::AXPY>::Owned>,
        {
            type PrincipalE = Self;

            fn dot<I: $crate::instance::Instance<Self>>(&self, other: I) -> $F {
                other.eval_decompose(|other| {
                    let a = self.get_guard();
                    let b = other.get_guard();
                    a.get_view().dot(&b.get_view())
                })
            }

            fn norm2_squared(&self) -> $F {
                let a = self.get_guard();
                a.get_view().norm2_squared()
            }

            fn dist2_squared<I: $crate::instance::Instance<Self>>(&self, other: I) -> $F {
                other.eval_decompose(|other| {
                    let a = self.get_guard();
                    let b = other.get_guard();
                    a.get_view().dist2_squared(b.get_view())
                })
            }
        }

        impl<$($qual)*> $crate::linops::VectorSpace for $type
        // where
        //     Self : $crate::euclidean::wrap::Wrapped<WrappedField = $F>,
        //     Self::Unwrapped : $crate::linops::AXPY<Field = F>,
        //     Self: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
        //     Self::Unwrapped: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
        {
            type Field = $F;
            type PrincipalV = Self;

            /// Return a similar zero as `self`.
            fn similar_origin(&self) -> Self::PrincipalV {
                let a = self.get_guard();
                Self::wrap(a.get_view().similar_origin())
            }
        }

        impl<$($qual)*> $crate::linops::AXPY for $type
        // where
        //     Self : $crate::euclidean::wrap::Wrapped<WrappedField = $F>,
        //     Self::Unwrapped : $crate::linops::AXPY<Field = F>,
        //     Self: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
        //     Self::Unwrapped: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
        {
             fn axpy<I: $crate::instance::Instance<Self>>(&mut self, α: $F, x: I, β: $F) {
                x.eval_decompose(|other| {
                    let mut a = self.get_guard_mut();
                    let b = other.get_guard();
                    $crate::linops::AXPY::axpy(&mut a.get_view_mut(), α, b.get_view(), β)
                })
            }

            fn copy_from<I: $crate::instance::Instance<Self>>(&mut self, x: I) {
                x.eval_decompose(|other| {
                    let mut a = self.get_guard_mut();
                    let b = other.get_guard();
                    $crate::linops::AXPY::copy_from(&mut a.get_view_mut(), b.get_view())
                })
            }

            fn scale_from<I: $crate::instance::Instance<Self>>(&mut self, α: $F, x: I) {
                x.eval_decompose(|other| {
                    let mut a = self.get_guard_mut();
                    let b = other.get_guard();
                    $crate::linops::AXPY::scale_from(&mut a.get_view_mut(), α, b.get_view())
                })
            }

            /// Set self to zero.
            fn set_zero(&mut self) {
                let mut a = self.get_guard_mut();
                a.get_view_mut().set_zero()
            }
        }

        impl<$($qual)*> $crate::instance::Space for $type {
            type Decomp = $crate::instance::BasicDecomposition;
            type Principal = Self;
        }
    };
}

mercurial