src/operator_arithmetic.rs

changeset 90
b3c35d16affe
parent 75
e9f4550cfa18
child 86
d5b0e496b72f
equal deleted inserted replaced
25:d14c877e14b7 90:b3c35d16affe
1 /*!
2 Arithmetic of [`Mapping`]s.
3 */
4
5 use serde::Serialize;
6 use crate::types::*;
7 use crate::instance::{Space, Instance};
8 use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping};
9
10 /// A trait for encoding constant [`Float`] values
11 pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> {
12 /// The type of the value
13 type Type : Float;
14 /// Returns the value of the constant
15 fn value(&self) -> Self::Type;
16 }
17
18 impl<F : Float> Constant for F {
19 type Type = F;
20 #[inline]
21 fn value(&self) -> F { *self }
22 }
23
24 /// Weighting of a [`Mapping`] by scalar multiplication.
25 #[derive(Copy,Clone,Debug,Serialize)]
26 pub struct Weighted<T, C : Constant> {
27 /// The weight
28 pub weight : C,
29 /// The base [`Mapping`] being weighted.
30 pub base_fn : T,
31 }
32
33 impl<T, C> Weighted<T, C>
34 where
35 C : Constant,
36 {
37 /// Construct from an iterator.
38 pub fn new(weight : C, base_fn : T) -> Self {
39 Weighted{ weight, base_fn }
40 }
41 }
42
43 impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C>
44 where
45 F : Float,
46 D : Space,
47 T : Mapping<D, Codomain=V>,
48 V : Space + ClosedMul<F>,
49 C : Constant<Type=F>
50 {
51 type Codomain = V;
52
53 #[inline]
54 fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain {
55 self.base_fn.apply(x) * self.weight.value()
56 }
57 }
58
59 impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C>
60 where
61 F : Float,
62 D : Space,
63 T : DifferentiableMapping<D, DerivativeDomain=V>,
64 V : Space + std::ops::Mul<F, Output=V>,
65 C : Constant<Type=F>
66 {
67 type Derivative = V;
68
69 #[inline]
70 fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative {
71 self.base_fn.differential(x) * self.weight.value()
72 }
73 }
74
75 /// A sum of [`Mapping`]s.
76 #[derive(Serialize, Debug, Clone)]
77 pub struct MappingSum<M>(Vec<M>);
78
79 impl< M> MappingSum<M> {
80 /// Construct from an iterator.
81 pub fn new<I : IntoIterator<Item = M>>(iter : I) -> Self {
82 MappingSum(iter.into_iter().collect())
83 }
84
85 /// Iterate over the component functions of the sum
86 pub fn iter(&self) -> std::slice::Iter<'_, M> {
87 self.0.iter()
88 }
89 }
90
91 impl<Domain, M> Mapping<Domain> for MappingSum<M>
92 where
93 Domain : Space + Clone,
94 M : Mapping<Domain>,
95 M::Codomain : std::iter::Sum + Clone
96 {
97 type Codomain = M::Codomain;
98
99 fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain {
100 let xr = x.ref_instance();
101 self.0.iter().map(|c| c.apply(xr)).sum()
102 }
103 }
104
105 impl<Domain, M> DifferentiableImpl<Domain> for MappingSum< M>
106 where
107 Domain : Space + Clone,
108 M : DifferentiableMapping<Domain>,
109 M :: DerivativeDomain : std::iter::Sum
110 {
111 type Derivative = M::DerivativeDomain;
112
113 fn differential_impl<I : Instance<Domain>>(&self, x : I) -> Self::Derivative {
114 let xr = x.ref_instance();
115 self.0.iter().map(|c| c.differential(xr)).sum()
116 }
117 }

mercurial