# HG changeset patch # User Tuomo Valkonen # Date 1689684250 -10800 # Node ID cf8ef9463664db457ca9074635d9eb387cb2e4e5 # Parent 9f2214c961cb7d5dd04b9fb0a97d1a66283d6be8 linearisation_error diff -r 9f2214c961cb -r cf8ef9463664 src/bisection_tree/btfn.rs --- a/src/bisection_tree/btfn.rs Fri Apr 28 14:02:18 2023 +0300 +++ b/src/bisection_tree/btfn.rs Tue Jul 18 15:44:10 2023 +0300 @@ -420,36 +420,62 @@ } } -impl<'a, F : Float, G, BT, V, const N : usize> Differentiable<&'a Loc> +impl<'a, F : Float, G, BT, V, W, const N : usize> Differentiable<&'a Loc> for BTFN where BT : BTImpl, G : SupportGenerator, - G::SupportType : LocalAnalysis + Differentiable<&'a Loc, Output = V>, - V : Sum { + G::SupportType : LocalAnalysis + + Differentiable<&'a Loc, Output = V> + + Apply<&'a Loc, Output = W>, + V : Sum, + W : Sum { type Output = V; - fn differential(&self, x : &'a Loc) -> Self::Output { + fn differential(&self, x : &'a Loc) -> V { self.bt.iter_at(x) .map(|&d| self.generator.support_for(d).differential(x)) .sum() } + + fn linearisation_error_gen(&self, x : &'a Loc, y : &'a Loc, z : &'a Loc) -> W { + self.bt.iter_at(x) + .map(|&d| self.generator.support_for(d).linearisation_error_gen(x, y, z)).sum() + } + + fn linearisation_error(&self, x : &'a Loc, y : &'a Loc) -> W { + self.bt.iter_at(x) + .map(|&d| self.generator.support_for(d).linearisation_error(x, y)).sum() + } } -impl Differentiable> +impl Differentiable> for BTFN where BT : BTImpl, G : SupportGenerator, - G::SupportType : LocalAnalysis + Differentiable, Output = V>, - V : Sum { + G::SupportType : LocalAnalysis + + Differentiable, Output = V> + + Apply, Output = W>, + V : Sum, + W : Sum { type Output = V; - fn differential(&self, x : Loc) -> Self::Output { + fn differential(&self, x : Loc) -> V { self.bt.iter_at(&x) .map(|&d| self.generator.support_for(d).differential(x)) .sum() } + + fn linearisation_error_gen(&self, x : Loc, y : Loc, z : Loc) -> W { + self.bt.iter_at(&x) + .map(|&d| self.generator.support_for(d).linearisation_error_gen(x, y, z)).sum() + } + + fn linearisation_error(&self, x : Loc, y : Loc) -> W { + self.bt.iter_at(&x) + .map(|&d| self.generator.support_for(d).linearisation_error(x, y)).sum() + } } // diff -r 9f2214c961cb -r cf8ef9463664 src/bisection_tree/either.rs --- a/src/bisection_tree/either.rs Fri Apr 28 14:02:18 2023 +0300 +++ b/src/bisection_tree/either.rs Tue Jul 18 15:44:10 2023 +0300 @@ -190,17 +190,34 @@ } } -impl Differentiable for EitherSupport -where S1 : Differentiable, - S2 : Differentiable { - type Output = F; +impl Differentiable for EitherSupport +where S1 : Differentiable + Apply, + S2 : Differentiable + Apply { + type Output = D; + #[inline] - fn differential(&self, x : X) -> F { + fn differential(&self, x : X) -> D { match self { EitherSupport::Left(ref a) => a.differential(x), EitherSupport::Right(ref b) => b.differential(x), } } + + #[inline] + fn linearisation_error(&self, x : X, y : X) -> F { + match self { + EitherSupport::Left(ref a) => a.linearisation_error(x, y), + EitherSupport::Right(ref b) => b.linearisation_error(x, y), + } + } + + #[inline] + fn linearisation_error_gen(&self, x : X, y : X, z : X) -> F { + match self { + EitherSupport::Left(ref a) => a.linearisation_error_gen(x, y, z), + EitherSupport::Right(ref b) => b.linearisation_error_gen(x, y, z), + } + } } macro_rules! make_either_scalarop_rhs { diff -r 9f2214c961cb -r cf8ef9463664 src/bisection_tree/support.rs --- a/src/bisection_tree/support.rs Fri Apr 28 14:02:18 2023 +0300 +++ b/src/bisection_tree/support.rs Tue Jul 18 15:44:10 2023 +0300 @@ -145,22 +145,48 @@ } } -impl<'a, T, V, F : Float, const N : usize> Differentiable<&'a Loc> for Shift -where T : Differentiable, Output=V> { +impl<'a, T, V, W, F : Float, const N : usize> Differentiable<&'a Loc> +for Shift +where T : Differentiable, Output=V> + Apply, Output=W> { type Output = V; #[inline] - fn differential(&self, x : &'a Loc) -> Self::Output { + fn differential(&self, x : &'a Loc) -> V { self.base_fn.differential(x - &self.shift) } + + #[inline] + fn linearisation_error(&self, x : &'a Loc, y : &'a Loc) -> W { + self.base_fn + .linearisation_error(x - &self.shift, y - &self.shift) + } + + #[inline] + fn linearisation_error_gen(&self, x : &'a Loc, y : &'a Loc, z : &'a Loc) -> W { + self.base_fn + .linearisation_error_gen(x - &self.shift, y - &self.shift, z - &self.shift) + } } -impl<'a, T, V, F : Float, const N : usize> Differentiable> for Shift -where T : Differentiable, Output=V> { +impl<'a, T, V, W, F : Float, const N : usize> Differentiable> +for Shift +where T : Differentiable, Output=V> + Apply, Output=W> { type Output = V; #[inline] - fn differential(&self, x : Loc) -> Self::Output { + fn differential(&self, x : Loc) -> V { self.base_fn.differential(x - &self.shift) } + + #[inline] + fn linearisation_error(&self, x : Loc, y : Loc) -> W { + self.base_fn + .linearisation_error(x - &self.shift, y - &self.shift) + } + + #[inline] + fn linearisation_error_gen(&self, x : Loc, y : Loc, z : Loc) -> W { + self.base_fn + .linearisation_error_gen(x - &self.shift, y - &self.shift, z - &self.shift) + } } impl<'a, T, F : Float, const N : usize> Support for Shift @@ -250,26 +276,52 @@ } } -impl<'a, T, V, F : Float, C, const N : usize> Differentiable<&'a Loc> for Weighted -where T : for<'b> Differentiable<&'b Loc, Output=V>, +impl<'a, T, V, W, F : Float, C, const N : usize> Differentiable<&'a Loc> for Weighted +where T : for<'b> Differentiable<&'b Loc, Output=V> + + for<'b> Apply<&'b Loc, Output=W>, V : std::ops::Mul, + W : std::ops::Mul, C : Constant { type Output = V; + #[inline] - fn differential(&self, x : &'a Loc) -> Self::Output { + fn differential(&self, x : &'a Loc) -> V { self.base_fn.differential(x) * self.weight.value() } + + #[inline] + fn linearisation_error(&self, x : &'a Loc, y : &'a Loc) -> W { + self.base_fn.linearisation_error(x, y) * self.weight.value() + } + + #[inline] + fn linearisation_error_gen(&self, x : &'a Loc, y : &'a Loc, z : &'a Loc) -> W { + self.base_fn.linearisation_error_gen(x, y, z) * self.weight.value() + } } -impl<'a, T, V, F : Float, C, const N : usize> Differentiable> for Weighted -where T : Differentiable, Output=V>, +impl<'a, T, V, W, F : Float, C, const N : usize> Differentiable> +for Weighted +where T : Differentiable, Output=V> + Apply, Output=W>, V : std::ops::Mul, + W : std::ops::Mul, C : Constant { type Output = V; + #[inline] - fn differential(&self, x : Loc) -> Self::Output { + fn differential(&self, x : Loc) -> V { self.base_fn.differential(x) * self.weight.value() } + + #[inline] + fn linearisation_error(&self, x : Loc, y : Loc) -> W { + self.base_fn.linearisation_error(x, y) * self.weight.value() + } + + #[inline] + fn linearisation_error_gen(&self, x : Loc, y : Loc, z : Loc) -> W { + self.base_fn.linearisation_error_gen(x, y, z) * self.weight.value() + } } impl<'a, T, F : Float, C, const N : usize> Support for Weighted diff -r 9f2214c961cb -r cf8ef9463664 src/mapping.rs --- a/src/mapping.rs Fri Apr 28 14:02:18 2023 +0300 +++ b/src/mapping.rs Tue Jul 18 15:44:10 2023 +0300 @@ -51,11 +51,21 @@ /// Trait for calculation the differential of `Self` as a mathematical function on `X`. -pub trait Differentiable { +pub trait Differentiable : Apply { type Output; /// Compute the differential of `self` at `x`. - fn differential(&self, x : X) -> Self::Output; + fn differential(&self, x : X) -> >::Output; + + /// Compute the linearisation error of `self` at `x` for `y`. + fn linearisation_error(&self, x : X, y : X) -> >::Output { + let z = x.clone(); + self.linearisation_error_gen(x, y, z) + } + + /// Compute the linearisation error of `self` at `x` for `y`, with + /// derivative calculated at `z` + fn linearisation_error_gen(&self, x : X, y : X, z : X) -> >::Output; } @@ -63,15 +73,15 @@ /// `Differential`. /// /// This is automatically implemented when the relevant [`Differentiate`] are implemented. -pub trait DifferentiableMapping +pub trait DifferentiableMapping : Mapping + Differentiable - + for<'a> Differentiable<&'a Domain, Output=Self::Differential>{ + + for<'a> Differentiable<&'a Domain, Output=Self::Differential> { type Differential; } -impl DifferentiableMapping for T +impl DifferentiableMapping for T where T : Mapping + Differentiable + for<'a> Differentiable<&'a Domain, Output=Differential> { @@ -111,7 +121,24 @@ type Output = M::Differential; - fn differential(&self, x : Domain) -> Self::Output { - self.components.iter().map(|c| c.differential(x)).sum() + fn differential(&self, x : Domain) -> M::Differential { + self.components + .iter() + .map(|c| c.differential(x)) + .sum() + } + + fn linearisation_error(&self, x : Domain, y : Domain) -> M::Codomain { + self.components + .iter() + .map(|c| c.linearisation_error(x, y)) + .sum() + } + + fn linearisation_error_gen(&self, x : Domain, y : Domain, z : Domain) -> M::Codomain { + self.components + .iter() + .map(|c| c.linearisation_error_gen(x, y, z)) + .sum() } }