Tue, 18 Jul 2023 15:44:10 +0300
linearisation_error
--- a/src/bisection_tree/btfn.rs Fri Apr 28 14:02:18 2023 +0300 +++ b/src/bisection_tree/btfn.rs Tue Jul 18 15:44:10 2023 +0300 @@ -420,36 +420,62 @@ } } -impl<'a, F : Float, G, BT, V, const N : usize> Differentiable<&'a Loc<F, N>> +impl<'a, F : Float, G, BT, V, W, const N : usize> Differentiable<&'a Loc<F, N>> for BTFN<F, G, BT, N> where BT : BTImpl<F, N>, G : SupportGenerator<F, N, Id=BT::Data>, - G::SupportType : LocalAnalysis<F, BT::Agg, N> + Differentiable<&'a Loc<F, N>, Output = V>, - V : Sum { + G::SupportType : LocalAnalysis<F, BT::Agg, N> + + Differentiable<&'a Loc<F, N>, Output = V> + + Apply<&'a Loc<F, N>, Output = W>, + V : Sum, + W : Sum { type Output = V; - fn differential(&self, x : &'a Loc<F, N>) -> Self::Output { + fn differential(&self, x : &'a Loc<F, N>) -> V { self.bt.iter_at(x) .map(|&d| self.generator.support_for(d).differential(x)) .sum() } + + fn linearisation_error_gen(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>, z : &'a Loc<F, N>) -> W { + self.bt.iter_at(x) + .map(|&d| self.generator.support_for(d).linearisation_error_gen(x, y, z)).sum() + } + + fn linearisation_error(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>) -> W { + self.bt.iter_at(x) + .map(|&d| self.generator.support_for(d).linearisation_error(x, y)).sum() + } } -impl<F : Float, G, BT, V, const N : usize> Differentiable<Loc<F, N>> +impl<F : Float, G, BT, V, W, const N : usize> Differentiable<Loc<F, N>> for BTFN<F, G, BT, N> where BT : BTImpl<F, N>, G : SupportGenerator<F, N, Id=BT::Data>, - G::SupportType : LocalAnalysis<F, BT::Agg, N> + Differentiable<Loc<F, N>, Output = V>, - V : Sum { + G::SupportType : LocalAnalysis<F, BT::Agg, N> + + Differentiable<Loc<F, N>, Output = V> + + Apply<Loc<F, N>, Output = W>, + V : Sum, + W : Sum { type Output = V; - fn differential(&self, x : Loc<F, N>) -> Self::Output { + fn differential(&self, x : Loc<F, N>) -> V { self.bt.iter_at(&x) .map(|&d| self.generator.support_for(d).differential(x)) .sum() } + + fn linearisation_error_gen(&self, x : Loc<F, N>, y : Loc<F, N>, z : Loc<F, N>) -> W { + self.bt.iter_at(&x) + .map(|&d| self.generator.support_for(d).linearisation_error_gen(x, y, z)).sum() + } + + fn linearisation_error(&self, x : Loc<F, N>, y : Loc<F, N>) -> W { + self.bt.iter_at(&x) + .map(|&d| self.generator.support_for(d).linearisation_error(x, y)).sum() + } } //
--- a/src/bisection_tree/either.rs Fri Apr 28 14:02:18 2023 +0300 +++ b/src/bisection_tree/either.rs Tue Jul 18 15:44:10 2023 +0300 @@ -190,17 +190,34 @@ } } -impl<F, S1, S2, X> Differentiable<X> for EitherSupport<S1, S2> -where S1 : Differentiable<X, Output=F>, - S2 : Differentiable<X, Output=F> { - type Output = F; +impl<F, D, S1, S2, X : Clone> Differentiable<X> for EitherSupport<S1, S2> +where S1 : Differentiable<X, Output=D> + Apply<X, Output = F>, + S2 : Differentiable<X, Output=D> + Apply<X, Output = F> { + type Output = D; + #[inline] - fn differential(&self, x : X) -> F { + fn differential(&self, x : X) -> D { match self { EitherSupport::Left(ref a) => a.differential(x), EitherSupport::Right(ref b) => b.differential(x), } } + + #[inline] + fn linearisation_error(&self, x : X, y : X) -> F { + match self { + EitherSupport::Left(ref a) => a.linearisation_error(x, y), + EitherSupport::Right(ref b) => b.linearisation_error(x, y), + } + } + + #[inline] + fn linearisation_error_gen(&self, x : X, y : X, z : X) -> F { + match self { + EitherSupport::Left(ref a) => a.linearisation_error_gen(x, y, z), + EitherSupport::Right(ref b) => b.linearisation_error_gen(x, y, z), + } + } } macro_rules! make_either_scalarop_rhs {
--- a/src/bisection_tree/support.rs Fri Apr 28 14:02:18 2023 +0300 +++ b/src/bisection_tree/support.rs Tue Jul 18 15:44:10 2023 +0300 @@ -145,22 +145,48 @@ } } -impl<'a, T, V, F : Float, const N : usize> Differentiable<&'a Loc<F, N>> for Shift<T,F,N> -where T : Differentiable<Loc<F, N>, Output=V> { +impl<'a, T, V, W, F : Float, const N : usize> Differentiable<&'a Loc<F, N>> +for Shift<T,F,N> +where T : Differentiable<Loc<F, N>, Output=V> + Apply<Loc<F, N>, Output=W> { type Output = V; #[inline] - fn differential(&self, x : &'a Loc<F, N>) -> Self::Output { + fn differential(&self, x : &'a Loc<F, N>) -> V { self.base_fn.differential(x - &self.shift) } + + #[inline] + fn linearisation_error(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>) -> W { + self.base_fn + .linearisation_error(x - &self.shift, y - &self.shift) + } + + #[inline] + fn linearisation_error_gen(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>, z : &'a Loc<F, N>) -> W { + self.base_fn + .linearisation_error_gen(x - &self.shift, y - &self.shift, z - &self.shift) + } } -impl<'a, T, V, F : Float, const N : usize> Differentiable<Loc<F, N>> for Shift<T,F,N> -where T : Differentiable<Loc<F, N>, Output=V> { +impl<'a, T, V, W, F : Float, const N : usize> Differentiable<Loc<F, N>> +for Shift<T,F,N> +where T : Differentiable<Loc<F, N>, Output=V> + Apply<Loc<F, N>, Output=W> { type Output = V; #[inline] - fn differential(&self, x : Loc<F, N>) -> Self::Output { + fn differential(&self, x : Loc<F, N>) -> V { self.base_fn.differential(x - &self.shift) } + + #[inline] + fn linearisation_error(&self, x : Loc<F, N>, y : Loc<F, N>) -> W { + self.base_fn + .linearisation_error(x - &self.shift, y - &self.shift) + } + + #[inline] + fn linearisation_error_gen(&self, x : Loc<F, N>, y : Loc<F, N>, z : Loc<F, N>) -> W { + self.base_fn + .linearisation_error_gen(x - &self.shift, y - &self.shift, z - &self.shift) + } } impl<'a, T, F : Float, const N : usize> Support<F,N> for Shift<T,F,N> @@ -250,26 +276,52 @@ } } -impl<'a, T, V, F : Float, C, const N : usize> Differentiable<&'a Loc<F, N>> for Weighted<T, C> -where T : for<'b> Differentiable<&'b Loc<F, N>, Output=V>, +impl<'a, T, V, W, F : Float, C, const N : usize> Differentiable<&'a Loc<F, N>> for Weighted<T, C> +where T : for<'b> Differentiable<&'b Loc<F, N>, Output=V> + + for<'b> Apply<&'b Loc<F, N>, Output=W>, V : std::ops::Mul<F, Output=V>, + W : std::ops::Mul<F, Output=W>, C : Constant<Type=F> { type Output = V; + #[inline] - fn differential(&self, x : &'a Loc<F, N>) -> Self::Output { + fn differential(&self, x : &'a Loc<F, N>) -> V { self.base_fn.differential(x) * self.weight.value() } + + #[inline] + fn linearisation_error(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>) -> W { + self.base_fn.linearisation_error(x, y) * self.weight.value() + } + + #[inline] + fn linearisation_error_gen(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>, z : &'a Loc<F, N>) -> W { + self.base_fn.linearisation_error_gen(x, y, z) * self.weight.value() + } } -impl<'a, T, V, F : Float, C, const N : usize> Differentiable<Loc<F, N>> for Weighted<T, C> -where T : Differentiable<Loc<F, N>, Output=V>, +impl<'a, T, V, W, F : Float, C, const N : usize> Differentiable<Loc<F, N>> +for Weighted<T, C> +where T : Differentiable<Loc<F, N>, Output=V> + Apply<Loc<F, N>, Output=W>, V : std::ops::Mul<F, Output=V>, + W : std::ops::Mul<F, Output=W>, C : Constant<Type=F> { type Output = V; + #[inline] - fn differential(&self, x : Loc<F, N>) -> Self::Output { + fn differential(&self, x : Loc<F, N>) -> V { self.base_fn.differential(x) * self.weight.value() } + + #[inline] + fn linearisation_error(&self, x : Loc<F, N>, y : Loc<F, N>) -> W { + self.base_fn.linearisation_error(x, y) * self.weight.value() + } + + #[inline] + fn linearisation_error_gen(&self, x : Loc<F, N>, y : Loc<F, N>, z : Loc<F, N>) -> W { + self.base_fn.linearisation_error_gen(x, y, z) * self.weight.value() + } } impl<'a, T, F : Float, C, const N : usize> Support<F,N> for Weighted<T, C>
--- a/src/mapping.rs Fri Apr 28 14:02:18 2023 +0300 +++ b/src/mapping.rs Tue Jul 18 15:44:10 2023 +0300 @@ -51,11 +51,21 @@ /// Trait for calculation the differential of `Self` as a mathematical function on `X`. -pub trait Differentiable<X> { +pub trait Differentiable<X : Clone> : Apply<X> { type Output; /// Compute the differential of `self` at `x`. - fn differential(&self, x : X) -> Self::Output; + fn differential(&self, x : X) -> <Self as Differentiable<X>>::Output; + + /// Compute the linearisation error of `self` at `x` for `y`. + fn linearisation_error(&self, x : X, y : X) -> <Self as Apply<X>>::Output { + let z = x.clone(); + self.linearisation_error_gen(x, y, z) + } + + /// Compute the linearisation error of `self` at `x` for `y`, with + /// derivative calculated at `z` + fn linearisation_error_gen(&self, x : X, y : X, z : X) -> <Self as Apply<X>>::Output; } @@ -63,15 +73,15 @@ /// `Differential`. /// /// This is automatically implemented when the relevant [`Differentiate`] are implemented. -pub trait DifferentiableMapping<Domain> +pub trait DifferentiableMapping<Domain : Clone> : Mapping<Domain> + Differentiable<Domain, Output=Self::Differential> - + for<'a> Differentiable<&'a Domain, Output=Self::Differential>{ + + for<'a> Differentiable<&'a Domain, Output=Self::Differential> { type Differential; } -impl<Domain, Differential, T> DifferentiableMapping<Domain> for T +impl<Domain : Clone, Differential, T> DifferentiableMapping<Domain> for T where T : Mapping<Domain> + Differentiable<Domain, Output=Differential> + for<'a> Differentiable<&'a Domain, Output=Differential> { @@ -111,7 +121,24 @@ type Output = M::Differential; - fn differential(&self, x : Domain) -> Self::Output { - self.components.iter().map(|c| c.differential(x)).sum() + fn differential(&self, x : Domain) -> M::Differential { + self.components + .iter() + .map(|c| c.differential(x)) + .sum() + } + + fn linearisation_error(&self, x : Domain, y : Domain) -> M::Codomain { + self.components + .iter() + .map(|c| c.linearisation_error(x, y)) + .sum() + } + + fn linearisation_error_gen(&self, x : Domain, y : Domain, z : Domain) -> M::Codomain { + self.components + .iter() + .map(|c| c.linearisation_error_gen(x, y, z)) + .sum() } }