|
1 /*! |
|
2 Arithmetic of [`Mapping`]s. |
|
3 */ |
|
4 |
|
5 use serde::Serialize; |
|
6 use crate::types::*; |
|
7 use crate::instance::{Space, Instance}; |
|
8 use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping}; |
|
9 |
|
10 /// A trait for encoding constant [`Float`] values |
|
11 pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> { |
|
12 /// The type of the value |
|
13 type Type : Float; |
|
14 /// Returns the value of the constant |
|
15 fn value(&self) -> Self::Type; |
|
16 } |
|
17 |
|
18 impl<F : Float> Constant for F { |
|
19 type Type = F; |
|
20 #[inline] |
|
21 fn value(&self) -> F { *self } |
|
22 } |
|
23 |
|
24 /// Weighting of a [`Mapping`] by scalar multiplication. |
|
25 #[derive(Copy,Clone,Debug,Serialize)] |
|
26 pub struct Weighted<T, C : Constant> { |
|
27 /// The weight |
|
28 pub weight : C, |
|
29 /// The base [`Mapping`] being weighted. |
|
30 pub base_fn : T, |
|
31 } |
|
32 |
|
33 impl<T, C> Weighted<T, C> |
|
34 where |
|
35 C : Constant, |
|
36 { |
|
37 /// Construct from an iterator. |
|
38 pub fn new(weight : C, base_fn : T) -> Self { |
|
39 Weighted{ weight, base_fn } |
|
40 } |
|
41 } |
|
42 |
|
43 impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C> |
|
44 where |
|
45 F : Float, |
|
46 D : Space, |
|
47 T : Mapping<D, Codomain=V>, |
|
48 V : Space + ClosedMul<F>, |
|
49 C : Constant<Type=F> |
|
50 { |
|
51 type Codomain = V; |
|
52 |
|
53 #[inline] |
|
54 fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain { |
|
55 self.base_fn.apply(x) * self.weight.value() |
|
56 } |
|
57 } |
|
58 |
|
59 impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C> |
|
60 where |
|
61 F : Float, |
|
62 D : Space, |
|
63 T : DifferentiableMapping<D, DerivativeDomain=V>, |
|
64 V : Space + std::ops::Mul<F, Output=V>, |
|
65 C : Constant<Type=F> |
|
66 { |
|
67 type Derivative = V; |
|
68 |
|
69 #[inline] |
|
70 fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative { |
|
71 self.base_fn.differential(x) * self.weight.value() |
|
72 } |
|
73 } |
|
74 |
|
75 /// A sum of [`Mapping`]s. |
|
76 #[derive(Serialize, Debug, Clone)] |
|
77 pub struct MappingSum<M>(Vec<M>); |
|
78 |
|
79 impl< M> MappingSum<M> { |
|
80 /// Construct from an iterator. |
|
81 pub fn new<I : IntoIterator<Item = M>>(iter : I) -> Self { |
|
82 MappingSum(iter.into_iter().collect()) |
|
83 } |
|
84 |
|
85 /// Iterate over the component functions of the sum |
|
86 pub fn iter(&self) -> std::slice::Iter<'_, M> { |
|
87 self.0.iter() |
|
88 } |
|
89 } |
|
90 |
|
91 impl<Domain, M> Mapping<Domain> for MappingSum<M> |
|
92 where |
|
93 Domain : Space + Clone, |
|
94 M : Mapping<Domain>, |
|
95 M::Codomain : std::iter::Sum + Clone |
|
96 { |
|
97 type Codomain = M::Codomain; |
|
98 |
|
99 fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain { |
|
100 let xr = x.ref_instance(); |
|
101 self.0.iter().map(|c| c.apply(xr)).sum() |
|
102 } |
|
103 } |
|
104 |
|
105 impl<Domain, M> DifferentiableImpl<Domain> for MappingSum< M> |
|
106 where |
|
107 Domain : Space + Clone, |
|
108 M : DifferentiableMapping<Domain>, |
|
109 M :: DerivativeDomain : std::iter::Sum |
|
110 { |
|
111 type Derivative = M::DerivativeDomain; |
|
112 |
|
113 fn differential_impl<I : Instance<Domain>>(&self, x : I) -> Self::Derivative { |
|
114 let xr = x.ref_instance(); |
|
115 self.0.iter().map(|c| c.differential(xr)).sum() |
|
116 } |
|
117 } |