Fri, 28 Apr 2023 14:02:18 +0300
Implement Differentiable for Weighted and Shift
src/bisection_tree/support.rs | file | annotate | diff | comparison | revisions |
--- a/src/bisection_tree/support.rs Fri Apr 28 13:42:03 2023 +0300 +++ b/src/bisection_tree/support.rs Fri Apr 28 14:02:18 2023 +0300 @@ -6,7 +6,7 @@ use std::ops::{MulAssign,DivAssign,Neg}; use crate::types::{Float, Num}; use crate::maputil::map2; -use crate::mapping::Apply; +use crate::mapping::{Apply, Differentiable}; use crate::sets::Cube; use crate::loc::Loc; use super::aggregator::Bounds; @@ -145,6 +145,24 @@ } } +impl<'a, T, V, F : Float, const N : usize> Differentiable<&'a Loc<F, N>> for Shift<T,F,N> +where T : Differentiable<Loc<F, N>, Output=V> { + type Output = V; + #[inline] + fn differential(&self, x : &'a Loc<F, N>) -> Self::Output { + self.base_fn.differential(x - &self.shift) + } +} + +impl<'a, T, V, F : Float, const N : usize> Differentiable<Loc<F, N>> for Shift<T,F,N> +where T : Differentiable<Loc<F, N>, Output=V> { + type Output = V; + #[inline] + fn differential(&self, x : Loc<F, N>) -> Self::Output { + self.base_fn.differential(x - &self.shift) + } +} + impl<'a, T, F : Float, const N : usize> Support<F,N> for Shift<T,F,N> where T : Support<F, N> { #[inline] @@ -232,6 +250,28 @@ } } +impl<'a, T, V, F : Float, C, const N : usize> Differentiable<&'a Loc<F, N>> for Weighted<T, C> +where T : for<'b> Differentiable<&'b Loc<F, N>, Output=V>, + V : std::ops::Mul<F, Output=V>, + C : Constant<Type=F> { + type Output = V; + #[inline] + fn differential(&self, x : &'a Loc<F, N>) -> Self::Output { + self.base_fn.differential(x) * self.weight.value() + } +} + +impl<'a, T, V, F : Float, C, const N : usize> Differentiable<Loc<F, N>> for Weighted<T, C> +where T : Differentiable<Loc<F, N>, Output=V>, + V : std::ops::Mul<F, Output=V>, + C : Constant<Type=F> { + type Output = V; + #[inline] + fn differential(&self, x : Loc<F, N>) -> Self::Output { + self.base_fn.differential(x) * self.weight.value() + } +} + impl<'a, T, F : Float, C, const N : usize> Support<F,N> for Weighted<T, C> where T : Support<F, N>, C : Constant<Type=F> {