| 1 /*! |
1 /*! |
| 2 Arithmetic of [`Mapping`]s. |
2 Arithmetic of [`Mapping`]s. |
| 3 */ |
3 */ |
| 4 |
4 |
| |
5 use crate::instance::{Instance, Space}; |
| |
6 use crate::mapping::{DifferentiableImpl, DifferentiableMapping, Mapping}; |
| |
7 use crate::types::*; |
| 5 use serde::Serialize; |
8 use serde::Serialize; |
| 6 use crate::types::*; |
|
| 7 use crate::instance::{Space, Instance}; |
|
| 8 use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping}; |
|
| 9 |
9 |
| 10 /// A trait for encoding constant [`Float`] values |
10 /// A trait for encoding constant [`Float`] values |
| 11 pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> { |
11 pub trait Constant: Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> { |
| 12 /// The type of the value |
12 /// The type of the value |
| 13 type Type : Float; |
13 type Type: Float; |
| 14 /// Returns the value of the constant |
14 /// Returns the value of the constant |
| 15 fn value(&self) -> Self::Type; |
15 fn value(&self) -> Self::Type; |
| 16 } |
16 } |
| 17 |
17 |
| 18 impl<F : Float> Constant for F { |
18 impl<F: Float> Constant for F { |
| 19 type Type = F; |
19 type Type = F; |
| 20 #[inline] |
20 #[inline] |
| 21 fn value(&self) -> F { *self } |
21 fn value(&self) -> F { |
| |
22 *self |
| |
23 } |
| 22 } |
24 } |
| 23 |
25 |
| 24 /// Weighting of a [`Mapping`] by scalar multiplication. |
26 /// Weighting of a [`Mapping`] by scalar multiplication. |
| 25 #[derive(Copy,Clone,Debug,Serialize)] |
27 #[derive(Copy, Clone, Debug, Serialize)] |
| 26 pub struct Weighted<T, C : Constant> { |
28 pub struct Weighted<T, C: Constant> { |
| 27 /// The weight |
29 /// The weight |
| 28 pub weight : C, |
30 pub weight: C, |
| 29 /// The base [`Mapping`] being weighted. |
31 /// The base [`Mapping`] being weighted. |
| 30 pub base_fn : T, |
32 pub base_fn: T, |
| 31 } |
33 } |
| 32 |
34 |
| 33 impl<T, C> Weighted<T, C> |
35 impl<T, C> Weighted<T, C> |
| 34 where |
36 where |
| 35 C : Constant, |
37 C: Constant, |
| 36 { |
38 { |
| 37 /// Construct from an iterator. |
39 /// Construct from an iterator. |
| 38 pub fn new(weight : C, base_fn : T) -> Self { |
40 pub fn new(weight: C, base_fn: T) -> Self { |
| 39 Weighted{ weight, base_fn } |
41 Weighted { weight, base_fn } |
| 40 } |
42 } |
| 41 } |
43 } |
| 42 |
44 |
| 43 impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C> |
45 impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C> |
| 44 where |
46 where |
| 45 F : Float, |
47 F: Float, |
| 46 D : Space, |
48 D: Space, |
| 47 T : Mapping<D, Codomain=V>, |
49 T: Mapping<D, Codomain = V>, |
| 48 V : Space + ClosedMul<F>, |
50 V: Space + ClosedMul<F>, |
| 49 C : Constant<Type=F> |
51 C: Constant<Type = F>, |
| 50 { |
52 { |
| 51 type Codomain = V; |
53 type Codomain = V; |
| 52 |
54 |
| 53 #[inline] |
55 #[inline] |
| 54 fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain { |
56 fn apply<I: Instance<D>>(&self, x: I) -> Self::Codomain { |
| 55 self.base_fn.apply(x) * self.weight.value() |
57 self.base_fn.apply(x) * self.weight.value() |
| 56 } |
58 } |
| 57 } |
59 } |
| 58 |
60 |
| 59 impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C> |
61 impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C> |
| 60 where |
62 where |
| 61 F : Float, |
63 F: Float, |
| 62 D : Space, |
64 D: Space, |
| 63 T : DifferentiableMapping<D, DerivativeDomain=V>, |
65 T: DifferentiableMapping<D, DerivativeDomain = V>, |
| 64 V : Space + std::ops::Mul<F, Output=V>, |
66 V: Space + std::ops::Mul<F, Output = V>, |
| 65 C : Constant<Type=F> |
67 C: Constant<Type = F>, |
| 66 { |
68 { |
| 67 type Derivative = V; |
69 type Derivative = V; |
| 68 |
70 |
| 69 #[inline] |
71 #[inline] |
| 70 fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative { |
72 fn differential_impl<I: Instance<D>>(&self, x: I) -> Self::Derivative { |
| 71 self.base_fn.differential(x) * self.weight.value() |
73 self.base_fn.differential(x) * self.weight.value() |
| 72 } |
74 } |
| 73 } |
75 } |
| 74 |
76 |
| 75 /// A sum of [`Mapping`]s. |
77 /// A sum of [`Mapping`]s. |
| 76 #[derive(Serialize, Debug, Clone)] |
78 #[derive(Serialize, Debug, Clone)] |
| 77 pub struct MappingSum<M>(Vec<M>); |
79 pub struct MappingSum<M>(Vec<M>); |
| 78 |
80 |
| 79 impl< M> MappingSum<M> { |
81 impl<M> MappingSum<M> { |
| 80 /// Construct from an iterator. |
82 /// Construct from an iterator. |
| 81 pub fn new<I : IntoIterator<Item = M>>(iter : I) -> Self { |
83 pub fn new<I: IntoIterator<Item = M>>(iter: I) -> Self { |
| 82 MappingSum(iter.into_iter().collect()) |
84 MappingSum(iter.into_iter().collect()) |
| 83 } |
85 } |
| 84 |
86 |
| 85 /// Iterate over the component functions of the sum |
87 /// Iterate over the component functions of the sum |
| 86 pub fn iter(&self) -> std::slice::Iter<'_, M> { |
88 pub fn iter(&self) -> std::slice::Iter<'_, M> { |
| 88 } |
90 } |
| 89 } |
91 } |
| 90 |
92 |
| 91 impl<Domain, M> Mapping<Domain> for MappingSum<M> |
93 impl<Domain, M> Mapping<Domain> for MappingSum<M> |
| 92 where |
94 where |
| 93 Domain : Space + Clone, |
95 Domain: Space + Clone, |
| 94 M : Mapping<Domain>, |
96 M: Mapping<Domain>, |
| 95 M::Codomain : std::iter::Sum + Clone |
97 M::Codomain: std::iter::Sum + Clone, |
| 96 { |
98 { |
| 97 type Codomain = M::Codomain; |
99 type Codomain = M::Codomain; |
| 98 |
100 |
| 99 fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain { |
101 fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain { |
| 100 let xr = x.ref_instance(); |
102 let xr = x.ref_instance(); |
| 101 self.0.iter().map(|c| c.apply(xr)).sum() |
103 self.0.iter().map(|c| c.apply(xr)).sum() |
| 102 } |
104 } |
| 103 } |
105 } |
| 104 |
106 |
| 105 impl<Domain, M> DifferentiableImpl<Domain> for MappingSum< M> |
107 impl<Domain, M> DifferentiableImpl<Domain> for MappingSum<M> |
| 106 where |
108 where |
| 107 Domain : Space + Clone, |
109 Domain: Space, |
| 108 M : DifferentiableMapping<Domain>, |
110 M: DifferentiableMapping<Domain>, |
| 109 M :: DerivativeDomain : std::iter::Sum |
111 M::DerivativeDomain: std::iter::Sum, |
| 110 { |
112 { |
| 111 type Derivative = M::DerivativeDomain; |
113 type Derivative = M::DerivativeDomain; |
| 112 |
114 |
| 113 fn differential_impl<I : Instance<Domain>>(&self, x : I) -> Self::Derivative { |
115 fn differential_impl<I: Instance<Domain>>(&self, x: I) -> Self::Derivative { |
| 114 let xr = x.ref_instance(); |
116 let xr = x.ref_instance(); |
| 115 self.0.iter().map(|c| c.differential(xr)).sum() |
117 self.0.iter().map(|c| c.differential(xr)).sum() |
| 116 } |
118 } |
| 117 } |
119 } |