src/kernels/base.rs

branch
dev
changeset 32
56c8adc32b09
parent 0
eb3c7813b67a
child 35
b087e3eab191
equal deleted inserted replaced
30:bd13c2ae3450 32:56c8adc32b09
12 Bounds, 12 Bounds,
13 LocalAnalysis, 13 LocalAnalysis,
14 GlobalAnalysis, 14 GlobalAnalysis,
15 Bounded, 15 Bounded,
16 }; 16 };
17 use alg_tools::mapping::Apply; 17 use alg_tools::mapping::{Apply, Differentiable};
18 use alg_tools::maputil::{array_init, map2}; 18 use alg_tools::maputil::{array_init, map2, map1_indexed};
19 use alg_tools::sets::SetOrd; 19 use alg_tools::sets::SetOrd;
20 20
21 use crate::fourier::Fourier; 21 use crate::fourier::Fourier;
22 use crate::types::Lipschitz;
22 23
23 /// Representation of the product of two kernels. 24 /// Representation of the product of two kernels.
24 /// 25 ///
25 /// The kernels typically implement [`Support`] and [`Mapping`][alg_tools::mapping::Mapping]. 26 /// The kernels typically implement [`Support`] and [`Mapping`][alg_tools::mapping::Mapping].
26 /// 27 ///
53 fn apply(&self, x : &'a Loc<F, N>) -> Self::Output { 54 fn apply(&self, x : &'a Loc<F, N>) -> Self::Output {
54 self.0.apply(x) * self.1.apply(x) 55 self.0.apply(x) * self.1.apply(x)
55 } 56 }
56 } 57 }
57 58
59 impl<A, B, F : Float, const N : usize> Differentiable<Loc<F, N>>
60 for SupportProductFirst<A, B>
61 where A : for<'a> Apply<&'a Loc<F, N>, Output=F>
62 + for<'a> Differentiable<&'a Loc<F, N>, Output=Loc<F, N>>,
63 B : for<'a> Apply<&'a Loc<F, N>, Output=F>
64 + for<'a> Differentiable<&'a Loc<F, N>, Output=Loc<F, N>> {
65 type Output = Loc<F, N>;
66 #[inline]
67 fn differential(&self, x : Loc<F, N>) -> Self::Output {
68 self.0.differential(&x) * self.1.apply(&x) + self.1.differential(&x) * self.0.apply(&x)
69 }
70 }
71
72 impl<'a, A, B, F : Float, const N : usize> Differentiable<&'a Loc<F, N>>
73 for SupportProductFirst<A, B>
74 where A : Apply<&'a Loc<F, N>, Output=F>
75 + Differentiable<&'a Loc<F, N>, Output=Loc<F, N>>,
76 B : Apply<&'a Loc<F, N>, Output=F>
77 + Differentiable<&'a Loc<F, N>, Output=Loc<F, N>> {
78 type Output = Loc<F, N>;
79 #[inline]
80 fn differential(&self, x : &'a Loc<F, N>) -> Self::Output {
81 self.0.differential(&x) * self.1.apply(&x) + self.1.differential(&x) * self.0.apply(&x)
82 }
83 }
84
85
58 impl<'a, A, B, F : Float, const N : usize> Support<F, N> 86 impl<'a, A, B, F : Float, const N : usize> Support<F, N>
59 for SupportProductFirst<A, B> 87 for SupportProductFirst<A, B>
60 where A : Support<F, N>, 88 where A : Support<F, N>,
61 B : Support<F, N> { 89 B : Support<F, N> {
62 #[inline] 90 #[inline]
128 fn apply(&self, x : Loc<F, N>) -> Self::Output { 156 fn apply(&self, x : Loc<F, N>) -> Self::Output {
129 self.0.apply(&x) + self.1.apply(&x) 157 self.0.apply(&x) + self.1.apply(&x)
130 } 158 }
131 } 159 }
132 160
161 impl<'a, A, B, F : Float, const N : usize> Differentiable<&'a Loc<F, N>>
162 for SupportSum<A, B>
163 where A : Differentiable<&'a Loc<F, N>, Output=Loc<F, N>>,
164 B : Differentiable<&'a Loc<F, N>, Output=Loc<F, N>> {
165 type Output = Loc<F, N>;
166 #[inline]
167 fn differential(&self, x : &'a Loc<F, N>) -> Self::Output {
168 self.0.differential(x) + self.1.differential(x)
169 }
170 }
171
172 impl<A, B, F : Float, const N : usize> Differentiable<Loc<F, N>>
173 for SupportSum<A, B>
174 where A : for<'a> Differentiable<&'a Loc<F, N>, Output=Loc<F, N>>,
175 B : for<'a> Differentiable<&'a Loc<F, N>, Output=Loc<F, N>> {
176 type Output = Loc<F, N>;
177 #[inline]
178 fn differential(&self, x : Loc<F, N>) -> Self::Output {
179 self.0.differential(&x) + self.1.differential(&x)
180 }
181 }
182
133 impl<'a, A, B, F : Float, const N : usize> Support<F, N> 183 impl<'a, A, B, F : Float, const N : usize> Support<F, N>
134 for SupportSum<A, B> 184 for SupportSum<A, B>
135 where A : Support<F, N>, 185 where A : Support<F, N>,
136 B : Support<F, N>, 186 B : Support<F, N>,
137 Cube<F, N> : SetOrd { 187 Cube<F, N> : SetOrd {
171 #[inline] 221 #[inline]
172 fn local_analysis(&self, cube : &Cube<F, N>) -> Bounds<F> { 222 fn local_analysis(&self, cube : &Cube<F, N>) -> Bounds<F> {
173 self.0.local_analysis(cube) + self.1.local_analysis(cube) 223 self.0.local_analysis(cube) + self.1.local_analysis(cube)
174 } 224 }
175 } 225 }
226
227 impl<F : Float, M : Copy, A, B> Lipschitz<M> for SupportSum<A, B>
228 where A : Lipschitz<M, FloatType = F>,
229 B : Lipschitz<M, FloatType = F> {
230 type FloatType = F;
231
232 fn lipschitz_factor(&self, m : M) -> Option<F> {
233 match (self.0.lipschitz_factor(m), self.1.lipschitz_factor(m)) {
234 (Some(l0), Some(l1)) => Some(l0 + l1),
235 _ => None
236 }
237 }
238 }
239
176 240
177 /// Representation of the convolution of two kernels. 241 /// Representation of the convolution of two kernels.
178 /// 242 ///
179 /// The kernels typically implement [`Support`]s and [`Mapping`][alg_tools::mapping::Mapping]. 243 /// The kernels typically implement [`Support`]s and [`Mapping`][alg_tools::mapping::Mapping].
180 // 244 //
185 pub A, 249 pub A,
186 /// Second kernel 250 /// Second kernel
187 pub B 251 pub B
188 ); 252 );
189 253
254 impl<F : Float, M, A, B> Lipschitz<M> for Convolution<A, B>
255 where A : Bounded<F> ,
256 B : Lipschitz<M, FloatType = F> {
257 type FloatType = F;
258
259 fn lipschitz_factor(&self, m : M) -> Option<F> {
260 self.1.lipschitz_factor(m).map(|l| l * self.0.bounds().uniform())
261 }
262 }
263
190 /// Representation of the autoconvolution of a kernel. 264 /// Representation of the autoconvolution of a kernel.
191 /// 265 ///
192 /// The kernel typically implements [`Support`] and [`Mapping`][alg_tools::mapping::Mapping]. 266 /// The kernel typically implements [`Support`] and [`Mapping`][alg_tools::mapping::Mapping].
193 /// 267 ///
194 /// Trait implementations have to be on a case-by-case basis. 268 /// Trait implementations have to be on a case-by-case basis.
195 #[derive(Copy,Clone,Serialize,Debug,Eq,PartialEq)] 269 #[derive(Copy,Clone,Serialize,Debug,Eq,PartialEq)]
196 pub struct AutoConvolution<A>( 270 pub struct AutoConvolution<A>(
197 /// The kernel to be autoconvolved 271 /// The kernel to be autoconvolved
198 pub A 272 pub A
199 ); 273 );
274
275 impl<F : Float, M, C> Lipschitz<M> for AutoConvolution<C>
276 where C : Lipschitz<M, FloatType = F> + Bounded<F> {
277 type FloatType = F;
278
279 fn lipschitz_factor(&self, m : M) -> Option<F> {
280 self.0.lipschitz_factor(m).map(|l| l * self.0.bounds().uniform())
281 }
282 }
283
200 284
201 /// Representation a multi-dimensional product of a one-dimensional kernel. 285 /// Representation a multi-dimensional product of a one-dimensional kernel.
202 /// 286 ///
203 /// For $G: ℝ → ℝ$, this is the function $F(x\_1, …, x\_n) := \prod_{i=1}^n G(x\_i)$. 287 /// For $G: ℝ → ℝ$, this is the function $F(x\_1, …, x\_n) := \prod_{i=1}^n G(x\_i)$.
204 /// The kernel $G$ typically implements [`Support`] and [`Mapping`][alg_tools::mapping::Mapping] 288 /// The kernel $G$ typically implements [`Support`] and [`Mapping`][alg_tools::mapping::Mapping]
224 where G : Apply<Loc<F, 1>, Output=F> { 308 where G : Apply<Loc<F, 1>, Output=F> {
225 type Output = F; 309 type Output = F;
226 #[inline] 310 #[inline]
227 fn apply(&self, x : Loc<F, N>) -> F { 311 fn apply(&self, x : Loc<F, N>) -> F {
228 x.into_iter().map(|y| self.0.apply(Loc([y]))).product() 312 x.into_iter().map(|y| self.0.apply(Loc([y]))).product()
313 }
314 }
315
316 impl<'a, G, F : Float, const N : usize> Differentiable<&'a Loc<F, N>>
317 for UniformProduct<G, N>
318 where G : Apply<Loc<F, 1>, Output=F> + Differentiable<Loc<F, 1>, Output=F> {
319 type Output = Loc<F, N>;
320 #[inline]
321 fn differential(&self, x : &'a Loc<F, N>) -> Loc<F, N> {
322 let vs = x.map(|y| self.0.apply(Loc([y])));
323 product_differential(x, &vs, |y| self.0.differential(Loc([y])))
324 }
325 }
326
327 /// Helper function to calulate the differential of $f(x)=∏_{i=1}^N g(x_i)$.
328 ///
329 /// The vector `x` is the location, `vs` consists of the values `g(x_i)`, and
330 /// `gd` calculates the derivative `g'`.
331 #[inline]
332 pub(crate) fn product_differential<F : Float, G : Fn(F) -> F, const N : usize>(
333 x : &Loc<F, N>,
334 vs : &Loc<F, N>,
335 gd : G
336 ) -> Loc<F, N> {
337 map1_indexed(x, |i, &y| {
338 gd(y) * vs.iter()
339 .zip(0..)
340 .filter_map(|(v, j)| (j != i).then_some(*v))
341 .product()
342 }).into()
343 }
344
345 impl<G, F : Float, const N : usize> Differentiable<Loc<F, N>>
346 for UniformProduct<G, N>
347 where G : Apply<Loc<F, 1>, Output=F> + Differentiable<Loc<F, 1>, Output=F> {
348 type Output = Loc<F, N>;
349 #[inline]
350 fn differential(&self, x : Loc<F, N>) -> Loc<F, N> {
351 self.differential(&x)
229 } 352 }
230 } 353 }
231 354
232 impl<G, F : Float, const N : usize> Support<F, N> 355 impl<G, F : Float, const N : usize> Support<F, N>
233 for UniformProduct<G, N> 356 for UniformProduct<G, N>

mercurial