Sun, 18 May 2025 23:15:50 -0500
Basic Pair pyo3 conversions
/*! Direct products of the form $A \times B$. TODO: This could be easily much more generic if `derive_more` could derive arithmetic operations on references. */ use crate::euclidean::Euclidean; use crate::instance::{Decomposition, DecompositionMut, Instance, InstanceMut, MyCow}; use crate::linops::AXPY; use crate::loc::Loc; use crate::mapping::Space; use crate::norms::{HasDual, Norm, NormExponent, Normed, PairNorm, L2}; use crate::types::{Float, Num}; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use serde::{Deserialize, Serialize}; use std::clone::Clone; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub struct Pair<A, B>(pub A, pub B); impl<A, B> Pair<A, B> { pub fn new(a: A, b: B) -> Pair<A, B> { Pair(a, b) } } impl<A, B> From<(A, B)> for Pair<A, B> { #[inline] fn from((a, b): (A, B)) -> Pair<A, B> { Pair(a, b) } } impl<A, B> From<Pair<A, B>> for (A, B) { #[inline] fn from(Pair(a, b): Pair<A, B>) -> (A, B) { (a, b) } } macro_rules! impl_unary { ($trait:ident, $fn:ident) => { impl<A, B> $trait for Pair<A, B> where A: $trait, B: $trait, { type Output = Pair<A::Output, B::Output>; fn $fn(self) -> Self::Output { let Pair(a, b) = self; Pair(a.$fn(), b.$fn()) } } // Compiler overflow // impl<'a, A, B> $trait for &'a Pair<A, B> // where // &'a A: $trait, // &'a B: $trait, // { // type Output = Pair<<&'a A as $trait>::Output, <&'a B as $trait>::Output>; // fn $fn(self) -> Self::Output { // let Pair(ref a, ref b) = self; // Pair(a.$fn(), b.$fn()) // } // } }; } impl_unary!(Neg, neg); macro_rules! impl_binary { ($trait:ident, $fn:ident) => { impl<A, B, C, D> $trait<Pair<C, D>> for Pair<A, B> where A: $trait<C>, B: $trait<D>, { type Output = Pair<A::Output, B::Output>; fn $fn(self, Pair(c, d): Pair<C, D>) -> Self::Output { let Pair(a, b) = self; Pair(a.$fn(c), b.$fn(d)) } } impl<'a, A, B, C, D> $trait<Pair<C, D>> for &'a Pair<A, B> where &'a A: $trait<C>, &'a B: $trait<D>, { type Output = Pair<<&'a A as $trait<C>>::Output, <&'a B as $trait<D>>::Output>; fn $fn(self, Pair(c, d): Pair<C, D>) -> Self::Output { let Pair(ref a, ref b) = self; Pair(a.$fn(c), b.$fn(d)) } } impl<'a, 'b, A, B, C, D> $trait<&'b Pair<C, D>> for &'a Pair<A, B> where &'a A: $trait<&'b C>, &'a B: $trait<&'b D>, { type Output = Pair<<&'a A as $trait<&'b C>>::Output, <&'a B as $trait<&'b D>>::Output>; fn $fn(self, Pair(ref c, ref d): &'b Pair<C, D>) -> Self::Output { let Pair(ref a, ref b) = self; Pair(a.$fn(c), b.$fn(d)) } } impl<'b, A, B, C, D> $trait<&'b Pair<C, D>> for Pair<A, B> where A: $trait<&'b C>, B: $trait<&'b D>, { type Output = Pair<<A as $trait<&'b C>>::Output, <B as $trait<&'b D>>::Output>; fn $fn(self, Pair(ref c, ref d): &'b Pair<C, D>) -> Self::Output { let Pair(a, b) = self; Pair(a.$fn(c), b.$fn(d)) } } }; } impl_binary!(Add, add); impl_binary!(Sub, sub); macro_rules! impl_scalar { ($trait:ident, $fn:ident) => { impl<A, B, F: Num> $trait<F> for Pair<A, B> where A: $trait<F>, B: $trait<F>, { type Output = Pair<A::Output, B::Output>; fn $fn(self, t: F) -> Self::Output { let Pair(a, b) = self; Pair(a.$fn(t), b.$fn(t)) } } impl<'a, A, B, F: Num> $trait<F> for &'a Pair<A, B> where &'a A: $trait<F>, &'a B: $trait<F>, { type Output = Pair<<&'a A as $trait<F>>::Output, <&'a B as $trait<F>>::Output>; fn $fn(self, t: F) -> Self::Output { let Pair(ref a, ref b) = self; Pair(a.$fn(t), b.$fn(t)) } } // impl<'a, 'b, A, B> $trait<&'b $F> for &'a Pair<A, B> // where // &'a A: $trait<&'b $F>, // &'a B: $trait<&'b $F>, // { // type Output = // Pair<<&'a A as $trait<&'b $F>>::Output, <&'a B as $trait<&'b $F>>::Output>; // fn $fn(self, t: &'b $F) -> Self::Output { // let Pair(ref a, ref b) = self; // Pair(a.$fn(t), b.$fn(t)) // } // } // impl<'b, A, B> $trait<&'b $F> for Pair<A, B> // where // A: $trait<&'b $F>, // B: $trait<&'b $F>, // { // type Output = Pair<<A as $trait<&'b $F>>::Output, <B as $trait<&'b $F>>::Output>; // fn $fn(self, t: &'b $F) -> Self::Output { // let Pair(a, b) = self; // Pair(a.$fn(t), b.$fn(t)) // } // } }; } impl_scalar!(Mul, mul); impl_scalar!(Div, div); macro_rules! impl_scalar_lhs { ($trait:ident, $fn:ident, $F:ty) => { impl<A, B> $trait<Pair<A, B>> for $F where $F: $trait<A> + $trait<B>, { type Output = Pair<<$F as $trait<A>>::Output, <$F as $trait<B>>::Output>; fn $fn(self, Pair(a, b): Pair<A, B>) -> Self::Output { Pair(self.$fn(a), self.$fn(b)) } } // Compiler overflow: // // impl<'a, A, B> $trait<&'a Pair<A, B>> for $F // where // $F: $trait<&'a A> + $trait<&'a B>, // { // type Output = Pair<<$F as $trait<&'a A>>::Output, <$F as $trait<&'a B>>::Output>; // fn $fn(self, Pair(a, b): &'a Pair<A, B>) -> Self::Output { // Pair(self.$fn(a), self.$fn(b)) // } // } }; } impl_scalar_lhs!(Mul, mul, f32); impl_scalar_lhs!(Mul, mul, f64); impl_scalar_lhs!(Div, div, f32); impl_scalar_lhs!(Div, div, f64); macro_rules! impl_binary_mut { ($trait:ident, $fn:ident) => { impl<'a, A, B, C, D> $trait<Pair<C, D>> for Pair<A, B> where A: $trait<C>, B: $trait<D>, { fn $fn(&mut self, Pair(c, d): Pair<C, D>) { let Pair(ref mut a, ref mut b) = self; a.$fn(c); b.$fn(d); } } impl<'a, 'b, A, B, C, D> $trait<&'b Pair<C, D>> for Pair<A, B> where A: $trait<&'b C>, B: $trait<&'b D>, { fn $fn(&mut self, Pair(ref c, ref d): &'b Pair<C, D>) { let Pair(ref mut a, ref mut b) = self; a.$fn(c); b.$fn(d); } } }; } impl_binary_mut!(AddAssign, add_assign); impl_binary_mut!(SubAssign, sub_assign); macro_rules! impl_scalar_mut { ($trait:ident, $fn:ident) => { impl<'a, A, B, F: Num> $trait<F> for Pair<A, B> where A: $trait<F>, B: $trait<F>, { fn $fn(&mut self, t: F) { let Pair(ref mut a, ref mut b) = self; a.$fn(t); b.$fn(t); } } }; } impl_scalar_mut!(MulAssign, mul_assign); impl_scalar_mut!(DivAssign, div_assign); /// We only support 'closed' `Euclidean` `Pair`s, as more general ones cause /// compiler overflows. impl<A, B, F: Float> Euclidean<F> for Pair<A, B> where A: Euclidean<F>, B: Euclidean<F>, //Pair<A, B>: Euclidean<F>, Self: Sized + Mul<F, Output = <Self as AXPY>::Owned> + MulAssign<F> + Div<F, Output = <Self as AXPY>::Owned> + DivAssign<F> + Add<Self, Output = <Self as AXPY>::Owned> + Sub<Self, Output = <Self as AXPY>::Owned> + for<'b> Add<&'b Self, Output = <Self as AXPY>::Owned> + for<'b> Sub<&'b Self, Output = <Self as AXPY>::Owned> + AddAssign<Self> + for<'b> AddAssign<&'b Self> + SubAssign<Self> + for<'b> SubAssign<&'b Self> + Neg<Output = <Self as AXPY>::Owned>, { fn dot<I: Instance<Self>>(&self, other: I) -> F { other.eval_decompose(|Pair(u, v)| self.0.dot(u) + self.1.dot(v)) } fn norm2_squared(&self) -> F { self.0.norm2_squared() + self.1.norm2_squared() } fn dist2_squared<I: Instance<Self>>(&self, other: I) -> F { other.eval_decompose(|Pair(u, v)| self.0.dist2_squared(u) + self.1.dist2_squared(v)) } } impl<F, A, B, U, V> AXPY<Pair<U, V>> for Pair<A, B> where U: Space, V: Space, A: AXPY<U, Field = F>, B: AXPY<V, Field = F>, F: Num, Self: MulAssign<F> + DivAssign<F>, Pair<A, B>: MulAssign<F> + DivAssign<F>, //A::Owned: MulAssign<F>, //B::Owned: MulAssign<F>, //Pair<A::Owned, B::Owned>: AXPY<Pair<U, V>, Field = F>, { type Field = F; type Owned = Pair<A::Owned, B::Owned>; fn axpy<I: Instance<Pair<U, V>>>(&mut self, α: F, x: I, β: F) { x.eval_decompose(|Pair(u, v)| { self.0.axpy(α, u, β); self.1.axpy(α, v, β); }) } fn copy_from<I: Instance<Pair<U, V>>>(&mut self, x: I) { x.eval_decompose(|Pair(u, v)| { self.0.copy_from(u); self.1.copy_from(v); }) } fn scale_from<I: Instance<Pair<U, V>>>(&mut self, α: F, x: I) { x.eval_decompose(|Pair(u, v)| { self.0.scale_from(α, u); self.1.scale_from(α, v); }) } /// Return a similar zero as `self`. fn similar_origin(&self) -> Self::Owned { Pair(self.0.similar_origin(), self.1.similar_origin()) } /// Set self to zero. fn set_zero(&mut self) { self.0.set_zero(); self.1.set_zero(); } } /// [`Decomposition`] for working with [`Pair`]s. #[derive(Copy, Clone, Debug)] pub struct PairDecomposition<D, Q>(D, Q); impl<A: Space, B: Space> Space for Pair<A, B> { type Decomp = PairDecomposition<A::Decomp, B::Decomp>; } impl<A, B, D, Q> Decomposition<Pair<A, B>> for PairDecomposition<D, Q> where A: Space, B: Space, D: Decomposition<A>, Q: Decomposition<B>, { type Decomposition<'b> = Pair<D::Decomposition<'b>, Q::Decomposition<'b>> where Pair<A, B>: 'b; type Reference<'b> = Pair<D::Reference<'b>, Q::Reference<'b>> where Pair<A, B>: 'b; #[inline] fn lift<'b>(Pair(u, v): Self::Reference<'b>) -> Self::Decomposition<'b> { Pair(D::lift(u), Q::lift(v)) } } impl<A, B, U, V, D, Q> Instance<Pair<A, B>, PairDecomposition<D, Q>> for Pair<U, V> where A: Space, B: Space, D: Decomposition<A>, Q: Decomposition<B>, U: Instance<A, D>, V: Instance<B, Q>, { fn eval_decompose<'b, R>( self, f: impl FnOnce(Pair<D::Decomposition<'b>, Q::Decomposition<'b>>) -> R, ) -> R where Pair<A, B>: 'b, Self: 'b, { self.0 .eval_decompose(|a| self.1.eval_decompose(|b| f(Pair(a, b)))) } fn eval_ref_decompose<'b, R>( &'b self, f: impl FnOnce(Pair<D::Reference<'b>, Q::Reference<'b>>) -> R, ) -> R where Pair<A, B>: 'b, Self: 'b, { self.0 .eval_ref_decompose(|a| self.1.eval_ref_decompose(|b| f(Pair(a, b)))) } #[inline] fn cow<'b>(self) -> MyCow<'b, Pair<A, B>> where Self: 'b, { MyCow::Owned(Pair(self.0.own(), self.1.own())) } #[inline] fn own(self) -> Pair<A, B> { Pair(self.0.own(), self.1.own()) } } impl<'a, A, B, U, V, D, Q> Instance<Pair<A, B>, PairDecomposition<D, Q>> for &'a Pair<U, V> where A: Space, B: Space, D: Decomposition<A>, Q: Decomposition<B>, U: Instance<A, D>, V: Instance<B, Q>, &'a U: Instance<A, D>, &'a V: Instance<B, Q>, { fn eval_decompose<'b, R>( self, f: impl FnOnce(Pair<D::Decomposition<'b>, Q::Decomposition<'b>>) -> R, ) -> R where Pair<A, B>: 'b, Self: 'b, { self.0.eval_ref_decompose(|a| { self.1 .eval_ref_decompose(|b| f(Pair(D::lift(a), Q::lift(b)))) }) } fn eval_ref_decompose<'b, R>( &'b self, f: impl FnOnce(Pair<D::Reference<'b>, Q::Reference<'b>>) -> R, ) -> R where Pair<A, B>: 'b, Self: 'b, { self.0 .eval_ref_decompose(|a| self.1.eval_ref_decompose(|b| f(Pair(a, b)))) } #[inline] fn cow<'b>(self) -> MyCow<'b, Pair<A, B>> where Self: 'b, { MyCow::Owned(self.own()) } #[inline] fn own(self) -> Pair<A, B> { let Pair(ref u, ref v) = self; Pair(u.own(), v.own()) } } impl<A, B, D, Q> DecompositionMut<Pair<A, B>> for PairDecomposition<D, Q> where A: Space, B: Space, D: DecompositionMut<A>, Q: DecompositionMut<B>, { type ReferenceMut<'b> = Pair<D::ReferenceMut<'b>, Q::ReferenceMut<'b>> where Pair<A, B>: 'b; } impl<A, B, U, V, D, Q> InstanceMut<Pair<A, B>, PairDecomposition<D, Q>> for Pair<U, V> where A: Space, B: Space, D: DecompositionMut<A>, Q: DecompositionMut<B>, U: InstanceMut<A, D>, V: InstanceMut<B, Q>, { #[inline] fn ref_instance_mut( &mut self, ) -> <PairDecomposition<D, Q> as DecompositionMut<Pair<A, B>>>::ReferenceMut<'_> { Pair(self.0.ref_instance_mut(), self.1.ref_instance_mut()) } } impl<'a, A, B, U, V, D, Q> InstanceMut<Pair<A, B>, PairDecomposition<D, Q>> for &'a mut Pair<U, V> where A: Space, B: Space, D: DecompositionMut<A>, Q: DecompositionMut<B>, U: InstanceMut<A, D>, V: InstanceMut<B, Q>, { #[inline] fn ref_instance_mut( &mut self, ) -> <PairDecomposition<D, Q> as DecompositionMut<Pair<A, B>>>::ReferenceMut<'_> { Pair(self.0.ref_instance_mut(), self.1.ref_instance_mut()) } } impl<F, A, B, ExpA, ExpB, ExpJ> Norm<PairNorm<ExpA, ExpB, ExpJ>, F> for Pair<A, B> where F: Num, ExpA: NormExponent, ExpB: NormExponent, ExpJ: NormExponent, A: Norm<ExpA, F>, B: Norm<ExpB, F>, Loc<2, F>: Norm<ExpJ, F>, { fn norm(&self, PairNorm(expa, expb, expj): PairNorm<ExpA, ExpB, ExpJ>) -> F { Loc([self.0.norm(expa), self.1.norm(expb)]).norm(expj) } } impl<F: Float, A, B> Normed<F> for Pair<A, B> where A: Normed<F>, B: Normed<F>, { type NormExp = PairNorm<A::NormExp, B::NormExp, L2>; #[inline] fn norm_exponent(&self) -> Self::NormExp { PairNorm(self.0.norm_exponent(), self.1.norm_exponent(), L2) } #[inline] fn is_zero(&self) -> bool { self.0.is_zero() && self.1.is_zero() } } impl<F: Float, A, B> HasDual<F> for Pair<A, B> where A: HasDual<F>, B: HasDual<F>, { type DualSpace = Pair<A::DualSpace, B::DualSpace>; fn dual_origin(&self) -> <Self::DualSpace as AXPY>::Owned { Pair(self.0.dual_origin(), self.1.dual_origin()) } } #[cfg(feature = "pyo3")] mod python { use super::Pair; use pyo3::conversion::FromPyObject; use pyo3::types::{PyAny, PyTuple}; use pyo3::{Bound, IntoPyObject, PyErr, PyResult, Python}; impl<'py, A, B> IntoPyObject<'py> for Pair<A, B> where A: IntoPyObject<'py>, B: IntoPyObject<'py>, { type Target = PyTuple; type Error = PyErr; type Output = Bound<'py, Self::Target>; fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> { (self.0, self.1).into_pyobject(py) } } impl<'a, 'py, A, B> IntoPyObject<'py> for &'a mut Pair<A, B> where &'a mut A: IntoPyObject<'py>, &'a mut B: IntoPyObject<'py>, { type Target = PyTuple; type Error = PyErr; type Output = Bound<'py, Self::Target>; fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> { (&mut self.0, &mut self.1).into_pyobject(py) } } impl<'a, 'py, A, B> IntoPyObject<'py> for &'a Pair<A, B> where &'a A: IntoPyObject<'py>, &'a B: IntoPyObject<'py>, { type Target = PyTuple; type Error = PyErr; type Output = Bound<'py, Self::Target>; fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> { (&self.0, &self.1).into_pyobject(py) } } impl<'py, A, B> FromPyObject<'py> for Pair<A, B> where A: Clone + FromPyObject<'py>, B: Clone + FromPyObject<'py>, { fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> { FromPyObject::extract_bound(ob).map(|(a, b)| Pair(a, b)) } } }