diff -r 5e3c1874797d -r 1b3b1687b9ed src/linops.rs --- a/src/linops.rs Wed Dec 11 20:45:17 2024 -0500 +++ b/src/linops.rs Fri Dec 13 22:37:12 2024 -0500 @@ -7,6 +7,7 @@ use crate::types::*; use serde::Serialize; pub use crate::mapping::Apply; +use crate::direct_product::Pair; /// Trait for linear operators on `X`. pub trait Linear : Apply @@ -96,7 +97,7 @@ fn preadjoint(&self) -> Self::Preadjoint<'_>; } -/// Adjointable operators $A: X → Y$ on between reflexive spaces $X$ and $Y$. +/// Adjointable operators $A: X → Y$ between reflexive spaces $X$ and $Y$. pub trait SimplyAdjointable : Adjointable>::Codomain> {} impl<'a,X,T> SimplyAdjointable for T where T : Adjointable>::Codomain> {} @@ -152,3 +153,266 @@ fn adjoint(&self) -> Self::Adjoint<'_> { IdOp::new() } } +/// “Row operator” $(S, T)$; $(S, T)(x, y)=Sx + Ty$. +pub struct RowOp(pub S, pub T); + +use std::ops::Add; + +impl Apply> for RowOp +where + S : Apply, + T : Apply, + S::Output : Add +{ + type Output = >::Output; + + fn apply(&self, Pair(a, b) : Pair) -> Self::Output { + self.0.apply(a) + self.1.apply(b) + } +} + +impl<'a, A, B, S, T> Apply<&'a Pair> for RowOp +where + S : Apply<&'a A>, + T : Apply<&'a B>, + S::Output : Add +{ + type Output = >::Output; + + fn apply(&self, Pair(ref a, ref b) : &'a Pair) -> Self::Output { + self.0.apply(a) + self.1.apply(b) + } +} + +impl Linear> for RowOp +where + RowOp : Apply, Output=D> + for<'a> Apply<&'a Pair, Output=D>, +{ + type Codomain = D; +} + +impl GEMV, Y> for RowOp +where + S : GEMV, + T : GEMV, + F : Num, + Self : Linear, Codomain=Y> +{ + fn gemv(&self, y : &mut Y, α : F, x : &Pair, β : F) { + self.0.gemv(y, α, &x.0, β); + self.1.gemv(y, α, &x.1, β); + } + + fn apply_mut(&self, y : &mut Y, x : &Pair){ + self.0.apply_mut(y, &x.0); + self.1.apply_mut(y, &x.1); + } + + /// Computes `y += Ax`, where `A` is `Self` + fn apply_add(&self, y : &mut Y, x : &Pair){ + self.0.apply_add(y, &x.0); + self.1.apply_add(y, &x.1); + } +} + + +/// “Column operator” $(S; T)$; $(S; T)x=(Sx, Tx)$. +pub struct ColOp(pub S, pub T); + +impl Apply for ColOp +where + S : for<'a> Apply<&'a A, Output=O>, + T : Apply, + A : std::borrow::Borrow, +{ + type Output = Pair; + + fn apply(&self, a : A) -> Self::Output { + Pair(self.0.apply(a.borrow()), self.1.apply(a)) + } +} + +impl Linear for ColOp +where + ColOp : Apply + for<'a> Apply<&'a A, Output=D>, +{ + type Codomain = D; +} + +impl GEMV> for ColOp +where + S : GEMV, + T : GEMV, + F : Num, + Self : Linear> +{ + fn gemv(&self, y : &mut Pair, α : F, x : &X, β : F) { + self.0.gemv(&mut y.0, α, x, β); + self.1.gemv(&mut y.1, α, x, β); + } + + fn apply_mut(&self, y : &mut Pair, x : &X){ + self.0.apply_mut(&mut y.0, x); + self.1.apply_mut(&mut y.1, x); + } + + /// Computes `y += Ax`, where `A` is `Self` + fn apply_add(&self, y : &mut Pair, x : &X){ + self.0.apply_add(&mut y.0, x); + self.1.apply_add(&mut y.1, x); + } +} + + +impl Adjointable,Yʹ> for RowOp +where + S : Adjointable, + T : Adjointable, + Self : Linear>, + for<'a> ColOp, T::Adjoint<'a>> : Linear, +{ + type AdjointCodomain = R; + type Adjoint<'a> = ColOp, T::Adjoint<'a>> where Self : 'a; + + fn adjoint(&self) -> Self::Adjoint<'_> { + ColOp(self.0.adjoint(), self.1.adjoint()) + } +} + +impl Preadjointable,Yʹ> for RowOp +where + S : Preadjointable, + T : Preadjointable, + Self : Linear>, + for<'a> ColOp, T::Preadjoint<'a>> : Adjointable, Codomain=R>, +{ + type PreadjointCodomain = R; + type Preadjoint<'a> = ColOp, T::Preadjoint<'a>> where Self : 'a; + + fn preadjoint(&self) -> Self::Preadjoint<'_> { + ColOp(self.0.preadjoint(), self.1.preadjoint()) + } +} + + +impl Adjointable> for ColOp +where + S : Adjointable, + T : Adjointable, + Self : Linear, + for<'a> RowOp, T::Adjoint<'a>> : Linear, Codomain=R>, +{ + type AdjointCodomain = R; + type Adjoint<'a> = RowOp, T::Adjoint<'a>> where Self : 'a; + + fn adjoint(&self) -> Self::Adjoint<'_> { + RowOp(self.0.adjoint(), self.1.adjoint()) + } +} + +impl Preadjointable> for ColOp +where + S : Preadjointable, + T : Preadjointable, + Self : Linear, + for<'a> RowOp, T::Preadjoint<'a>> : Adjointable, A, Codomain=R>, +{ + type PreadjointCodomain = R; + type Preadjoint<'a> = RowOp, T::Preadjoint<'a>> where Self : 'a; + + fn preadjoint(&self) -> Self::Preadjoint<'_> { + RowOp(self.0.preadjoint(), self.1.preadjoint()) + } +} + +/// Diagonal operator +pub struct DiagOp(pub S, pub T); + +impl Apply> for DiagOp +where + S : Apply, + T : Apply, +{ + type Output = Pair; + + fn apply(&self, Pair(a, b) : Pair) -> Self::Output { + Pair(self.0.apply(a), self.1.apply(b)) + } +} + +impl<'a, A, B, S, T> Apply<&'a Pair> for DiagOp +where + S : Apply<&'a A>, + T : Apply<&'a B>, +{ + type Output = Pair; + + fn apply(&self, Pair(ref a, ref b) : &'a Pair) -> Self::Output { + Pair(self.0.apply(a), self.1.apply(b)) + } +} + +impl Linear> for DiagOp +where + DiagOp : Apply, Output=D> + for<'a> Apply<&'a Pair, Output=D>, +{ + type Codomain = D; +} + +impl GEMV, Pair> for DiagOp +where + S : GEMV, + T : GEMV, + F : Num, + Self : Linear, Codomain=Pair> +{ + fn gemv(&self, y : &mut Pair, α : F, x : &Pair, β : F) { + self.0.gemv(&mut y.0, α, &x.0, β); + self.1.gemv(&mut y.1, α, &x.1, β); + } + + fn apply_mut(&self, y : &mut Pair, x : &Pair){ + self.0.apply_mut(&mut y.0, &x.0); + self.1.apply_mut(&mut y.1, &x.1); + } + + /// Computes `y += Ax`, where `A` is `Self` + fn apply_add(&self, y : &mut Pair, x : &Pair){ + self.0.apply_add(&mut y.0, &x.0); + self.1.apply_add(&mut y.1, &x.1); + } +} + +impl Adjointable,Pair> for DiagOp +where + S : Adjointable, + T : Adjointable, + Self : Linear>, + for<'a> DiagOp, T::Adjoint<'a>> : Linear, Codomain=R>, +{ + type AdjointCodomain = R; + type Adjoint<'a> = DiagOp, T::Adjoint<'a>> where Self : 'a; + + fn adjoint(&self) -> Self::Adjoint<'_> { + DiagOp(self.0.adjoint(), self.1.adjoint()) + } +} + +impl Preadjointable,Pair> for DiagOp +where + S : Preadjointable, + T : Preadjointable, + Self : Linear>, + for<'a> DiagOp, T::Preadjoint<'a>> : Adjointable, Pair, Codomain=R>, +{ + type PreadjointCodomain = R; + type Preadjoint<'a> = DiagOp, T::Preadjoint<'a>> where Self : 'a; + + fn preadjoint(&self) -> Self::Preadjoint<'_> { + DiagOp(self.0.preadjoint(), self.1.preadjoint()) + } +} + +/// Block operator +pub type BlockOp = ColOp, RowOp>; +