src/bisection_tree/btfn.rs

branch
dev
changeset 77
cf8ef9463664
parent 29
7fd0984743b5
--- a/src/bisection_tree/btfn.rs	Fri Apr 28 14:02:18 2023 +0300
+++ b/src/bisection_tree/btfn.rs	Tue Jul 18 15:44:10 2023 +0300
@@ -420,36 +420,62 @@
     }
 }
 
-impl<'a, F : Float, G, BT, V, const N : usize> Differentiable<&'a Loc<F, N>>
+impl<'a, F : Float, G, BT, V, W, const N : usize> Differentiable<&'a Loc<F, N>>
 for BTFN<F, G, BT, N>
 where BT : BTImpl<F, N>,
       G : SupportGenerator<F, N, Id=BT::Data>,
-      G::SupportType : LocalAnalysis<F, BT::Agg, N> + Differentiable<&'a Loc<F, N>, Output = V>,
-      V : Sum {
+      G::SupportType : LocalAnalysis<F, BT::Agg, N>
+                       + Differentiable<&'a Loc<F, N>, Output = V>
+                       + Apply<&'a Loc<F, N>, Output = W>,
+      V : Sum,
+      W : Sum {
 
     type Output = V;
 
-    fn differential(&self, x : &'a Loc<F, N>) -> Self::Output {
+    fn differential(&self, x : &'a Loc<F, N>) -> V {
         self.bt.iter_at(x)
             .map(|&d| self.generator.support_for(d).differential(x))
             .sum()
     }
+
+    fn linearisation_error_gen(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>, z : &'a Loc<F, N>) -> W {
+        self.bt.iter_at(x)
+            .map(|&d| self.generator.support_for(d).linearisation_error_gen(x, y, z)).sum()
+    }
+
+    fn linearisation_error(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>) -> W {
+        self.bt.iter_at(x)
+            .map(|&d| self.generator.support_for(d).linearisation_error(x, y)).sum()
+    }
 }
 
-impl<F : Float, G, BT, V, const N : usize> Differentiable<Loc<F, N>>
+impl<F : Float, G, BT, V, W, const N : usize> Differentiable<Loc<F, N>>
 for BTFN<F, G, BT, N>
 where BT : BTImpl<F, N>,
       G : SupportGenerator<F, N, Id=BT::Data>,
-      G::SupportType : LocalAnalysis<F, BT::Agg, N> + Differentiable<Loc<F, N>, Output = V>,
-      V : Sum {
+      G::SupportType : LocalAnalysis<F, BT::Agg, N>
+                       + Differentiable<Loc<F, N>, Output = V>
+                       + Apply<Loc<F, N>, Output = W>,
+      V : Sum,
+      W : Sum {
 
     type Output = V;
 
-    fn differential(&self, x : Loc<F, N>) -> Self::Output {
+    fn differential(&self, x : Loc<F, N>) -> V {
         self.bt.iter_at(&x)
             .map(|&d| self.generator.support_for(d).differential(x))
             .sum()
     }
+
+    fn linearisation_error_gen(&self, x : Loc<F, N>, y : Loc<F, N>, z : Loc<F, N>) -> W {
+        self.bt.iter_at(&x)
+            .map(|&d| self.generator.support_for(d).linearisation_error_gen(x, y, z)).sum()
+    }
+
+    fn linearisation_error(&self, x : Loc<F, N>, y : Loc<F, N>) -> W {
+        self.bt.iter_at(&x)
+            .map(|&d| self.generator.support_for(d).linearisation_error(x, y)).sum()
+    }
 }
 
 //

mercurial