src/bisection_tree/support.rs

branch
dev
changeset 59
9226980e45a7
parent 47
a0db98c16ab5
child 68
c5f70e767511
equal deleted inserted replaced
58:1a38447a89fa 59:9226980e45a7
4 */ 4 */
5 use serde::Serialize; 5 use serde::Serialize;
6 use std::ops::{MulAssign,DivAssign,Neg}; 6 use std::ops::{MulAssign,DivAssign,Neg};
7 use crate::types::{Float, Num}; 7 use crate::types::{Float, Num};
8 use crate::maputil::map2; 8 use crate::maputil::map2;
9 use crate::mapping::{Apply, Differentiable}; 9 use crate::mapping::{
10 Instance, Mapping, DifferentiableImpl, DifferentiableMapping, Space
11 };
10 use crate::sets::Cube; 12 use crate::sets::Cube;
11 use crate::loc::Loc; 13 use crate::loc::Loc;
12 use super::aggregator::Bounds; 14 use super::aggregator::Bounds;
13 use crate::norms::{Norm, L1, L2, Linfinity}; 15 use crate::norms::{Norm, L1, L2, Linfinity};
14 16
125 pub struct Shift<T, F, const N : usize> { 127 pub struct Shift<T, F, const N : usize> {
126 shift : Loc<F, N>, 128 shift : Loc<F, N>,
127 base_fn : T, 129 base_fn : T,
128 } 130 }
129 131
130 impl<'a, T, V, F : Float, const N : usize> Apply<&'a Loc<F, N>> for Shift<T,F,N> 132 impl<'a, T, V : Space, F : Float, const N : usize> Mapping<Loc<F, N>> for Shift<T,F,N>
131 where T : Apply<Loc<F, N>, Output=V> { 133 where T : Mapping<Loc<F, N>, Codomain=V> {
132 type Output = V; 134 type Codomain = V;
133 #[inline] 135
134 fn apply(&self, x : &'a Loc<F, N>) -> Self::Output { 136 #[inline]
135 self.base_fn.apply(x - &self.shift) 137 fn apply<I : Instance<Loc<F, N>>>(&self, x : I) -> Self::Codomain {
136 } 138 self.base_fn.apply(x.own() - &self.shift)
137 } 139 }
138 140 }
139 impl<'a, T, V, F : Float, const N : usize> Apply<Loc<F, N>> for Shift<T,F,N> 141
140 where T : Apply<Loc<F, N>, Output=V> { 142 impl<'a, T, V : Space, F : Float, const N : usize> DifferentiableImpl<Loc<F, N>> for Shift<T,F,N>
141 type Output = V; 143 where T : DifferentiableMapping<Loc<F, N>, DerivativeDomain=V> {
142 #[inline]
143 fn apply(&self, x : Loc<F, N>) -> Self::Output {
144 self.base_fn.apply(x - &self.shift)
145 }
146 }
147
148 impl<'a, T, V, F : Float, const N : usize> Differentiable<&'a Loc<F, N>> for Shift<T,F,N>
149 where T : Differentiable<Loc<F, N>, Derivative=V> {
150 type Derivative = V; 144 type Derivative = V;
151 #[inline] 145
152 fn differential(&self, x : &'a Loc<F, N>) -> Self::Derivative { 146 #[inline]
153 self.base_fn.differential(x - &self.shift) 147 fn differential_impl<I : Instance<Loc<F, N>>>(&self, x : I) -> Self::Derivative {
154 } 148 self.base_fn.differential(x.own() - &self.shift)
155 }
156
157 impl<'a, T, V, F : Float, const N : usize> Differentiable<Loc<F, N>> for Shift<T,F,N>
158 where T : Differentiable<Loc<F, N>, Derivative=V> {
159 type Derivative = V;
160 #[inline]
161 fn differential(&self, x : Loc<F, N>) -> Self::Derivative {
162 self.base_fn.differential(x - &self.shift)
163 } 149 }
164 } 150 }
165 151
166 impl<'a, T, F : Float, const N : usize> Support<F,N> for Shift<T,F,N> 152 impl<'a, T, F : Float, const N : usize> Support<F,N> for Shift<T,F,N>
167 where T : Support<F, N> { 153 where T : Support<F, N> {
226 pub weight : C, 212 pub weight : C,
227 /// The base [`Support`] or [`Apply`] being weighted. 213 /// The base [`Support`] or [`Apply`] being weighted.
228 pub base_fn : T, 214 pub base_fn : T,
229 } 215 }
230 216
231 impl<'a, T, V, F : Float, C, const N : usize> Apply<&'a Loc<F, N>> for Weighted<T, C> 217 impl<'a, T, V, F : Float, C, const N : usize> Mapping<Loc<F, N>> for Weighted<T, C>
232 where T : for<'b> Apply<&'b Loc<F, N>, Output=V>, 218 where T : Mapping<Loc<F, N>, Codomain=V>,
233 V : std::ops::Mul<F,Output=V>, 219 V : Space + std::ops::Mul<F,Output=V>,
234 C : Constant<Type=F> { 220 C : Constant<Type=F> {
235 type Output = V; 221 type Codomain = V;
236 #[inline] 222
237 fn apply(&self, x : &'a Loc<F, N>) -> Self::Output { 223 #[inline]
224 fn apply<I : Instance<Loc<F, N>>>(&self, x : I) -> Self::Codomain {
238 self.base_fn.apply(x) * self.weight.value() 225 self.base_fn.apply(x) * self.weight.value()
239 } 226 }
240 } 227 }
241 228
242 impl<'a, T, V, F : Float, C, const N : usize> Apply<Loc<F, N>> for Weighted<T, C> 229 impl<'a, T, V, F : Float, C, const N : usize> DifferentiableImpl<Loc<F, N>> for Weighted<T, C>
243 where T : Apply<Loc<F, N>, Output=V>, 230 where T : DifferentiableMapping<Loc<F, N>, DerivativeDomain=V>,
244 V : std::ops::Mul<F,Output=V>, 231 V : Space + std::ops::Mul<F, Output=V>,
245 C : Constant<Type=F> {
246 type Output = V;
247 #[inline]
248 fn apply(&self, x : Loc<F, N>) -> Self::Output {
249 self.base_fn.apply(x) * self.weight.value()
250 }
251 }
252
253 impl<'a, T, V, F : Float, C, const N : usize> Differentiable<&'a Loc<F, N>> for Weighted<T, C>
254 where T : for<'b> Differentiable<&'b Loc<F, N>, Derivative=V>,
255 V : std::ops::Mul<F, Output=V>,
256 C : Constant<Type=F> { 232 C : Constant<Type=F> {
257 type Derivative = V; 233 type Derivative = V;
258 #[inline] 234
259 fn differential(&self, x : &'a Loc<F, N>) -> Self::Derivative { 235 #[inline]
260 self.base_fn.differential(x) * self.weight.value() 236 fn differential_impl<I : Instance<Loc<F, N>>>(&self, x : I) -> Self::Derivative {
261 }
262 }
263
264 impl<'a, T, V, F : Float, C, const N : usize> Differentiable<Loc<F, N>> for Weighted<T, C>
265 where T : Differentiable<Loc<F, N>, Derivative=V>,
266 V : std::ops::Mul<F, Output=V>,
267 C : Constant<Type=F> {
268 type Derivative = V;
269 #[inline]
270 fn differential(&self, x : Loc<F, N>) -> Self::Derivative {
271 self.base_fn.differential(x) * self.weight.value() 237 self.base_fn.differential(x) * self.weight.value()
272 } 238 }
273 } 239 }
274 240
275 impl<'a, T, F : Float, C, const N : usize> Support<F,N> for Weighted<T, C> 241 impl<'a, T, F : Float, C, const N : usize> Support<F,N> for Weighted<T, C>
378 pub struct Normalised<T>( 344 pub struct Normalised<T>(
379 /// The base [`Support`] or [`Apply`]. 345 /// The base [`Support`] or [`Apply`].
380 pub T 346 pub T
381 ); 347 );
382 348
383 impl<'a, T, F : Float, const N : usize> Apply<&'a Loc<F, N>> for Normalised<T> 349 impl<'a, T, F : Float, const N : usize> Mapping<Loc<F, N>> for Normalised<T>
384 where T : Norm<F, L1> + for<'b> Apply<&'b Loc<F, N>, Output=F> { 350 where T : Norm<F, L1> + Mapping<Loc<F,N>, Codomain=F> {
385 type Output = F; 351 type Codomain = F;
386 #[inline] 352
387 fn apply(&self, x : &'a Loc<F, N>) -> Self::Output { 353 #[inline]
388 let w = self.0.norm(L1); 354 fn apply<I : Instance<Loc<F, N>>>(&self, x : I) -> Self::Codomain {
389 if w == F::ZERO { F::ZERO } else { self.0.apply(x) / w }
390 }
391 }
392
393 impl<'a, T, F : Float, const N : usize> Apply<Loc<F, N>> for Normalised<T>
394 where T : Norm<F, L1> + Apply<Loc<F,N>, Output=F> {
395 type Output = F;
396 #[inline]
397 fn apply(&self, x : Loc<F, N>) -> Self::Output {
398 let w = self.0.norm(L1); 355 let w = self.0.norm(L1);
399 if w == F::ZERO { F::ZERO } else { self.0.apply(x) / w } 356 if w == F::ZERO { F::ZERO } else { self.0.apply(x) / w }
400 } 357 }
401 } 358 }
402 359

mercurial