src/euclidean/wrap.rs

Mon, 01 Sep 2025 00:04:22 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Mon, 01 Sep 2025 00:04:22 -0500
branch
dev
changeset 149
2f1798c65fd6
parent 147
d6009939e832
child 174
53ab61a41d70
permissions
-rw-r--r--

wrap hacks

/*!
Wrappers for implemention [`Euclidean`] operations.
*/

use crate::euclidean::Euclidean;
use crate::instance::Space;
use crate::types::Float;

pub trait Wrapped: Space {
    type WrappedField: Float;
    type Unwrapped: Euclidean<Self::WrappedField>;
    type UnwrappedMut: Euclidean<Self::WrappedField>;
    type UnwrappedOutput: Euclidean<Self::WrappedField>;
    type WrappedOutput;
    fn get_view(&self) -> Self::Unwrapped;
    fn get_view_mut(&mut self) -> Self::UnwrappedMut;
    fn wrap(output: Self::UnwrappedOutput) -> Self::WrappedOutput;
}

#[macro_export]
macro_rules! wrap {
    // Rust macros are totally fucked up. $trait:path does not work, have to
    // manually code paths through $($trait:ident)::+.
    (impl_unary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+ for $type {
            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self) -> Self::Output {
                Self::wrap(self.get_view().$fn())
            }
        }
    };
    (impl_binary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+<$type> for $type {
            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, other: $type) -> Self::Output {
                Self::wrap(self.get_view().$fn(other.get_view()))
            }
        }

        impl<'a, $($qual)*> $($trait)::+<$type> for &'a $type {
            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, other: $type) -> Self::Output {
                <$type>::wrap(self.get_view().$fn(other.get_view()))
            }
        }

        impl<'a, 'b, $($qual)*> $($trait)::+<&'b $type> for &'a $type {
            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, other: &'b $type) -> Self::Output {
                <$type>::wrap(self.get_view().$fn(other.get_view()))
            }
        }

        impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type {
            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, other: &'b $type) -> Self::Output {
                Self::wrap(self.get_view().$fn(other.get_view()))
            }
        }
    };
    (impl_scalar $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+<$F> for $type
        // where
        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
        //     //$type::Unwrapped: $($trait)::+<F>,
        {
            type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, t: $F) -> Self::Output {
                Self::wrap(self.get_view().$fn(t))
            }
        }

        impl<'a, $($qual)*> $($trait)::+<$F> for &'a $type
        // where
        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
        //     //$type::Unwrapped: $($trait)::+<F>,
        {
            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, t: $F) -> Self::Output {
                <$type>::wrap(self.get_view().$fn(t))
            }
        }

    };
    (impl_scalar_lhs $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+<$type> for $F
        // where
        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = $F>,
        // // where
        // //     $F: $($trait)::+<$type::Unwrapped>,
        {
            type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
            fn $fn(self, rhs: $type) -> Self::Output {
                <$type>::wrap(self.$fn(rhs.get_view()))
            }
        }
    };
    (impl_binary_mut $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+<$type> for $type {
            fn $fn(&mut self, rhs: $type) {
                self.get_view_mut().$fn(rhs.get_view())
            }
        }

        impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type {
            fn $fn(&mut self, rhs: &'b $type) {
                self.get_view_mut().$fn(rhs.get_view())
            }
        }
    };
    (impl_scalar_mut $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
        impl<$($qual)*> $($trait)::+<$F> for $type
        // where
        //     $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
        // // where
        // //     $type::UnwrappedMut: $($trait)::+<$($trait)::+<F>>,
        {
            fn $fn(&mut self, t: $F) {
                self.get_view_mut().$fn(t)
            }
        }
    };
    // ($type:ty) => {
    //     $crate::wrap!(imp<> do $type);
    // };
    ($F:ty; $type:ty where $($qual:tt)*) => {
        $crate::wrap!(impl_unary $type, std::ops::Neg, neg where $($qual)*);
        $crate::wrap!(impl_binary $type, std::ops::Add, add where $($qual)*);
        $crate::wrap!(impl_binary $type, std::ops::Sub, sub where $($qual)*);
        $crate::wrap!(impl_scalar $F, $type, std::ops::Mul, mul where $($qual)*);
        $crate::wrap!(impl_scalar $F, $type, std::ops::Div, div where $($qual)*);
        $crate::wrap!(impl_scalar_lhs $F, $type, std::ops::Mul, mul where $($qual)*);
        $crate::wrap!(impl_binary_mut $type, std::ops::AddAssign, add_assign where $($qual)*);
        $crate::wrap!(impl_binary_mut $type, std::ops::SubAssign, sub_assign where $($qual)*);
        $crate::wrap!(impl_scalar_mut $F, $type, std::ops::MulAssign, mul_assign where $($qual)*);
        $crate::wrap!(impl_scalar_mut $F, $type, std::ops::DivAssign, div_assign where $($qual)*);

        impl<$($qual)*> $crate::euclidean::Euclidean<$F> for $type
        // where
        //     Self: $crate::euclidean::wrap::Wrapped<WrappedField = $F>
        //         + Sized
        //         + std::ops::Mul<F, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + std::ops::MulAssign<F>
        //         + std::ops::Div<F, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + std::ops::DivAssign<F>
        //         + std::ops::Add<Self, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + std::ops::Sub<Self, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + for<'b> std::ops::Add<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + for<'b> std::ops::Sub<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
        //         + std::ops::AddAssign<Self>
        //         + for<'b> std::ops::AddAssign<&'b Self>
        //         + std::ops::SubAssign<Self>
        //         + for<'b> std::ops::SubAssign<&'b Self>
        //         + std::ops::Neg<Output = <Self as $crate::linops::AXPY>::Owned>,
        {
            fn dot<I: $crate::instance::Instance<Self>>(&self, other: I) -> $F {
                other.eval_decompose(|x| self.get_view().dot(&x.get_view()))
            }

            fn norm2_squared(&self) -> $F {
                self.get_view().norm2_squared()
            }

            fn dist2_squared<I: $crate::instance::Instance<Self>>(&self, other: I) -> $F {
                other.eval_decompose(|x| self.get_view().dist2_squared(x.get_view()))
            }
        }

        impl<$($qual)*> $crate::linops::AXPY for $type
        // where
        //     Self : $crate::euclidean::wrap::Wrapped<WrappedField = $F>,
        //     Self::Unwrapped : $crate::linops::AXPY<Field = F>,
        //     Self: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
        //     Self::Unwrapped: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
        {
            type Field = $F;
            type Owned = Self;

            fn axpy<I: $crate::instance::Instance<Self>>(&mut self, α: $F, x: I, β: $F) {
                x.eval_decompose(|v| {
                    self.get_view_mut().axpy(α, v.get_view(), β)
                })
            }

            fn copy_from<I: $crate::instance::Instance<Self>>(&mut self, x: I) {
                x.eval_decompose(|v| {
                    self.get_view_mut().copy_from(v.get_view())
                })
            }

            fn scale_from<I: $crate::instance::Instance<Self>>(&mut self, α: $F, x: I) {
                x.eval_decompose(|v| {
                    self.get_mut_view().scale_from(α, v.get_view())
                })
            }

            /// Return a similar zero as `self`.
            fn similar_origin(&self) -> Self::Owned {
                Self::wrap(self.get_view().similar_origin())
            }

            /// Set self to zero.
            fn set_zero(&mut self) {
                self.get_mut_view().set_zero()
            }
        }

        impl<$($qual)*> $crate::instance::Space for $type {
            type Decomp = <<Self as $crate::euclidean::wrap::Wrapped>::Unwrapped as $crate::instance::Space>::Decomp;
        }
    };
}

mercurial