diff -r d14c877e14b7 -r b3c35d16affe src/direct_product.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/direct_product.rs Mon Feb 03 19:22:16 2025 -0500 @@ -0,0 +1,506 @@ +/*! +Direct products of the form $A \times B$. + +TODO: This could be easily much more generic if `derive_more` could derive arithmetic +operations on references. +*/ + +use core::ops::{Mul,MulAssign,Div,DivAssign,Add,AddAssign,Sub,SubAssign,Neg}; +use std::clone::Clone; +use serde::{Serialize, Deserialize}; +use crate::types::{Num, Float}; +use crate::{maybe_lifetime, maybe_ref}; +use crate::euclidean::Euclidean; +use crate::instance::{Instance, InstanceMut, Decomposition, DecompositionMut, MyCow}; +use crate::mapping::Space; +use crate::linops::AXPY; +use crate::loc::Loc; +use crate::norms::{Norm, PairNorm, NormExponent, Normed, HasDual, L2}; + +#[derive(Debug,Clone,Copy,PartialEq,Eq,Serialize,Deserialize)] +pub struct Pair (pub A, pub B); + +impl Pair { + pub fn new(a : A, b : B) -> Pair { Pair(a, b) } +} + +impl From<(A,B)> for Pair { + #[inline] + fn from((a, b) : (A, B)) -> Pair { Pair(a, b) } +} + +impl From> for (A,B) { + #[inline] + fn from(Pair(a, b) : Pair) -> (A,B) { (a, b) } +} + +macro_rules! impl_binop { + (($a : ty, $b : ty), $trait : ident, $fn : ident, $refl:ident, $refr:ident) => { + impl_binop!(@doit: $a, $b, $trait, $fn; + maybe_lifetime!($refl, &'l Pair<$a,$b>), + (maybe_lifetime!($refl, &'l $a), + maybe_lifetime!($refl, &'l $b)); + maybe_lifetime!($refr, &'r Pair), + (maybe_lifetime!($refr, &'r Ai), + maybe_lifetime!($refr, &'r Bi)); + $refl, $refr); + }; + + (@doit: $a:ty, $b:ty, + $trait:ident, $fn:ident; + $self:ty, ($aself:ty, $bself:ty); + $in:ty, ($ain:ty, $bin:ty); + $refl:ident, $refr:ident) => { + impl<'l, 'r, Ai, Bi> $trait<$in> + for $self + where $aself: $trait<$ain>, + $bself: $trait<$bin> { + type Output = Pair<<$aself as $trait<$ain>>::Output, + <$bself as $trait<$bin>>::Output>; + + #[inline] + fn $fn(self, y : $in) -> Self::Output { + Pair(maybe_ref!($refl, self.0).$fn(maybe_ref!($refr, y.0)), + maybe_ref!($refl, self.1).$fn(maybe_ref!($refr, y.1))) + } + } + }; +} + +macro_rules! impl_assignop { + (($a : ty, $b : ty), $trait : ident, $fn : ident, $refr:ident) => { + impl_assignop!(@doit: $a, $b, + $trait, $fn; + maybe_lifetime!($refr, &'r Pair), + (maybe_lifetime!($refr, &'r Ai), + maybe_lifetime!($refr, &'r Bi)); + $refr); + }; + (@doit: $a : ty, $b : ty, + $trait:ident, $fn:ident; + $in:ty, ($ain:ty, $bin:ty); + $refr:ident) => { + impl<'r, Ai, Bi> $trait<$in> + for Pair<$a,$b> + where $a: $trait<$ain>, + $b: $trait<$bin> { + #[inline] + fn $fn(&mut self, y : $in) -> () { + self.0.$fn(maybe_ref!($refr, y.0)); + self.1.$fn(maybe_ref!($refr, y.1)); + } + } + } +} + +macro_rules! impl_scalarop { + (($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident, $refl:ident) => { + impl_scalarop!(@doit: $field, + $trait, $fn; + maybe_lifetime!($refl, &'l Pair<$a,$b>), + (maybe_lifetime!($refl, &'l $a), + maybe_lifetime!($refl, &'l $b)); + $refl); + }; + (@doit: $field : ty, + $trait:ident, $fn:ident; + $self:ty, ($aself:ty, $bself:ty); + $refl:ident) => { + // Scalar as Rhs + impl<'l> $trait<$field> + for $self + where $aself: $trait<$field>, + $bself: $trait<$field> { + type Output = Pair<<$aself as $trait<$field>>::Output, + <$bself as $trait<$field>>::Output>; + #[inline] + fn $fn(self, a : $field) -> Self::Output { + Pair(maybe_ref!($refl, self.0).$fn(a), + maybe_ref!($refl, self.1).$fn(a)) + } + } + } +} + +// Not used due to compiler overflow +#[allow(unused_macros)] +macro_rules! impl_scalarlhs_op { + (($a : ty, $b : ty), $field : ty, $trait:ident, $fn:ident, $refr:ident) => { + impl_scalarlhs_op!(@doit: $trait, $fn, + maybe_lifetime!($refr, &'r Pair<$a,$b>), + (maybe_lifetime!($refr, &'r $a), + maybe_lifetime!($refr, &'r $b)); + $refr, $field); + }; + (@doit: $trait:ident, $fn:ident, + $in:ty, ($ain:ty, $bin:ty); + $refr:ident, $field:ty) => { + impl<'r> $trait<$in> + for $field + where $field : $trait<$ain> + + $trait<$bin> { + type Output = Pair<<$field as $trait<$ain>>::Output, + <$field as $trait<$bin>>::Output>; + #[inline] + fn $fn(self, x : $in) -> Self::Output { + Pair(self.$fn(maybe_ref!($refr, x.0)), + self.$fn(maybe_ref!($refr, x.1))) + } + } + }; +} + +macro_rules! impl_scalar_assignop { + (($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident) => { + impl<'r> $trait<$field> + for Pair<$a, $b> + where $a: $trait<$field>, $b: $trait<$field> { + #[inline] + fn $fn(&mut self, a : $field) -> () { + self.0.$fn(a); + self.1.$fn(a); + } + } + } +} + +macro_rules! impl_unaryop { + (($a : ty, $b : ty), $trait:ident, $fn:ident, $refl:ident) => { + impl_unaryop!(@doit: $trait, $fn; + maybe_lifetime!($refl, &'l Pair<$a,$b>), + (maybe_lifetime!($refl, &'l $a), + maybe_lifetime!($refl, &'l $b)); + $refl); + }; + (@doit: $trait:ident, $fn:ident; + $self:ty, ($aself:ty, $bself:ty); + $refl : ident) => { + impl<'l> $trait + for $self + where $aself: $trait, + $bself: $trait { + type Output = Pair<<$aself as $trait>::Output, + <$bself as $trait>::Output>; + #[inline] + fn $fn(self) -> Self::Output { + Pair(maybe_ref!($refl, self.0).$fn(), + maybe_ref!($refl, self.1).$fn()) + } + } + } +} + +#[macro_export] +macro_rules! impl_pair_vectorspace_ops { + (($a:ty, $b:ty), $field:ty) => { + impl_pair_vectorspace_ops!(@binary, ($a, $b), Add, add); + impl_pair_vectorspace_ops!(@binary, ($a, $b), Sub, sub); + impl_pair_vectorspace_ops!(@assign, ($a, $b), AddAssign, add_assign); + impl_pair_vectorspace_ops!(@assign, ($a, $b), SubAssign, sub_assign); + impl_pair_vectorspace_ops!(@scalar, ($a, $b), $field, Mul, mul); + impl_pair_vectorspace_ops!(@scalar, ($a, $b), $field, Div, div); + // Compiler overflow + // $( + // impl_pair_vectorspace_ops!(@scalar_lhs, ($a, $b), $field, $impl_scalarlhs_op, Mul, mul); + // )* + impl_pair_vectorspace_ops!(@scalar_assign, ($a, $b), $field, MulAssign, mul_assign); + impl_pair_vectorspace_ops!(@scalar_assign, ($a, $b), $field, DivAssign, div_assign); + impl_pair_vectorspace_ops!(@unary, ($a, $b), Neg, neg); + }; + (@binary, ($a : ty, $b : ty), $trait : ident, $fn : ident) => { + impl_binop!(($a, $b), $trait, $fn, ref, ref); + impl_binop!(($a, $b), $trait, $fn, ref, noref); + impl_binop!(($a, $b), $trait, $fn, noref, ref); + impl_binop!(($a, $b), $trait, $fn, noref, noref); + }; + (@assign, ($a : ty, $b : ty), $trait : ident, $fn :ident) => { + impl_assignop!(($a, $b), $trait, $fn, ref); + impl_assignop!(($a, $b), $trait, $fn, noref); + }; + (@scalar, ($a : ty, $b : ty), $field : ty, $trait : ident, $fn :ident) => { + impl_scalarop!(($a, $b), $field, $trait, $fn, ref); + impl_scalarop!(($a, $b), $field, $trait, $fn, noref); + }; + (@scalar_lhs, ($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident) => { + impl_scalarlhs_op!(($a, $b), $field, $trait, $fn, ref); + impl_scalarlhs_op!(($a, $b), $field, $trait, $fn, noref); + }; + (@scalar_assign, ($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident) => { + impl_scalar_assignop!(($a, $b), $field, $trait, $fn); + }; + (@unary, ($a : ty, $b : ty), $trait : ident, $fn : ident) => { + impl_unaryop!(($a, $b), $trait, $fn, ref); + impl_unaryop!(($a, $b), $trait, $fn, noref); + }; +} + +impl_pair_vectorspace_ops!((f32, f32), f32); +impl_pair_vectorspace_ops!((f64, f64), f64); + +type PairOutput = Pair<>::Output, >::Output>; + +impl Euclidean for Pair +where + A : Euclidean, + B : Euclidean, + F : Float, + PairOutput : Euclidean, + Self : Sized + + Mul> + MulAssign + + Div> + DivAssign + + Add> + + Sub> + + for<'b> Add<&'b Self, Output=PairOutput> + + for<'b> Sub<&'b Self, Output=PairOutput> + + AddAssign + for<'b> AddAssign<&'b Self> + + SubAssign + for<'b> SubAssign<&'b Self> + + Neg> +{ + type Output = PairOutput; + + fn dot>(&self, other : I) -> F { + let Pair(u, v) = other.decompose(); + self.0.dot(u) + self.1.dot(v) + } + + fn norm2_squared(&self) -> F { + self.0.norm2_squared() + self.1.norm2_squared() + } + + fn dist2_squared>(&self, other : I) -> F { + let Pair(u, v) = other.decompose(); + self.0.dist2_squared(u) + self.1.dist2_squared(v) + } +} + +impl AXPY> for Pair +where + U : Space, + V : Space, + A : AXPY, + B : AXPY, + F : Num, + Self : MulAssign, + Pair : MulAssign, + Pair : AXPY>, +{ + + type Owned = Pair; + + fn axpy>>(&mut self, α : F, x : I, β : F) { + let Pair(u, v) = x.decompose(); + self.0.axpy(α, u, β); + self.1.axpy(α, v, β); + } + + fn copy_from>>(&mut self, x : I) { + let Pair(u, v) = x.decompose(); + self.0.copy_from(u); + self.1.copy_from(v); + } + + fn scale_from>>(&mut self, α : F, x : I) { + let Pair(u, v) = x.decompose(); + self.0.scale_from(α, u); + self.1.scale_from(α, v); + } + + /// Return a similar zero as `self`. + fn similar_origin(&self) -> Self::Owned { + Pair(self.0.similar_origin(), self.1.similar_origin()) + } + + /// Set self to zero. + fn set_zero(&mut self) { + self.0.set_zero(); + self.1.set_zero(); + } +} + +/// [`Decomposition`] for working with [`Pair`]s. +#[derive(Copy, Clone, Debug)] +pub struct PairDecomposition(D, Q); + +impl Space for Pair { + type Decomp = PairDecomposition; +} + +impl Decomposition> for PairDecomposition +where + A : Space, + B : Space, + D : Decomposition, + Q : Decomposition, +{ + type Decomposition<'b> = Pair, Q::Decomposition<'b>> where Pair : 'b; + type Reference<'b> = Pair, Q::Reference<'b>> where Pair : 'b; + + #[inline] + fn lift<'b>(Pair(u, v) : Self::Reference<'b>) -> Self::Decomposition<'b> { + Pair(D::lift(u), Q::lift(v)) + } +} + +impl Instance, PairDecomposition> for Pair +where + A : Space, + B : Space, + D : Decomposition, + Q : Decomposition, + U : Instance, + V : Instance, +{ + #[inline] + fn decompose<'b>(self) + -> as Decomposition>>::Decomposition<'b> + where Self : 'b, Pair : 'b + { + Pair(self.0.decompose(), self.1.decompose()) + } + + #[inline] + fn ref_instance(&self) + -> as Decomposition>>::Reference<'_> + { + Pair(self.0.ref_instance(), self.1.ref_instance()) + } + + #[inline] + fn cow<'b>(self) -> MyCow<'b, Pair> where Self : 'b{ + MyCow::Owned(Pair(self.0.own(), self.1.own())) + } + + #[inline] + fn own(self) -> Pair { + Pair(self.0.own(), self.1.own()) + } +} + + +impl<'a, A, B, U, V, D, Q> Instance, PairDecomposition> for &'a Pair +where + A : Space, + B : Space, + D : Decomposition, + Q : Decomposition, + U : Instance, + V : Instance, + &'a U : Instance, + &'a V : Instance, +{ + #[inline] + fn decompose<'b>(self) + -> as Decomposition>>::Decomposition<'b> + where Self : 'b, Pair : 'b + { + Pair(D::lift(self.0.ref_instance()), Q::lift(self.1.ref_instance())) + } + + #[inline] + fn ref_instance(&self) + -> as Decomposition>>::Reference<'_> + { + Pair(self.0.ref_instance(), self.1.ref_instance()) + } + + #[inline] + fn cow<'b>(self) -> MyCow<'b, Pair> where Self : 'b { + MyCow::Owned(self.own()) + } + + #[inline] + fn own(self) -> Pair { + let Pair(ref u, ref v) = self; + Pair(u.own(), v.own()) + } + +} + +impl DecompositionMut> for PairDecomposition +where + A : Space, + B : Space, + D : DecompositionMut, + Q : DecompositionMut, +{ + type ReferenceMut<'b> = Pair, Q::ReferenceMut<'b>> where Pair : 'b; +} + +impl InstanceMut, PairDecomposition> for Pair +where + A : Space, + B : Space, + D : DecompositionMut, + Q : DecompositionMut, + U : InstanceMut, + V : InstanceMut, +{ + #[inline] + fn ref_instance_mut(&mut self) + -> as DecompositionMut>>::ReferenceMut<'_> + { + Pair(self.0.ref_instance_mut(), self.1.ref_instance_mut()) + } +} + +impl<'a, A, B, U, V, D, Q> InstanceMut, PairDecomposition> for &'a mut Pair +where + A : Space, + B : Space, + D : DecompositionMut, + Q : DecompositionMut, + U : InstanceMut, + V : InstanceMut, +{ + #[inline] + fn ref_instance_mut(&mut self) + -> as DecompositionMut>>::ReferenceMut<'_> + { + Pair(self.0.ref_instance_mut(), self.1.ref_instance_mut()) + } +} + + +impl Norm> +for Pair +where + F : Num, + ExpA : NormExponent, + ExpB : NormExponent, + ExpJ : NormExponent, + A : Norm, + B : Norm, + Loc : Norm, +{ + fn norm(&self, PairNorm(expa, expb, expj) : PairNorm) -> F { + Loc([self.0.norm(expa), self.1.norm(expb)]).norm(expj) + } +} + + +impl Normed for Pair +where + A : Normed, + B : Normed, +{ + type NormExp = PairNorm; + + #[inline] + fn norm_exponent(&self) -> Self::NormExp { + PairNorm(self.0.norm_exponent(), self.1.norm_exponent(), L2) + } + + #[inline] + fn is_zero(&self) -> bool { + self.0.is_zero() && self.1.is_zero() + } +} + +impl HasDual for Pair +where + A : HasDual, + B : HasDual, + +{ + type DualSpace = Pair; +}