diff -r 9f2214c961cb -r cf8ef9463664 src/bisection_tree/support.rs --- a/src/bisection_tree/support.rs Fri Apr 28 14:02:18 2023 +0300 +++ b/src/bisection_tree/support.rs Tue Jul 18 15:44:10 2023 +0300 @@ -145,22 +145,48 @@ } } -impl<'a, T, V, F : Float, const N : usize> Differentiable<&'a Loc> for Shift -where T : Differentiable, Output=V> { +impl<'a, T, V, W, F : Float, const N : usize> Differentiable<&'a Loc> +for Shift +where T : Differentiable, Output=V> + Apply, Output=W> { type Output = V; #[inline] - fn differential(&self, x : &'a Loc) -> Self::Output { + fn differential(&self, x : &'a Loc) -> V { self.base_fn.differential(x - &self.shift) } + + #[inline] + fn linearisation_error(&self, x : &'a Loc, y : &'a Loc) -> W { + self.base_fn + .linearisation_error(x - &self.shift, y - &self.shift) + } + + #[inline] + fn linearisation_error_gen(&self, x : &'a Loc, y : &'a Loc, z : &'a Loc) -> W { + self.base_fn + .linearisation_error_gen(x - &self.shift, y - &self.shift, z - &self.shift) + } } -impl<'a, T, V, F : Float, const N : usize> Differentiable> for Shift -where T : Differentiable, Output=V> { +impl<'a, T, V, W, F : Float, const N : usize> Differentiable> +for Shift +where T : Differentiable, Output=V> + Apply, Output=W> { type Output = V; #[inline] - fn differential(&self, x : Loc) -> Self::Output { + fn differential(&self, x : Loc) -> V { self.base_fn.differential(x - &self.shift) } + + #[inline] + fn linearisation_error(&self, x : Loc, y : Loc) -> W { + self.base_fn + .linearisation_error(x - &self.shift, y - &self.shift) + } + + #[inline] + fn linearisation_error_gen(&self, x : Loc, y : Loc, z : Loc) -> W { + self.base_fn + .linearisation_error_gen(x - &self.shift, y - &self.shift, z - &self.shift) + } } impl<'a, T, F : Float, const N : usize> Support for Shift @@ -250,26 +276,52 @@ } } -impl<'a, T, V, F : Float, C, const N : usize> Differentiable<&'a Loc> for Weighted -where T : for<'b> Differentiable<&'b Loc, Output=V>, +impl<'a, T, V, W, F : Float, C, const N : usize> Differentiable<&'a Loc> for Weighted +where T : for<'b> Differentiable<&'b Loc, Output=V> + + for<'b> Apply<&'b Loc, Output=W>, V : std::ops::Mul, + W : std::ops::Mul, C : Constant { type Output = V; + #[inline] - fn differential(&self, x : &'a Loc) -> Self::Output { + fn differential(&self, x : &'a Loc) -> V { self.base_fn.differential(x) * self.weight.value() } + + #[inline] + fn linearisation_error(&self, x : &'a Loc, y : &'a Loc) -> W { + self.base_fn.linearisation_error(x, y) * self.weight.value() + } + + #[inline] + fn linearisation_error_gen(&self, x : &'a Loc, y : &'a Loc, z : &'a Loc) -> W { + self.base_fn.linearisation_error_gen(x, y, z) * self.weight.value() + } } -impl<'a, T, V, F : Float, C, const N : usize> Differentiable> for Weighted -where T : Differentiable, Output=V>, +impl<'a, T, V, W, F : Float, C, const N : usize> Differentiable> +for Weighted +where T : Differentiable, Output=V> + Apply, Output=W>, V : std::ops::Mul, + W : std::ops::Mul, C : Constant { type Output = V; + #[inline] - fn differential(&self, x : Loc) -> Self::Output { + fn differential(&self, x : Loc) -> V { self.base_fn.differential(x) * self.weight.value() } + + #[inline] + fn linearisation_error(&self, x : Loc, y : Loc) -> W { + self.base_fn.linearisation_error(x, y) * self.weight.value() + } + + #[inline] + fn linearisation_error_gen(&self, x : Loc, y : Loc, z : Loc) -> W { + self.base_fn.linearisation_error_gen(x, y, z) * self.weight.value() + } } impl<'a, T, F : Float, C, const N : usize> Support for Weighted