src/bisection_tree/support.rs

branch
dev
changeset 77
cf8ef9463664
parent 30
9f2214c961cb
equal deleted inserted replaced
30:9f2214c961cb 77:cf8ef9463664
143 fn apply(&self, x : Loc<F, N>) -> Self::Output { 143 fn apply(&self, x : Loc<F, N>) -> Self::Output {
144 self.base_fn.apply(x - &self.shift) 144 self.base_fn.apply(x - &self.shift)
145 } 145 }
146 } 146 }
147 147
148 impl<'a, T, V, F : Float, const N : usize> Differentiable<&'a Loc<F, N>> for Shift<T,F,N> 148 impl<'a, T, V, W, F : Float, const N : usize> Differentiable<&'a Loc<F, N>>
149 where T : Differentiable<Loc<F, N>, Output=V> { 149 for Shift<T,F,N>
150 type Output = V; 150 where T : Differentiable<Loc<F, N>, Output=V> + Apply<Loc<F, N>, Output=W> {
151 #[inline] 151 type Output = V;
152 fn differential(&self, x : &'a Loc<F, N>) -> Self::Output { 152 #[inline]
153 fn differential(&self, x : &'a Loc<F, N>) -> V {
153 self.base_fn.differential(x - &self.shift) 154 self.base_fn.differential(x - &self.shift)
154 } 155 }
155 } 156
156 157 #[inline]
157 impl<'a, T, V, F : Float, const N : usize> Differentiable<Loc<F, N>> for Shift<T,F,N> 158 fn linearisation_error(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>) -> W {
158 where T : Differentiable<Loc<F, N>, Output=V> { 159 self.base_fn
159 type Output = V; 160 .linearisation_error(x - &self.shift, y - &self.shift)
160 #[inline] 161 }
161 fn differential(&self, x : Loc<F, N>) -> Self::Output { 162
163 #[inline]
164 fn linearisation_error_gen(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>, z : &'a Loc<F, N>) -> W {
165 self.base_fn
166 .linearisation_error_gen(x - &self.shift, y - &self.shift, z - &self.shift)
167 }
168 }
169
170 impl<'a, T, V, W, F : Float, const N : usize> Differentiable<Loc<F, N>>
171 for Shift<T,F,N>
172 where T : Differentiable<Loc<F, N>, Output=V> + Apply<Loc<F, N>, Output=W> {
173 type Output = V;
174 #[inline]
175 fn differential(&self, x : Loc<F, N>) -> V {
162 self.base_fn.differential(x - &self.shift) 176 self.base_fn.differential(x - &self.shift)
177 }
178
179 #[inline]
180 fn linearisation_error(&self, x : Loc<F, N>, y : Loc<F, N>) -> W {
181 self.base_fn
182 .linearisation_error(x - &self.shift, y - &self.shift)
183 }
184
185 #[inline]
186 fn linearisation_error_gen(&self, x : Loc<F, N>, y : Loc<F, N>, z : Loc<F, N>) -> W {
187 self.base_fn
188 .linearisation_error_gen(x - &self.shift, y - &self.shift, z - &self.shift)
163 } 189 }
164 } 190 }
165 191
166 impl<'a, T, F : Float, const N : usize> Support<F,N> for Shift<T,F,N> 192 impl<'a, T, F : Float, const N : usize> Support<F,N> for Shift<T,F,N>
167 where T : Support<F, N> { 193 where T : Support<F, N> {
248 fn apply(&self, x : Loc<F, N>) -> Self::Output { 274 fn apply(&self, x : Loc<F, N>) -> Self::Output {
249 self.base_fn.apply(x) * self.weight.value() 275 self.base_fn.apply(x) * self.weight.value()
250 } 276 }
251 } 277 }
252 278
253 impl<'a, T, V, F : Float, C, const N : usize> Differentiable<&'a Loc<F, N>> for Weighted<T, C> 279 impl<'a, T, V, W, F : Float, C, const N : usize> Differentiable<&'a Loc<F, N>> for Weighted<T, C>
254 where T : for<'b> Differentiable<&'b Loc<F, N>, Output=V>, 280 where T : for<'b> Differentiable<&'b Loc<F, N>, Output=V>
281 + for<'b> Apply<&'b Loc<F, N>, Output=W>,
255 V : std::ops::Mul<F, Output=V>, 282 V : std::ops::Mul<F, Output=V>,
256 C : Constant<Type=F> { 283 W : std::ops::Mul<F, Output=W>,
257 type Output = V; 284 C : Constant<Type=F> {
258 #[inline] 285 type Output = V;
259 fn differential(&self, x : &'a Loc<F, N>) -> Self::Output { 286
287 #[inline]
288 fn differential(&self, x : &'a Loc<F, N>) -> V {
260 self.base_fn.differential(x) * self.weight.value() 289 self.base_fn.differential(x) * self.weight.value()
261 } 290 }
262 } 291
263 292 #[inline]
264 impl<'a, T, V, F : Float, C, const N : usize> Differentiable<Loc<F, N>> for Weighted<T, C> 293 fn linearisation_error(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>) -> W {
265 where T : Differentiable<Loc<F, N>, Output=V>, 294 self.base_fn.linearisation_error(x, y) * self.weight.value()
295 }
296
297 #[inline]
298 fn linearisation_error_gen(&self, x : &'a Loc<F, N>, y : &'a Loc<F, N>, z : &'a Loc<F, N>) -> W {
299 self.base_fn.linearisation_error_gen(x, y, z) * self.weight.value()
300 }
301 }
302
303 impl<'a, T, V, W, F : Float, C, const N : usize> Differentiable<Loc<F, N>>
304 for Weighted<T, C>
305 where T : Differentiable<Loc<F, N>, Output=V> + Apply<Loc<F, N>, Output=W>,
266 V : std::ops::Mul<F, Output=V>, 306 V : std::ops::Mul<F, Output=V>,
267 C : Constant<Type=F> { 307 W : std::ops::Mul<F, Output=W>,
268 type Output = V; 308 C : Constant<Type=F> {
269 #[inline] 309 type Output = V;
270 fn differential(&self, x : Loc<F, N>) -> Self::Output { 310
311 #[inline]
312 fn differential(&self, x : Loc<F, N>) -> V {
271 self.base_fn.differential(x) * self.weight.value() 313 self.base_fn.differential(x) * self.weight.value()
314 }
315
316 #[inline]
317 fn linearisation_error(&self, x : Loc<F, N>, y : Loc<F, N>) -> W {
318 self.base_fn.linearisation_error(x, y) * self.weight.value()
319 }
320
321 #[inline]
322 fn linearisation_error_gen(&self, x : Loc<F, N>, y : Loc<F, N>, z : Loc<F, N>) -> W {
323 self.base_fn.linearisation_error_gen(x, y, z) * self.weight.value()
272 } 324 }
273 } 325 }
274 326
275 impl<'a, T, F : Float, C, const N : usize> Support<F,N> for Weighted<T, C> 327 impl<'a, T, F : Float, C, const N : usize> Support<F,N> for Weighted<T, C>
276 where T : Support<F, N>, 328 where T : Support<F, N>,

mercurial