--- a/src/bisection_tree/btfn.rs Fri Apr 28 14:02:18 2023 +0300 +++ b/src/bisection_tree/btfn.rs Tue Jul 18 15:44:10 2023 +0300 @@ -420,36 +420,62 @@ } } -impl<'a, F : Float, G, BT, V, const N : usize> Differentiable<&'a Loc<F, N>> +impl<'a, F : Float, G, BT, V, W, const N : usize> Differentiable<&'a Loc<F, N>> for BTFN<F, G, BT, N> where BT : BTImpl<F, N>, G : SupportGenerator<F, N, Id=BT::Data>, - G::SupportType : LocalAnalysis<F, BT::Agg, N> + Differentiable<&'a Loc<F, N>, Output = V>, - V : Sum { + G::SupportType : LocalAnalysis<F, BT::Agg, N> + + Differentiable<&'a Loc<F, N>, Output = V> + + Apply<&'a Loc<F, N>, Output = W>, + V : Sum, + W : Sum { type Output = V; - fn differential(&self, x : &'a Loc<F, N>) -> Self::Output { + fn differential(&self, x : &'a Loc<F, N>) -> V { self.bt.iter_at(x) .map(|&d| self.generator.support_for(d).differential(x)) .sum() } + + fn linearisation_error_gen(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>, z : &'a Loc<F, N>) -> W { + self.bt.iter_at(x) + .map(|&d| self.generator.support_for(d).linearisation_error_gen(x, y, z)).sum() + } + + fn linearisation_error(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>) -> W { + self.bt.iter_at(x) + .map(|&d| self.generator.support_for(d).linearisation_error(x, y)).sum() + } } -impl<F : Float, G, BT, V, const N : usize> Differentiable<Loc<F, N>> +impl<F : Float, G, BT, V, W, const N : usize> Differentiable<Loc<F, N>> for BTFN<F, G, BT, N> where BT : BTImpl<F, N>, G : SupportGenerator<F, N, Id=BT::Data>, - G::SupportType : LocalAnalysis<F, BT::Agg, N> + Differentiable<Loc<F, N>, Output = V>, - V : Sum { + G::SupportType : LocalAnalysis<F, BT::Agg, N> + + Differentiable<Loc<F, N>, Output = V> + + Apply<Loc<F, N>, Output = W>, + V : Sum, + W : Sum { type Output = V; - fn differential(&self, x : Loc<F, N>) -> Self::Output { + fn differential(&self, x : Loc<F, N>) -> V { self.bt.iter_at(&x) .map(|&d| self.generator.support_for(d).differential(x)) .sum() } + + fn linearisation_error_gen(&self, x : Loc<F, N>, y : Loc<F, N>, z : Loc<F, N>) -> W { + self.bt.iter_at(&x) + .map(|&d| self.generator.support_for(d).linearisation_error_gen(x, y, z)).sum() + } + + fn linearisation_error(&self, x : Loc<F, N>, y : Loc<F, N>) -> W { + self.bt.iter_at(&x) + .map(|&d| self.generator.support_for(d).linearisation_error(x, y)).sum() + } } //