| |
1 /*! |
| |
2 Arithmetic of [`Mapping`]s. |
| |
3 */ |
| |
4 |
| |
5 use serde::Serialize; |
| |
6 use crate::types::*; |
| |
7 use crate::instance::{Space, Instance}; |
| |
8 use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping}; |
| |
9 |
| |
10 /// A trait for encoding constant [`Float`] values |
| |
11 pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> { |
| |
12 /// The type of the value |
| |
13 type Type : Float; |
| |
14 /// Returns the value of the constant |
| |
15 fn value(&self) -> Self::Type; |
| |
16 } |
| |
17 |
| |
18 impl<F : Float> Constant for F { |
| |
19 type Type = F; |
| |
20 #[inline] |
| |
21 fn value(&self) -> F { *self } |
| |
22 } |
| |
23 |
| |
24 /// Weighting of a [`Support`] and [`Apply`] by scalar multiplication; |
| |
25 /// output of [`Support::weigh`]. |
| |
26 #[derive(Copy,Clone,Debug,Serialize)] |
| |
27 pub struct Weighted<T, C : Constant> { |
| |
28 /// The weight |
| |
29 pub weight : C, |
| |
30 /// The base [`Support`] or [`Apply`] being weighted. |
| |
31 pub base_fn : T, |
| |
32 } |
| |
33 |
| |
34 impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C> |
| |
35 where |
| |
36 F : Float, |
| |
37 D : Space, |
| |
38 T : Mapping<D, Codomain=V>, |
| |
39 V : Space + ClosedMul<F>, |
| |
40 C : Constant<Type=F> |
| |
41 { |
| |
42 type Codomain = V; |
| |
43 |
| |
44 #[inline] |
| |
45 fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain { |
| |
46 self.base_fn.apply(x) * self.weight.value() |
| |
47 } |
| |
48 } |
| |
49 |
| |
50 impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C> |
| |
51 where |
| |
52 F : Float, |
| |
53 D : Space, |
| |
54 T : DifferentiableMapping<D, DerivativeDomain=V>, |
| |
55 V : Space + std::ops::Mul<F, Output=V>, |
| |
56 C : Constant<Type=F> |
| |
57 { |
| |
58 type Derivative = V; |
| |
59 |
| |
60 #[inline] |
| |
61 fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative { |
| |
62 self.base_fn.differential(x) * self.weight.value() |
| |
63 } |
| |
64 } |