src/bisection_tree/support.rs

branch
dev
changeset 30
9f2214c961cb
parent 15
e03ce15643da
child 47
a0db98c16ab5
child 77
cf8ef9463664
equal deleted inserted replaced
29:7fd0984743b5 30:9f2214c961cb
4 */ 4 */
5 use serde::Serialize; 5 use serde::Serialize;
6 use std::ops::{MulAssign,DivAssign,Neg}; 6 use std::ops::{MulAssign,DivAssign,Neg};
7 use crate::types::{Float, Num}; 7 use crate::types::{Float, Num};
8 use crate::maputil::map2; 8 use crate::maputil::map2;
9 use crate::mapping::Apply; 9 use crate::mapping::{Apply, Differentiable};
10 use crate::sets::Cube; 10 use crate::sets::Cube;
11 use crate::loc::Loc; 11 use crate::loc::Loc;
12 use super::aggregator::Bounds; 12 use super::aggregator::Bounds;
13 use crate::norms::{Norm, L1, L2, Linfinity}; 13 use crate::norms::{Norm, L1, L2, Linfinity};
14 14
140 where T : Apply<Loc<F, N>, Output=V> { 140 where T : Apply<Loc<F, N>, Output=V> {
141 type Output = V; 141 type Output = V;
142 #[inline] 142 #[inline]
143 fn apply(&self, x : Loc<F, N>) -> Self::Output { 143 fn apply(&self, x : Loc<F, N>) -> Self::Output {
144 self.base_fn.apply(x - &self.shift) 144 self.base_fn.apply(x - &self.shift)
145 }
146 }
147
148 impl<'a, T, V, F : Float, const N : usize> Differentiable<&'a Loc<F, N>> for Shift<T,F,N>
149 where T : Differentiable<Loc<F, N>, Output=V> {
150 type Output = V;
151 #[inline]
152 fn differential(&self, x : &'a Loc<F, N>) -> Self::Output {
153 self.base_fn.differential(x - &self.shift)
154 }
155 }
156
157 impl<'a, T, V, F : Float, const N : usize> Differentiable<Loc<F, N>> for Shift<T,F,N>
158 where T : Differentiable<Loc<F, N>, Output=V> {
159 type Output = V;
160 #[inline]
161 fn differential(&self, x : Loc<F, N>) -> Self::Output {
162 self.base_fn.differential(x - &self.shift)
145 } 163 }
146 } 164 }
147 165
148 impl<'a, T, F : Float, const N : usize> Support<F,N> for Shift<T,F,N> 166 impl<'a, T, F : Float, const N : usize> Support<F,N> for Shift<T,F,N>
149 where T : Support<F, N> { 167 where T : Support<F, N> {
227 C : Constant<Type=F> { 245 C : Constant<Type=F> {
228 type Output = V; 246 type Output = V;
229 #[inline] 247 #[inline]
230 fn apply(&self, x : Loc<F, N>) -> Self::Output { 248 fn apply(&self, x : Loc<F, N>) -> Self::Output {
231 self.base_fn.apply(x) * self.weight.value() 249 self.base_fn.apply(x) * self.weight.value()
250 }
251 }
252
253 impl<'a, T, V, F : Float, C, const N : usize> Differentiable<&'a Loc<F, N>> for Weighted<T, C>
254 where T : for<'b> Differentiable<&'b Loc<F, N>, Output=V>,
255 V : std::ops::Mul<F, Output=V>,
256 C : Constant<Type=F> {
257 type Output = V;
258 #[inline]
259 fn differential(&self, x : &'a Loc<F, N>) -> Self::Output {
260 self.base_fn.differential(x) * self.weight.value()
261 }
262 }
263
264 impl<'a, T, V, F : Float, C, const N : usize> Differentiable<Loc<F, N>> for Weighted<T, C>
265 where T : Differentiable<Loc<F, N>, Output=V>,
266 V : std::ops::Mul<F, Output=V>,
267 C : Constant<Type=F> {
268 type Output = V;
269 #[inline]
270 fn differential(&self, x : Loc<F, N>) -> Self::Output {
271 self.base_fn.differential(x) * self.weight.value()
232 } 272 }
233 } 273 }
234 274
235 impl<'a, T, F : Float, C, const N : usize> Support<F,N> for Weighted<T, C> 275 impl<'a, T, F : Float, C, const N : usize> Support<F,N> for Weighted<T, C>
236 where T : Support<F, N>, 276 where T : Support<F, N>,

mercurial