src/operator_arithmetic.rs

branch
dev
changeset 128
f75bf34adda0
parent 75
e9f4550cfa18
child 133
2b13f8a0c8ba
equal deleted inserted replaced
127:212f75931da0 128:f75bf34adda0
1 /*! 1 /*!
2 Arithmetic of [`Mapping`]s. 2 Arithmetic of [`Mapping`]s.
3 */ 3 */
4 4
5 use crate::instance::{Instance, Space};
6 use crate::mapping::{DifferentiableImpl, DifferentiableMapping, Mapping};
7 use crate::types::*;
5 use serde::Serialize; 8 use serde::Serialize;
6 use crate::types::*;
7 use crate::instance::{Space, Instance};
8 use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping};
9 9
10 /// A trait for encoding constant [`Float`] values 10 /// A trait for encoding constant [`Float`] values
11 pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> { 11 pub trait Constant: Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> {
12 /// The type of the value 12 /// The type of the value
13 type Type : Float; 13 type Type: Float;
14 /// Returns the value of the constant 14 /// Returns the value of the constant
15 fn value(&self) -> Self::Type; 15 fn value(&self) -> Self::Type;
16 } 16 }
17 17
18 impl<F : Float> Constant for F { 18 impl<F: Float> Constant for F {
19 type Type = F; 19 type Type = F;
20 #[inline] 20 #[inline]
21 fn value(&self) -> F { *self } 21 fn value(&self) -> F {
22 *self
23 }
22 } 24 }
23 25
24 /// Weighting of a [`Mapping`] by scalar multiplication. 26 /// Weighting of a [`Mapping`] by scalar multiplication.
25 #[derive(Copy,Clone,Debug,Serialize)] 27 #[derive(Copy, Clone, Debug, Serialize)]
26 pub struct Weighted<T, C : Constant> { 28 pub struct Weighted<T, C: Constant> {
27 /// The weight 29 /// The weight
28 pub weight : C, 30 pub weight: C,
29 /// The base [`Mapping`] being weighted. 31 /// The base [`Mapping`] being weighted.
30 pub base_fn : T, 32 pub base_fn: T,
31 } 33 }
32 34
33 impl<T, C> Weighted<T, C> 35 impl<T, C> Weighted<T, C>
34 where 36 where
35 C : Constant, 37 C: Constant,
36 { 38 {
37 /// Construct from an iterator. 39 /// Construct from an iterator.
38 pub fn new(weight : C, base_fn : T) -> Self { 40 pub fn new(weight: C, base_fn: T) -> Self {
39 Weighted{ weight, base_fn } 41 Weighted { weight, base_fn }
40 } 42 }
41 } 43 }
42 44
43 impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C> 45 impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C>
44 where 46 where
45 F : Float, 47 F: Float,
46 D : Space, 48 D: Space,
47 T : Mapping<D, Codomain=V>, 49 T: Mapping<D, Codomain = V>,
48 V : Space + ClosedMul<F>, 50 V: Space + ClosedMul<F>,
49 C : Constant<Type=F> 51 C: Constant<Type = F>,
50 { 52 {
51 type Codomain = V; 53 type Codomain = V;
52 54
53 #[inline] 55 #[inline]
54 fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain { 56 fn apply<I: Instance<D>>(&self, x: I) -> Self::Codomain {
55 self.base_fn.apply(x) * self.weight.value() 57 self.base_fn.apply(x) * self.weight.value()
56 } 58 }
57 } 59 }
58 60
59 impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C> 61 impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C>
60 where 62 where
61 F : Float, 63 F: Float,
62 D : Space, 64 D: Space,
63 T : DifferentiableMapping<D, DerivativeDomain=V>, 65 T: DifferentiableMapping<D, DerivativeDomain = V>,
64 V : Space + std::ops::Mul<F, Output=V>, 66 V: Space + std::ops::Mul<F, Output = V>,
65 C : Constant<Type=F> 67 C: Constant<Type = F>,
66 { 68 {
67 type Derivative = V; 69 type Derivative = V;
68 70
69 #[inline] 71 #[inline]
70 fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative { 72 fn differential_impl<I: Instance<D>>(&self, x: I) -> Self::Derivative {
71 self.base_fn.differential(x) * self.weight.value() 73 self.base_fn.differential(x) * self.weight.value()
72 } 74 }
73 } 75 }
74 76
75 /// A sum of [`Mapping`]s. 77 /// A sum of [`Mapping`]s.
76 #[derive(Serialize, Debug, Clone)] 78 #[derive(Serialize, Debug, Clone)]
77 pub struct MappingSum<M>(Vec<M>); 79 pub struct MappingSum<M>(Vec<M>);
78 80
79 impl< M> MappingSum<M> { 81 impl<M> MappingSum<M> {
80 /// Construct from an iterator. 82 /// Construct from an iterator.
81 pub fn new<I : IntoIterator<Item = M>>(iter : I) -> Self { 83 pub fn new<I: IntoIterator<Item = M>>(iter: I) -> Self {
82 MappingSum(iter.into_iter().collect()) 84 MappingSum(iter.into_iter().collect())
83 } 85 }
84 86
85 /// Iterate over the component functions of the sum 87 /// Iterate over the component functions of the sum
86 pub fn iter(&self) -> std::slice::Iter<'_, M> { 88 pub fn iter(&self) -> std::slice::Iter<'_, M> {
88 } 90 }
89 } 91 }
90 92
91 impl<Domain, M> Mapping<Domain> for MappingSum<M> 93 impl<Domain, M> Mapping<Domain> for MappingSum<M>
92 where 94 where
93 Domain : Space + Clone, 95 Domain: Space + Clone,
94 M : Mapping<Domain>, 96 M: Mapping<Domain>,
95 M::Codomain : std::iter::Sum + Clone 97 M::Codomain: std::iter::Sum + Clone,
96 { 98 {
97 type Codomain = M::Codomain; 99 type Codomain = M::Codomain;
98 100
99 fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain { 101 fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain {
100 let xr = x.ref_instance(); 102 let xr = x.ref_instance();
101 self.0.iter().map(|c| c.apply(xr)).sum() 103 self.0.iter().map(|c| c.apply(xr)).sum()
102 } 104 }
103 } 105 }
104 106
105 impl<Domain, M> DifferentiableImpl<Domain> for MappingSum< M> 107 impl<Domain, M> DifferentiableImpl<Domain> for MappingSum<M>
106 where 108 where
107 Domain : Space + Clone, 109 Domain: Space,
108 M : DifferentiableMapping<Domain>, 110 M: DifferentiableMapping<Domain>,
109 M :: DerivativeDomain : std::iter::Sum 111 M::DerivativeDomain: std::iter::Sum,
110 { 112 {
111 type Derivative = M::DerivativeDomain; 113 type Derivative = M::DerivativeDomain;
112 114
113 fn differential_impl<I : Instance<Domain>>(&self, x : I) -> Self::Derivative { 115 fn differential_impl<I: Instance<Domain>>(&self, x: I) -> Self::Derivative {
114 let xr = x.ref_instance(); 116 let xr = x.ref_instance();
115 self.0.iter().map(|c| c.differential(xr)).sum() 117 self.0.iter().map(|c| c.differential(xr)).sum()
116 } 118 }
117 } 119 }

mercurial