|
1 /*! |
|
2 Arithmetic of [`Mapping`]s. |
|
3 */ |
|
4 |
|
5 use serde::Serialize; |
|
6 use crate::types::*; |
|
7 use crate::instance::{Space, Instance}; |
|
8 use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping}; |
|
9 |
|
10 /// A trait for encoding constant [`Float`] values |
|
11 pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> { |
|
12 /// The type of the value |
|
13 type Type : Float; |
|
14 /// Returns the value of the constant |
|
15 fn value(&self) -> Self::Type; |
|
16 } |
|
17 |
|
18 impl<F : Float> Constant for F { |
|
19 type Type = F; |
|
20 #[inline] |
|
21 fn value(&self) -> F { *self } |
|
22 } |
|
23 |
|
24 /// Weighting of a [`Support`] and [`Apply`] by scalar multiplication; |
|
25 /// output of [`Support::weigh`]. |
|
26 #[derive(Copy,Clone,Debug,Serialize)] |
|
27 pub struct Weighted<T, C : Constant> { |
|
28 /// The weight |
|
29 pub weight : C, |
|
30 /// The base [`Support`] or [`Apply`] being weighted. |
|
31 pub base_fn : T, |
|
32 } |
|
33 |
|
34 impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C> |
|
35 where |
|
36 F : Float, |
|
37 D : Space, |
|
38 T : Mapping<D, Codomain=V>, |
|
39 V : Space + ClosedMul<F>, |
|
40 C : Constant<Type=F> |
|
41 { |
|
42 type Codomain = V; |
|
43 |
|
44 #[inline] |
|
45 fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain { |
|
46 self.base_fn.apply(x) * self.weight.value() |
|
47 } |
|
48 } |
|
49 |
|
50 impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C> |
|
51 where |
|
52 F : Float, |
|
53 D : Space, |
|
54 T : DifferentiableMapping<D, DerivativeDomain=V>, |
|
55 V : Space + std::ops::Mul<F, Output=V>, |
|
56 C : Constant<Type=F> |
|
57 { |
|
58 type Derivative = V; |
|
59 |
|
60 #[inline] |
|
61 fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative { |
|
62 self.base_fn.differential(x) * self.weight.value() |
|
63 } |
|
64 } |