# HG changeset patch # User Tuomo Valkonen # Date 1682679738 -10800 # Node ID 9f2214c961cb7d5dd04b9fb0a97d1a66283d6be8 # Parent 7fd0984743b5c3d3b138226944fad6d7469bdf98 Implement Differentiable for Weighted and Shift diff -r 7fd0984743b5 -r 9f2214c961cb src/bisection_tree/support.rs --- a/src/bisection_tree/support.rs Fri Apr 28 13:42:03 2023 +0300 +++ b/src/bisection_tree/support.rs Fri Apr 28 14:02:18 2023 +0300 @@ -6,7 +6,7 @@ use std::ops::{MulAssign,DivAssign,Neg}; use crate::types::{Float, Num}; use crate::maputil::map2; -use crate::mapping::Apply; +use crate::mapping::{Apply, Differentiable}; use crate::sets::Cube; use crate::loc::Loc; use super::aggregator::Bounds; @@ -145,6 +145,24 @@ } } +impl<'a, T, V, F : Float, const N : usize> Differentiable<&'a Loc> for Shift +where T : Differentiable, Output=V> { + type Output = V; + #[inline] + fn differential(&self, x : &'a Loc) -> Self::Output { + self.base_fn.differential(x - &self.shift) + } +} + +impl<'a, T, V, F : Float, const N : usize> Differentiable> for Shift +where T : Differentiable, Output=V> { + type Output = V; + #[inline] + fn differential(&self, x : Loc) -> Self::Output { + self.base_fn.differential(x - &self.shift) + } +} + impl<'a, T, F : Float, const N : usize> Support for Shift where T : Support { #[inline] @@ -232,6 +250,28 @@ } } +impl<'a, T, V, F : Float, C, const N : usize> Differentiable<&'a Loc> for Weighted +where T : for<'b> Differentiable<&'b Loc, Output=V>, + V : std::ops::Mul, + C : Constant { + type Output = V; + #[inline] + fn differential(&self, x : &'a Loc) -> Self::Output { + self.base_fn.differential(x) * self.weight.value() + } +} + +impl<'a, T, V, F : Float, C, const N : usize> Differentiable> for Weighted +where T : Differentiable, Output=V>, + V : std::ops::Mul, + C : Constant { + type Output = V; + #[inline] + fn differential(&self, x : Loc) -> Self::Output { + self.base_fn.differential(x) * self.weight.value() + } +} + impl<'a, T, F : Float, C, const N : usize> Support for Weighted where T : Support, C : Constant {