src/operator_arithmetic.rs

branch
dev
changeset 137
d5dfcb6abcf5
parent 133
2b13f8a0c8ba
equal deleted inserted replaced
136:22fd33834ab7 137:d5dfcb6abcf5
1 /*! 1 /*!
2 Arithmetic of [`Mapping`]s. 2 Arithmetic of [`Mapping`]s.
3 */ 3 */
4 4
5 use crate::instance::{Instance, Space}; 5 use crate::instance::{Instance, Instantiated, Space};
6 use crate::mapping::{DifferentiableImpl, DifferentiableMapping, Mapping}; 6 use crate::mapping::{DifferentiableImpl, DifferentiableMapping, Mapping};
7 use crate::types::*; 7 use crate::types::*;
8 use serde::Serialize; 8 use serde::Serialize;
9 9
10 /// A trait for encoding constant [`Float`] values 10 /// A trait for encoding constant [`Float`] values
97 M::Codomain: std::iter::Sum + Clone, 97 M::Codomain: std::iter::Sum + Clone,
98 { 98 {
99 type Codomain = M::Codomain; 99 type Codomain = M::Codomain;
100 100
101 fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain { 101 fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain {
102 x.eval_ref_decompose(|xr| self.0.iter().map(|c| c.apply(xr)).sum()) 102 let xi = x.instantiate();
103 let xr = xi.ref_inst();
104 self.0.iter().map(|c| c.apply(xr)).sum()
103 } 105 }
104 } 106 }
105 107
106 impl<Domain, M> DifferentiableImpl<Domain> for MappingSum<M> 108 impl<Domain, M> DifferentiableImpl<Domain> for MappingSum<M>
107 where 109 where
110 M::DerivativeDomain: std::iter::Sum, 112 M::DerivativeDomain: std::iter::Sum,
111 { 113 {
112 type Derivative = M::DerivativeDomain; 114 type Derivative = M::DerivativeDomain;
113 115
114 fn differential_impl<I: Instance<Domain>>(&self, x: I) -> Self::Derivative { 116 fn differential_impl<I: Instance<Domain>>(&self, x: I) -> Self::Derivative {
115 x.eval_ref_decompose(|xr| self.0.iter().map(|c| c.differential(xr)).sum()) 117 let xi = x.instantiate();
118 let xr = xi.ref_inst();
119
120 self.0.iter().map(|c| c.differential(xr)).sum()
116 } 121 }
117 } 122 }

mercurial