| |
1 /*! |
| |
2 Arithmetic of [`Mapping`]s. |
| |
3 */ |
| |
4 |
| |
5 use serde::Serialize; |
| |
6 use crate::types::*; |
| |
7 use crate::instance::{Space, Instance}; |
| |
8 use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping}; |
| |
9 |
| |
10 /// A trait for encoding constant [`Float`] values |
| |
11 pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> { |
| |
12 /// The type of the value |
| |
13 type Type : Float; |
| |
14 /// Returns the value of the constant |
| |
15 fn value(&self) -> Self::Type; |
| |
16 } |
| |
17 |
| |
18 impl<F : Float> Constant for F { |
| |
19 type Type = F; |
| |
20 #[inline] |
| |
21 fn value(&self) -> F { *self } |
| |
22 } |
| |
23 |
| |
24 /// Weighting of a [`Mapping`] by scalar multiplication. |
| |
25 #[derive(Copy,Clone,Debug,Serialize)] |
| |
26 pub struct Weighted<T, C : Constant> { |
| |
27 /// The weight |
| |
28 pub weight : C, |
| |
29 /// The base [`Mapping`] being weighted. |
| |
30 pub base_fn : T, |
| |
31 } |
| |
32 |
| |
33 impl<T, C> Weighted<T, C> |
| |
34 where |
| |
35 C : Constant, |
| |
36 { |
| |
37 /// Construct from an iterator. |
| |
38 pub fn new(weight : C, base_fn : T) -> Self { |
| |
39 Weighted{ weight, base_fn } |
| |
40 } |
| |
41 } |
| |
42 |
| |
43 impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C> |
| |
44 where |
| |
45 F : Float, |
| |
46 D : Space, |
| |
47 T : Mapping<D, Codomain=V>, |
| |
48 V : Space + ClosedMul<F>, |
| |
49 C : Constant<Type=F> |
| |
50 { |
| |
51 type Codomain = V; |
| |
52 |
| |
53 #[inline] |
| |
54 fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain { |
| |
55 self.base_fn.apply(x) * self.weight.value() |
| |
56 } |
| |
57 } |
| |
58 |
| |
59 impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C> |
| |
60 where |
| |
61 F : Float, |
| |
62 D : Space, |
| |
63 T : DifferentiableMapping<D, DerivativeDomain=V>, |
| |
64 V : Space + std::ops::Mul<F, Output=V>, |
| |
65 C : Constant<Type=F> |
| |
66 { |
| |
67 type Derivative = V; |
| |
68 |
| |
69 #[inline] |
| |
70 fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative { |
| |
71 self.base_fn.differential(x) * self.weight.value() |
| |
72 } |
| |
73 } |
| |
74 |
| |
75 /// A sum of [`Mapping`]s. |
| |
76 #[derive(Serialize, Debug, Clone)] |
| |
77 pub struct MappingSum<M>(Vec<M>); |
| |
78 |
| |
79 impl< M> MappingSum<M> { |
| |
80 /// Construct from an iterator. |
| |
81 pub fn new<I : IntoIterator<Item = M>>(iter : I) -> Self { |
| |
82 MappingSum(iter.into_iter().collect()) |
| |
83 } |
| |
84 |
| |
85 /// Iterate over the component functions of the sum |
| |
86 pub fn iter(&self) -> std::slice::Iter<'_, M> { |
| |
87 self.0.iter() |
| |
88 } |
| |
89 } |
| |
90 |
| |
91 impl<Domain, M> Mapping<Domain> for MappingSum<M> |
| |
92 where |
| |
93 Domain : Space + Clone, |
| |
94 M : Mapping<Domain>, |
| |
95 M::Codomain : std::iter::Sum + Clone |
| |
96 { |
| |
97 type Codomain = M::Codomain; |
| |
98 |
| |
99 fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain { |
| |
100 let xr = x.ref_instance(); |
| |
101 self.0.iter().map(|c| c.apply(xr)).sum() |
| |
102 } |
| |
103 } |
| |
104 |
| |
105 impl<Domain, M> DifferentiableImpl<Domain> for MappingSum< M> |
| |
106 where |
| |
107 Domain : Space + Clone, |
| |
108 M : DifferentiableMapping<Domain>, |
| |
109 M :: DerivativeDomain : std::iter::Sum |
| |
110 { |
| |
111 type Derivative = M::DerivativeDomain; |
| |
112 |
| |
113 fn differential_impl<I : Instance<Domain>>(&self, x : I) -> Self::Derivative { |
| |
114 let xr = x.ref_instance(); |
| |
115 self.0.iter().map(|c| c.differential(xr)).sum() |
| |
116 } |
| |
117 } |