Implement Differentiate for BTFN dev

Fri, 28 Apr 2023 08:32:15 +0300

author
Tuomo Valkonen <tuomov@iki.fi>
date
Fri, 28 Apr 2023 08:32:15 +0300
branch
dev
changeset 27
00029c20c0ee
parent 26
cc68841e758f
child 28
331345346e7b

Implement Differentiate for BTFN

src/bisection_tree/btfn.rs file | annotate | diff | comparison | revisions
src/bisection_tree/either.rs file | annotate | diff | comparison | revisions
--- a/src/bisection_tree/btfn.rs	Fri Apr 28 08:27:17 2023 +0300
+++ b/src/bisection_tree/btfn.rs	Fri Apr 28 08:32:15 2023 +0300
@@ -4,7 +4,7 @@
 use std::marker::PhantomData;
 use std::sync::Arc;
 use crate::types::Float;
-use crate::mapping::{Apply, Mapping};
+use crate::mapping::{Apply, Mapping, Differentiate};
 //use crate::linops::{Apply, Linear};
 use crate::sets::Set;
 use crate::sets::Cube;
@@ -386,10 +386,8 @@
 
 make_btfn_unaryop!(Neg, neg);
 
-
-
 //
-// Mapping
+// Apply, Mapping, Differentiate
 //
 
 impl<'a, F : Float, G, BT, V, const N : usize> Apply<&'a Loc<F, N>>
@@ -422,6 +420,42 @@
     }
 }
 
+impl<'a, F : Float, G, BT, V, const N : usize> Differentiate<&'a Loc<F, N>>
+for BTFN<F, G, BT, N>
+where BT : BTImpl<F, N>,
+      G : SupportGenerator<F, N, Id=BT::Data>,
+      G::SupportType : LocalAnalysis<F, BT::Agg, N> + Differentiate<&'a Loc<F, N>, Output = V>,
+      V : Sum {
+
+    type Output = V;
+
+    fn differential(&self, x : &'a Loc<F, N>) -> Self::Output {
+        self.bt.iter_at(x)
+            .map(|&d| self.generator.support_for(d).differential(x))
+            .sum()
+    }
+}
+
+impl<F : Float, G, BT, V, const N : usize> Differentiate<Loc<F, N>>
+for BTFN<F, G, BT, N>
+where BT : BTImpl<F, N>,
+      G : SupportGenerator<F, N, Id=BT::Data>,
+      G::SupportType : LocalAnalysis<F, BT::Agg, N> + Differentiate<Loc<F, N>, Output = V>,
+      V : Sum {
+
+    type Output = V;
+
+    fn differential(&self, x : Loc<F, N>) -> Self::Output {
+        self.bt.iter_at(&x)
+            .map(|&d| self.generator.support_for(d).differential(x))
+            .sum()
+    }
+}
+
+//
+// GlobalAnalysis
+//
+
 impl<F : Float, G, BT, const N : usize> GlobalAnalysis<F, BT::Agg>
 for BTFN<F, G, BT, N>
 where BT : BTImpl<F, N>,
--- a/src/bisection_tree/either.rs	Fri Apr 28 08:27:17 2023 +0300
+++ b/src/bisection_tree/either.rs	Fri Apr 28 08:32:15 2023 +0300
@@ -3,8 +3,8 @@
 use std::sync::Arc;
 
 use crate::types::*;
-use crate::mapping::Apply;
-use crate::iter::{Mappable,MapF,MapZ};
+use crate::mapping::{Apply, Differentiate};
+use crate::iter::{Mappable, MapF, MapZ};
 use crate::sets::Cube;
 use crate::loc::Loc;
 
@@ -190,6 +190,19 @@
     }
 }
 
+impl<F, S1, S2, X> Differentiate<X> for EitherSupport<S1, S2>
+where S1 : Differentiate<X, Output=F>,
+      S2 : Differentiate<X, Output=F> {
+    type Output = F;
+    #[inline]
+    fn differential(&self, x : X) -> F {
+        match self {
+            EitherSupport::Left(ref a) => a.differential(x),
+            EitherSupport::Right(ref b) => b.differential(x),
+        }
+    }
+}
+
 macro_rules! make_either_scalarop_rhs {
     ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => {
         impl<F : Float, G1, G2>

mercurial