Implement Differentiable for Weighted and Shift dev

Fri, 28 Apr 2023 14:02:18 +0300

author
Tuomo Valkonen <tuomov@iki.fi>
date
Fri, 28 Apr 2023 14:02:18 +0300
branch
dev
changeset 30
9f2214c961cb
parent 29
7fd0984743b5
child 31
50a77e4efcbb
child 77
cf8ef9463664

Implement Differentiable for Weighted and Shift

src/bisection_tree/support.rs file | annotate | diff | comparison | revisions
--- a/src/bisection_tree/support.rs	Fri Apr 28 13:42:03 2023 +0300
+++ b/src/bisection_tree/support.rs	Fri Apr 28 14:02:18 2023 +0300
@@ -6,7 +6,7 @@
 use std::ops::{MulAssign,DivAssign,Neg};
 use crate::types::{Float, Num};
 use crate::maputil::map2;
-use crate::mapping::Apply;
+use crate::mapping::{Apply, Differentiable};
 use crate::sets::Cube;
 use crate::loc::Loc;
 use super::aggregator::Bounds;
@@ -145,6 +145,24 @@
     }
 }
 
+impl<'a, T, V, F : Float, const N : usize> Differentiable<&'a Loc<F, N>> for Shift<T,F,N>
+where T : Differentiable<Loc<F, N>, Output=V> {
+    type Output = V;
+    #[inline]
+    fn differential(&self, x : &'a Loc<F, N>) -> Self::Output {
+        self.base_fn.differential(x - &self.shift)
+    }
+}
+
+impl<'a, T, V, F : Float, const N : usize> Differentiable<Loc<F, N>> for Shift<T,F,N>
+where T : Differentiable<Loc<F, N>, Output=V> {
+    type Output = V;
+    #[inline]
+    fn differential(&self, x : Loc<F, N>) -> Self::Output {
+        self.base_fn.differential(x - &self.shift)
+    }
+}
+
 impl<'a, T, F : Float, const N : usize> Support<F,N> for Shift<T,F,N>
 where T : Support<F, N> {
     #[inline]
@@ -232,6 +250,28 @@
     }
 }
 
+impl<'a, T, V, F : Float, C, const N : usize> Differentiable<&'a Loc<F, N>> for Weighted<T, C>
+where T : for<'b> Differentiable<&'b Loc<F, N>, Output=V>,
+      V : std::ops::Mul<F, Output=V>,
+      C : Constant<Type=F> {
+    type Output = V;
+    #[inline]
+    fn differential(&self, x : &'a Loc<F, N>) -> Self::Output {
+        self.base_fn.differential(x) * self.weight.value()
+    }
+}
+
+impl<'a, T, V, F : Float, C, const N : usize> Differentiable<Loc<F, N>> for Weighted<T, C>
+where T : Differentiable<Loc<F, N>, Output=V>,
+      V : std::ops::Mul<F, Output=V>,
+      C : Constant<Type=F> {
+    type Output = V;
+    #[inline]
+    fn differential(&self, x : Loc<F, N>) -> Self::Output {
+        self.base_fn.differential(x) * self.weight.value()
+    }
+}
+
 impl<'a, T, F : Float, C, const N : usize> Support<F,N> for Weighted<T, C>
 where T : Support<F, N>,
       C : Constant<Type=F> {

mercurial