src/direct_product.rs

Wed, 03 Sep 2025 08:40:17 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Wed, 03 Sep 2025 08:40:17 -0500
branch
dev
changeset 161
5df5258332d1
parent 155
45d03cf92c23
child 162
bea0c3841ced
permissions
-rw-r--r--

try

/*!
Direct products of the form $A \times B$.

TODO: This could be easily much more generic if `derive_more` could derive arithmetic
operations on references.
*/

use crate::euclidean::Euclidean;
use crate::instance::{Decomposition, DecompositionMut, Instance, InstanceMut, MyCow, Ownable};
use crate::linops::{VectorSpace, AXPY};
use crate::loc::Loc;
use crate::mapping::Space;
use crate::norms::{HasDual, Norm, NormExponent, Normed, PairNorm, L2};
use crate::types::{Float, Num};
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use serde::{Deserialize, Serialize};
use std::clone::Clone;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Pair<A, B>(pub A, pub B);

impl<A, B> Pair<A, B> {
    pub fn new(a: A, b: B) -> Pair<A, B> {
        Pair(a, b)
    }
}

impl<A, B> From<(A, B)> for Pair<A, B> {
    #[inline]
    fn from((a, b): (A, B)) -> Pair<A, B> {
        Pair(a, b)
    }
}

impl<A, B> From<Pair<A, B>> for (A, B) {
    #[inline]
    fn from(Pair(a, b): Pair<A, B>) -> (A, B) {
        (a, b)
    }
}

macro_rules! impl_unary {
    ($trait:ident, $fn:ident) => {
        impl<A, B> $trait for Pair<A, B>
        where
            A: $trait,
            B: $trait,
        {
            type Output = Pair<A::Output, B::Output>;
            fn $fn(self) -> Self::Output {
                let Pair(a, b) = self;
                Pair(a.$fn(), b.$fn())
            }
        }

        // Compiler overflow
        // impl<'a, A, B> $trait for &'a Pair<A, B>
        // where
        //     &'a A: $trait,
        //     &'a B: $trait,
        // {
        //     type Output = Pair<<&'a A as $trait>::Output, <&'a B as $trait>::Output>;
        //     fn $fn(self) -> Self::Output {
        //         let Pair(ref a, ref b) = self;
        //         Pair(a.$fn(), b.$fn())
        //     }
        // }
    };
}

impl_unary!(Neg, neg);

macro_rules! impl_binary {
    ($trait:ident, $fn:ident) => {
        impl<A, B, C, D> $trait<Pair<C, D>> for Pair<A, B>
        where
            A: $trait<C>,
            B: $trait<D>,
        {
            type Output = Pair<A::Output, B::Output>;
            fn $fn(self, Pair(c, d): Pair<C, D>) -> Self::Output {
                let Pair(a, b) = self;
                Pair(a.$fn(c), b.$fn(d))
            }
        }

        impl<'a, A, B, C, D> $trait<Pair<C, D>> for &'a Pair<A, B>
        where
            &'a A: $trait<C>,
            &'a B: $trait<D>,
        {
            type Output = Pair<<&'a A as $trait<C>>::Output, <&'a B as $trait<D>>::Output>;
            fn $fn(self, Pair(c, d): Pair<C, D>) -> Self::Output {
                let Pair(ref a, ref b) = self;
                Pair(a.$fn(c), b.$fn(d))
            }
        }

        impl<'a, 'b, A, B, C, D> $trait<&'b Pair<C, D>> for &'a Pair<A, B>
        where
            &'a A: $trait<&'b C>,
            &'a B: $trait<&'b D>,
        {
            type Output = Pair<<&'a A as $trait<&'b C>>::Output, <&'a B as $trait<&'b D>>::Output>;
            fn $fn(self, Pair(ref c, ref d): &'b Pair<C, D>) -> Self::Output {
                let Pair(ref a, ref b) = self;
                Pair(a.$fn(c), b.$fn(d))
            }
        }

        impl<'b, A, B, C, D> $trait<&'b Pair<C, D>> for Pair<A, B>
        where
            A: $trait<&'b C>,
            B: $trait<&'b D>,
        {
            type Output = Pair<<A as $trait<&'b C>>::Output, <B as $trait<&'b D>>::Output>;
            fn $fn(self, Pair(ref c, ref d): &'b Pair<C, D>) -> Self::Output {
                let Pair(a, b) = self;
                Pair(a.$fn(c), b.$fn(d))
            }
        }
    };
}

impl_binary!(Add, add);
impl_binary!(Sub, sub);

macro_rules! impl_scalar {
    ($trait:ident, $fn:ident) => {
        impl<A, B, F: Num> $trait<F> for Pair<A, B>
        where
            A: $trait<F>,
            B: $trait<F>,
        {
            type Output = Pair<A::Output, B::Output>;
            fn $fn(self, t: F) -> Self::Output {
                let Pair(a, b) = self;
                Pair(a.$fn(t), b.$fn(t))
            }
        }

        impl<'a, A, B, F: Num> $trait<F> for &'a Pair<A, B>
        where
            &'a A: $trait<F>,
            &'a B: $trait<F>,
        {
            type Output = Pair<<&'a A as $trait<F>>::Output, <&'a B as $trait<F>>::Output>;
            fn $fn(self, t: F) -> Self::Output {
                let Pair(ref a, ref b) = self;
                Pair(a.$fn(t), b.$fn(t))
            }
        }

        // impl<'a, 'b, A, B> $trait<&'b $F> for &'a Pair<A, B>
        // where
        //     &'a A: $trait<&'b $F>,
        //     &'a B: $trait<&'b $F>,
        // {
        //     type Output =
        //         Pair<<&'a A as $trait<&'b $F>>::Output, <&'a B as $trait<&'b $F>>::Output>;
        //     fn $fn(self, t: &'b $F) -> Self::Output {
        //         let Pair(ref a, ref b) = self;
        //         Pair(a.$fn(t), b.$fn(t))
        //     }
        // }

        // impl<'b, A, B> $trait<&'b $F> for Pair<A, B>
        // where
        //     A: $trait<&'b $F>,
        //     B: $trait<&'b $F>,
        // {
        //     type Output = Pair<<A as $trait<&'b $F>>::Output, <B as $trait<&'b $F>>::Output>;
        //     fn $fn(self, t: &'b $F) -> Self::Output {
        //         let Pair(a, b) = self;
        //         Pair(a.$fn(t), b.$fn(t))
        //     }
        // }
    };
}

impl_scalar!(Mul, mul);
impl_scalar!(Div, div);

macro_rules! impl_scalar_lhs {
    ($trait:ident, $fn:ident, $F:ty) => {
        impl<A, B> $trait<Pair<A, B>> for $F
        where
            $F: $trait<A> + $trait<B>,
        {
            type Output = Pair<<$F as $trait<A>>::Output, <$F as $trait<B>>::Output>;
            fn $fn(self, Pair(a, b): Pair<A, B>) -> Self::Output {
                Pair(self.$fn(a), self.$fn(b))
            }
        }

        // Compiler overflow:
        //
        // impl<'a, A, B> $trait<&'a Pair<A, B>> for $F
        // where
        //     $F: $trait<&'a A> + $trait<&'a B>,
        // {
        //     type Output = Pair<<$F as $trait<&'a A>>::Output, <$F as $trait<&'a B>>::Output>;
        //     fn $fn(self, Pair(a, b): &'a Pair<A, B>) -> Self::Output {
        //         Pair(self.$fn(a), self.$fn(b))
        //     }
        // }
    };
}

impl_scalar_lhs!(Mul, mul, f32);
impl_scalar_lhs!(Mul, mul, f64);
impl_scalar_lhs!(Div, div, f32);
impl_scalar_lhs!(Div, div, f64);

macro_rules! impl_binary_mut {
    ($trait:ident, $fn:ident) => {
        impl<'a, A, B, C, D> $trait<Pair<C, D>> for Pair<A, B>
        where
            A: $trait<C>,
            B: $trait<D>,
        {
            fn $fn(&mut self, Pair(c, d): Pair<C, D>) {
                let Pair(ref mut a, ref mut b) = self;
                a.$fn(c);
                b.$fn(d);
            }
        }

        impl<'a, 'b, A, B, C, D> $trait<&'b Pair<C, D>> for Pair<A, B>
        where
            A: $trait<&'b C>,
            B: $trait<&'b D>,
        {
            fn $fn(&mut self, Pair(ref c, ref d): &'b Pair<C, D>) {
                let Pair(ref mut a, ref mut b) = self;
                a.$fn(c);
                b.$fn(d);
            }
        }
    };
}

impl_binary_mut!(AddAssign, add_assign);
impl_binary_mut!(SubAssign, sub_assign);

macro_rules! impl_scalar_mut {
    ($trait:ident, $fn:ident) => {
        impl<'a, A, B, F: Num> $trait<F> for Pair<A, B>
        where
            A: $trait<F>,
            B: $trait<F>,
        {
            fn $fn(&mut self, t: F) {
                let Pair(ref mut a, ref mut b) = self;
                a.$fn(t);
                b.$fn(t);
            }
        }
    };
}

impl_scalar_mut!(MulAssign, mul_assign);
impl_scalar_mut!(DivAssign, div_assign);

/// Trait for ownable-by-consumption objects
impl<A, B> Ownable for Pair<A, B>
where
    A: Ownable,
    B: Ownable,
{
    type OwnedVariant = Pair<A::OwnedVariant, B::OwnedVariant>;

    #[inline]
    fn into_owned(self) -> Self::OwnedVariant {
        Pair(self.0.into_owned(), self.1.into_owned())
    }

    /// Returns an owned instance of a reference.
    fn clone_owned(&self) -> Self::OwnedVariant {
        Pair(self.0.clone_owned(), self.1.clone_owned())
    }
}

/// We only support 'closed' `Euclidean` `Pair`s, as more general ones cause
/// compiler overflows.
impl<A, B, F: Float> Euclidean<F> for Pair<A, B>
where
    A: Euclidean<F>,
    B: Euclidean<F>,
    // //Pair<A, B>: Euclidean<F>,
    // Self: Sized
    //     + Mul<F, Output = Self::OwnedEuclidean>
    //     + MulAssign<F>
    //     + Div<F, Output = Self::OwnedEuclidean>
    //     + DivAssign<F>
    //     + Add<Self, Output = Self::OwnedEuclidean>
    //     + Sub<Self, Output = Self::OwnedEuclidean>
    //     + for<'b> Add<&'b Self, Output = Self::OwnedEuclidean>
    //     + for<'b> Sub<&'b Self, Output = Self::OwnedEuclidean>
    //     + AddAssign<Self>
    //     + for<'b> AddAssign<&'b Self>
    //     + SubAssign<Self>
    //     + for<'b> SubAssign<&'b Self>
    //     + Neg<Output = Self::OwnedEuclidean>,
{
    type OwnedEuclidean = Pair<A::OwnedEuclidean, B::OwnedEuclidean>;

    fn dot<I: Instance<Self>>(&self, other: I) -> F {
        other.eval_decompose(|Pair(u, v)| self.0.dot(u) + self.1.dot(v))
    }

    fn norm2_squared(&self) -> F {
        self.0.norm2_squared() + self.1.norm2_squared()
    }

    fn dist2_squared<I: Instance<Self>>(&self, other: I) -> F {
        other.eval_decompose(|Pair(u, v)| self.0.dist2_squared(u) + self.1.dist2_squared(v))
    }
}

impl<F, A, B> VectorSpace for Pair<A, B>
where
    A: VectorSpace<Field = F>,
    B: VectorSpace<Field = F>,
    F: Num,
{
    type Field = F;
    type Owned = Pair<A::Owned, B::Owned>;

    /// Return a similar zero as `self`.
    fn similar_origin(&self) -> Self::Owned {
        Pair(self.0.similar_origin(), self.1.similar_origin())
    }

    // #[inline]
    // fn into_owned(self) -> Self::Owned {
    //     Pair(self.0.into_owned(), self.1.into_owned())
    // }
}

impl<F, A, B, U, V> AXPY<Pair<U, V>> for Pair<A, B>
where
    U: Space,
    V: Space,
    A: AXPY<U, Field = F>,
    B: AXPY<V, Field = F>,
    F: Num,
    // Self: MulAssign<F> + DivAssign<F>,
    // Pair<A, B>: MulAssign<F> + DivAssign<F>,
{
    fn axpy<I: Instance<Pair<U, V>>>(&mut self, α: F, x: I, β: F) {
        x.eval_decompose(|Pair(u, v)| {
            self.0.axpy(α, u, β);
            self.1.axpy(α, v, β);
        })
    }

    fn copy_from<I: Instance<Pair<U, V>>>(&mut self, x: I) {
        x.eval_decompose(|Pair(u, v)| {
            self.0.copy_from(u);
            self.1.copy_from(v);
        })
    }

    fn scale_from<I: Instance<Pair<U, V>>>(&mut self, α: F, x: I) {
        x.eval_decompose(|Pair(u, v)| {
            self.0.scale_from(α, u);
            self.1.scale_from(α, v);
        })
    }

    /// Set self to zero.
    fn set_zero(&mut self) {
        self.0.set_zero();
        self.1.set_zero();
    }
}

/// [`Decomposition`] for working with [`Pair`]s.
#[derive(Copy, Clone, Debug)]
pub struct PairDecomposition<D, Q>(D, Q);

impl<A: Space, B: Space> Space for Pair<A, B> {
    type OwnedSpace = Pair<A::OwnedSpace, B::OwnedSpace>;
    type Decomp = PairDecomposition<A::Decomp, B::Decomp>;
}

impl<A, B, D, Q> Decomposition<Pair<A, B>> for PairDecomposition<D, Q>
where
    A: Space,
    B: Space,
    D: Decomposition<A>,
    Q: Decomposition<B>,
{
    type OwnedInstance = Pair<D::OwnedInstance, Q::OwnedInstance>;

    type Decomposition<'b>
        = Pair<D::Decomposition<'b>, Q::Decomposition<'b>>
    where
        Pair<A, B>: 'b;
    type Reference<'b>
        = Pair<D::Reference<'b>, Q::Reference<'b>>
    where
        Pair<A, B>: 'b;

    #[inline]
    fn lift<'b>(Pair(u, v): Self::Reference<'b>) -> Self::Decomposition<'b> {
        Pair(D::lift(u), Q::lift(v))
    }
}

impl<A, B, U, V, D, Q> Instance<Pair<A, B>, PairDecomposition<D, Q>> for Pair<U, V>
where
    A: Space,
    B: Space,
    D: Decomposition<A>,
    Q: Decomposition<B>,
    U: Instance<A, D>,
    V: Instance<B, Q>,
{
    fn eval_decompose<'b, R>(
        self,
        f: impl FnOnce(Pair<D::Decomposition<'b>, Q::Decomposition<'b>>) -> R,
    ) -> R
    where
        Pair<A, B>: 'b,
        Self: 'b,
    {
        self.0
            .eval_decompose(|a| self.1.eval_decompose(|b| f(Pair(a, b))))
    }

    fn eval_ref_decompose<'b, R>(
        &'b self,
        f: impl FnOnce(Pair<D::Reference<'b>, Q::Reference<'b>>) -> R,
    ) -> R
    where
        Pair<A, B>: 'b,
        Self: 'b,
    {
        self.0
            .eval_ref_decompose(|a| self.1.eval_ref_decompose(|b| f(Pair(a, b))))
    }

    #[inline]
    fn cow<'b>(self) -> MyCow<'b, Pair<D::OwnedInstance, Q::OwnedInstance>>
    where
        Self: 'b,
    {
        MyCow::Owned(Pair(self.0.own(), self.1.own()))
    }

    #[inline]
    fn own(self) -> Pair<D::OwnedInstance, Q::OwnedInstance> {
        Pair(self.0.own(), self.1.own())
    }
}

impl<'a, A, B, U, V, D, Q> Instance<Pair<A, B>, PairDecomposition<D, Q>> for &'a Pair<U, V>
where
    A: Space,
    B: Space,
    D: Decomposition<A>,
    Q: Decomposition<B>,
    U: Instance<A, D>,
    V: Instance<B, Q>,
    &'a U: Instance<A, D>,
    &'a V: Instance<B, Q>,
{
    fn eval_decompose<'b, R>(
        self,
        f: impl FnOnce(Pair<D::Decomposition<'b>, Q::Decomposition<'b>>) -> R,
    ) -> R
    where
        Pair<A, B>: 'b,
        Self: 'b,
    {
        self.0.eval_ref_decompose(|a| {
            self.1
                .eval_ref_decompose(|b| f(Pair(D::lift(a), Q::lift(b))))
        })
    }

    fn eval_ref_decompose<'b, R>(
        &'b self,
        f: impl FnOnce(Pair<D::Reference<'b>, Q::Reference<'b>>) -> R,
    ) -> R
    where
        Pair<A, B>: 'b,
        Self: 'b,
    {
        self.0
            .eval_ref_decompose(|a| self.1.eval_ref_decompose(|b| f(Pair(a, b))))
    }

    #[inline]
    fn cow<'b>(self) -> MyCow<'b, Pair<D::OwnedInstance, Q::OwnedInstance>>
    where
        Self: 'b,
    {
        MyCow::Owned(self.own())
    }

    #[inline]
    fn own(self) -> Pair<D::OwnedInstance, Q::OwnedInstance> {
        let Pair(ref u, ref v) = self;
        Pair(u.own(), v.own())
    }
}

impl<A, B, D, Q> DecompositionMut<Pair<A, B>> for PairDecomposition<D, Q>
where
    A: Space,
    B: Space,
    D: DecompositionMut<A>,
    Q: DecompositionMut<B>,
{
    type ReferenceMut<'b>
        = Pair<D::ReferenceMut<'b>, Q::ReferenceMut<'b>>
    where
        Pair<A, B>: 'b;
}

impl<A, B, U, V, D, Q> InstanceMut<Pair<A, B>, PairDecomposition<D, Q>> for Pair<U, V>
where
    A: Space,
    B: Space,
    D: DecompositionMut<A>,
    Q: DecompositionMut<B>,
    U: InstanceMut<A, D>,
    V: InstanceMut<B, Q>,
{
    #[inline]
    fn ref_instance_mut(
        &mut self,
    ) -> <PairDecomposition<D, Q> as DecompositionMut<Pair<A, B>>>::ReferenceMut<'_> {
        Pair(self.0.ref_instance_mut(), self.1.ref_instance_mut())
    }
}

impl<'a, A, B, U, V, D, Q> InstanceMut<Pair<A, B>, PairDecomposition<D, Q>> for &'a mut Pair<U, V>
where
    A: Space,
    B: Space,
    D: DecompositionMut<A>,
    Q: DecompositionMut<B>,
    U: InstanceMut<A, D>,
    V: InstanceMut<B, Q>,
{
    #[inline]
    fn ref_instance_mut(
        &mut self,
    ) -> <PairDecomposition<D, Q> as DecompositionMut<Pair<A, B>>>::ReferenceMut<'_> {
        Pair(self.0.ref_instance_mut(), self.1.ref_instance_mut())
    }
}

impl<F, A, B, ExpA, ExpB, ExpJ> Norm<PairNorm<ExpA, ExpB, ExpJ>, F> for Pair<A, B>
where
    F: Num,
    ExpA: NormExponent,
    ExpB: NormExponent,
    ExpJ: NormExponent,
    A: Norm<ExpA, F>,
    B: Norm<ExpB, F>,
    Loc<2, F>: Norm<ExpJ, F>,
{
    fn norm(&self, PairNorm(expa, expb, expj): PairNorm<ExpA, ExpB, ExpJ>) -> F {
        Loc([self.0.norm(expa), self.1.norm(expb)]).norm(expj)
    }
}

impl<F: Float, A, B> Normed<F> for Pair<A, B>
where
    A: Normed<F>,
    B: Normed<F>,
{
    type NormExp = PairNorm<A::NormExp, B::NormExp, L2>;

    #[inline]
    fn norm_exponent(&self) -> Self::NormExp {
        PairNorm(self.0.norm_exponent(), self.1.norm_exponent(), L2)
    }

    #[inline]
    fn is_zero(&self) -> bool {
        self.0.is_zero() && self.1.is_zero()
    }
}

impl<F: Float, A, B> HasDual<F> for Pair<A, B>
where
    A: HasDual<F>,
    B: HasDual<F>,
{
    type DualSpace = Pair<A::DualSpace, B::DualSpace>;

    fn dual_origin(&self) -> <Self::DualSpace as VectorSpace>::Owned {
        Pair(self.0.dual_origin(), self.1.dual_origin())
    }
}

#[cfg(feature = "pyo3")]
mod python {
    use super::Pair;
    use pyo3::conversion::FromPyObject;
    use pyo3::types::{PyAny, PyTuple};
    use pyo3::{Bound, IntoPyObject, PyErr, PyResult, Python};

    impl<'py, A, B> IntoPyObject<'py> for Pair<A, B>
    where
        A: IntoPyObject<'py>,
        B: IntoPyObject<'py>,
    {
        type Target = PyTuple;
        type Error = PyErr;
        type Output = Bound<'py, Self::Target>;

        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
            (self.0, self.1).into_pyobject(py)
        }
    }

    impl<'a, 'py, A, B> IntoPyObject<'py> for &'a mut Pair<A, B>
    where
        &'a mut A: IntoPyObject<'py>,
        &'a mut B: IntoPyObject<'py>,
    {
        type Target = PyTuple;
        type Error = PyErr;
        type Output = Bound<'py, Self::Target>;

        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
            (&mut self.0, &mut self.1).into_pyobject(py)
        }
    }

    impl<'a, 'py, A, B> IntoPyObject<'py> for &'a Pair<A, B>
    where
        &'a A: IntoPyObject<'py>,
        &'a B: IntoPyObject<'py>,
    {
        type Target = PyTuple;
        type Error = PyErr;
        type Output = Bound<'py, Self::Target>;

        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
            (&self.0, &self.1).into_pyobject(py)
        }
    }

    impl<'py, A, B> FromPyObject<'py> for Pair<A, B>
    where
        A: Clone + FromPyObject<'py>,
        B: Clone + FromPyObject<'py>,
    {
        fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
            FromPyObject::extract_bound(ob).map(|(a, b)| Pair(a, b))
        }
    }
}

mercurial