src/kernels/hat_convolution.rs

branch
dev
changeset 32
56c8adc32b09
parent 0
eb3c7813b67a
child 34
efa60bc4f743
equal deleted inserted replaced
30:bd13c2ae3450 32:56c8adc32b09
12 Bounds, 12 Bounds,
13 LocalAnalysis, 13 LocalAnalysis,
14 GlobalAnalysis, 14 GlobalAnalysis,
15 Bounded, 15 Bounded,
16 }; 16 };
17 use alg_tools::mapping::Apply; 17 use alg_tools::mapping::{Apply, Differentiable};
18 use alg_tools::maputil::array_init; 18 use alg_tools::maputil::array_init;
19 19
20 use crate::types::Lipschitz;
20 use super::base::*; 21 use super::base::*;
21 use super::ball_indicator::CubeIndicator; 22 use super::ball_indicator::CubeIndicator;
22 23
23 /// Hat convolution kernel. 24 /// Hat convolution kernel.
24 /// 25 ///
79 fn apply(&self, y : Loc<S::Type, N>) -> Self::Output { 80 fn apply(&self, y : Loc<S::Type, N>) -> Self::Output {
80 self.apply(&y) 81 self.apply(&y)
81 } 82 }
82 } 83 }
83 84
85 #[replace_float_literals(S::Type::cast_from(literal))]
86 impl<S, const N : usize> Lipschitz<L1> for HatConv<S, N>
87 where S : Constant {
88 type FloatType = S::Type;
89 #[inline]
90 fn lipschitz_factor(&self, L1 : L1) -> Option<Self::FloatType> {
91 // For any ψ_i, we have
92 // ∏_{i=1}^N ψ_i(x_i) - ∏_{i=1}^N ψ_i(y_i)
93 // = [ψ_1(x_1)-ψ_1(y_1)] ∏_{i=2}^N ψ_i(x_i)
94 // + ψ_1(y_1)[ ∏_{i=2}^N ψ_i(x_i) - ∏_{i=2}^N ψ_i(y_i)]
95 // = ∑_{j=1}^N [ψ_j(x_j)-ψ_j(y_j)]∏_{i > j} ψ_i(x_i) ∏_{i < j} ψ_i(y_i)
96 // Thus
97 // |∏_{i=1}^N ψ_i(x_i) - ∏_{i=1}^N ψ_i(y_i)|
98 // ≤ ∑_{j=1}^N |ψ_j(x_j)-ψ_j(y_j)| ∏_{j ≠ i} \max_i |ψ_i|
99 let σ = self.radius();
100 Some((self.lipschitz_1d_σ1() / (σ*σ)) * (self.value_1d_σ1(0.0) / σ))
101 }
102 }
103
104 impl<S, const N : usize> Lipschitz<L2> for HatConv<S, N>
105 where S : Constant {
106 type FloatType = S::Type;
107 #[inline]
108 fn lipschitz_factor(&self, L2 : L2) -> Option<Self::FloatType> {
109 self.lipschitz_factor(L1).map(|l1| l1 * <S::Type>::cast_from(N).sqrt())
110 }
111 }
112
113
114 impl<'a, S, const N : usize> Differentiable<&'a Loc<S::Type, N>> for HatConv<S, N>
115 where S : Constant {
116 type Output = Loc<S::Type, N>;
117 #[inline]
118 fn differential(&self, y : &'a Loc<S::Type, N>) -> Self::Output {
119 let σ = self.radius();
120 let σ2 = σ * σ;
121 let vs = y.map(|x| {
122 self.value_1d_σ1(x / σ) / σ
123 });
124 product_differential(y, &vs, |x| {
125 self.diff_1d_σ1(x / σ) / σ2
126 })
127 }
128 }
129
130 impl<'a, S, const N : usize> Differentiable<Loc<S::Type, N>> for HatConv<S, N>
131 where S : Constant {
132 type Output = Loc<S::Type, N>;
133 #[inline]
134 fn differential(&self, y : Loc<S::Type, N>) -> Self::Output {
135 self.differential(&y)
136 }
137 }
84 138
85 #[replace_float_literals(S::Type::cast_from(literal))] 139 #[replace_float_literals(S::Type::cast_from(literal))]
86 impl<'a, F : Float, S, const N : usize> HatConv<S, N> 140 impl<'a, F : Float, S, const N : usize> HatConv<S, N>
87 where S : Constant<Type=F> { 141 where S : Constant<Type=F> {
88 /// Computes the value of the kernel for $n=1$ with $σ=1$. 142 /// Computes the value of the kernel for $n=1$ with $σ=1$.
94 } else if y > 0.5 { 148 } else if y > 0.5 {
95 - (8.0/3.0) * (y - 1.0).powi(3) 149 - (8.0/3.0) * (y - 1.0).powi(3)
96 } else /* 0 ≤ y ≤ 0.5 */ { 150 } else /* 0 ≤ y ≤ 0.5 */ {
97 (4.0/3.0) + 8.0 * y * y * (y - 1.0) 151 (4.0/3.0) + 8.0 * y * y * (y - 1.0)
98 } 152 }
153 }
154
155 /// Computes the differential of the kernel for $n=1$ with $σ=1$.
156 #[inline]
157 fn diff_1d_σ1(&self, x : F) -> F {
158 let y = x.abs();
159 if y >= 1.0 {
160 0.0
161 } else if y > 0.5 {
162 - 8.0 * (y - 1.0).powi(2)
163 } else /* 0 ≤ y ≤ 0.5 */ {
164 (24.0 * y - 16.0) * y
165 }
166 }
167
168 /// Computes the Lipschitz factor of the kernel for $n=1$ with $σ=1$.
169 #[inline]
170 fn lipschitz_1d_σ1(&self) -> F {
171 // Maximal absolute differential achieved at ±0.5 by diff_1d_σ1 analysis
172 2.0
99 } 173 }
100 } 174 }
101 175
102 impl<'a, S, const N : usize> Support<S::Type, N> for HatConv<S, N> 176 impl<'a, S, const N : usize> Support<S::Type, N> for HatConv<S, N>
103 where S : Constant { 177 where S : Constant {
199 fn apply(&self, y : Loc<F, N>) -> F { 273 fn apply(&self, y : Loc<F, N>) -> F {
200 self.apply(&y) 274 self.apply(&y)
201 } 275 }
202 } 276 }
203 277
278 #[replace_float_literals(F::cast_from(literal))]
279 impl<'a, F : Float, R, C, const N : usize> Differentiable<&'a Loc<F, N>>
280 for Convolution<CubeIndicator<R, N>, HatConv<C, N>>
281 where R : Constant<Type=F>,
282 C : Constant<Type=F> {
283
284 type Output = Loc<F, N>;
285
286 #[inline]
287 fn differential(&self, y : &'a Loc<F, N>) -> Loc<F, N> {
288 let Convolution(ref ind, ref hatconv) = self;
289 let β = ind.r.value();
290 let σ = hatconv.radius();
291 let σ2 = σ * σ;
292
293 let vs = y.map(|x| {
294 self.value_1d_σ1(x / σ, β / σ)
295 });
296 product_differential(y, &vs, |x| {
297 self.diff_1d_σ1(x / σ, β / σ) / σ2
298 })
299 }
300 }
301
302 impl<'a, F : Float, R, C, const N : usize> Differentiable<Loc<F, N>>
303 for Convolution<CubeIndicator<R, N>, HatConv<C, N>>
304 where R : Constant<Type=F>,
305 C : Constant<Type=F> {
306
307 type Output = Loc<F, N>;
308
309 #[inline]
310 fn differential(&self, y : Loc<F, N>) -> Loc<F, N> {
311 self.differential(&y)
312 }
313 }
314
315 /// Integrate $f$, whose support is $[c, d]$, on $[a, b]$.
316 /// If $b > d$, add $g()$ to the result.
317 #[inline]
318 #[replace_float_literals(F::cast_from(literal))]
319 fn i<F: Float>(a : F, b : F, c : F, d : F, f : impl Fn(F) -> F,
320 g : impl Fn() -> F) -> F {
321 if b < c {
322 0.0
323 } else if b <= d {
324 if a <= c {
325 f(b) - f(c)
326 } else {
327 f(b) - f(a)
328 }
329 } else /* b > d */ {
330 g() + if a <= c {
331 f(d) - f(c)
332 } else if a < d {
333 f(d) - f(a)
334 } else {
335 0.0
336 }
337 }
338 }
204 339
205 #[replace_float_literals(F::cast_from(literal))] 340 #[replace_float_literals(F::cast_from(literal))]
206 impl<F : Float, C, R, const N : usize> Convolution<CubeIndicator<R, N>, HatConv<C, N>> 341 impl<F : Float, C, R, const N : usize> Convolution<CubeIndicator<R, N>, HatConv<C, N>>
207 where R : Constant<Type=F>, 342 where R : Constant<Type=F>,
208 C : Constant<Type=F> { 343 C : Constant<Type=F> {
344
345 /// Calculates the value of the 1D hat convolution further convolved by a interval indicator.
346 /// As both functions are piecewise polynomials, this is implemented by explicit integral over
347 /// all subintervals of polynomiality of the cube indicator, using easily formed
348 /// antiderivatives.
209 #[inline] 349 #[inline]
210 pub fn value_1d_σ1(&self, x : F, β : F) -> F { 350 pub fn value_1d_σ1(&self, x : F, β : F) -> F {
211 // The integration interval 351 // The integration interval
212 let a = x - β; 352 let a = x - β;
213 let b = x + β; 353 let b = x + β;
216 fn pow4<F : Float>(x : F) -> F { 356 fn pow4<F : Float>(x : F) -> F {
217 let y = x * x; 357 let y = x * x;
218 y * y 358 y * y
219 } 359 }
220 360
221 /// Integrate $f$, whose support is $[c, d]$, on $[a, b]$.
222 /// If $b > d$, add $g()$ to the result.
223 #[inline]
224 fn i<F: Float>(a : F, b : F, c : F, d : F, f : impl Fn(F) -> F,
225 g : impl Fn() -> F) -> F {
226 if b < c {
227 0.0
228 } else if b <= d {
229 if a <= c {
230 f(b) - f(c)
231 } else {
232 f(b) - f(a)
233 }
234 } else /* b > d */ {
235 g() + if a <= c {
236 f(d) - f(c)
237 } else if a < d {
238 f(d) - f(a)
239 } else {
240 0.0
241 }
242 }
243 }
244
245 // Observe the factor 1/6 at the front from the antiderivatives below. 361 // Observe the factor 1/6 at the front from the antiderivatives below.
246 // The factor 4 is from normalisation of the original function. 362 // The factor 4 is from normalisation of the original function.
247 (4.0/6.0) * i(a, b, -1.0, -0.5, 363 (4.0/6.0) * i(a, b, -1.0, -0.5,
248 // (2/3) (y+1)^3 on -1 < y ≤ - 1/2 364 // (2/3) (y+1)^3 on -1 < y ≤ -1/2
249 // The antiderivative is (2/12)(y+1)^4 = (1/6)(y+1)^4 365 // The antiderivative is (2/12)(y+1)^4 = (1/6)(y+1)^4
250 |y| pow4(y+1.0), 366 |y| pow4(y+1.0),
251 || i(a, b, -0.5, 0.0, 367 || i(a, b, -0.5, 0.0,
252 // -2 y^3 - 2 y^2 + 1/3 on -1/2 < y ≤ 0 368 // -2 y^3 - 2 y^2 + 1/3 on -1/2 < y ≤ 0
253 // The antiderivative is -1/2 y^4 - 2/3 y^3 + 1/3 y 369 // The antiderivative is -1/2 y^4 - 2/3 y^3 + 1/3 y
264 ) 380 )
265 ) 381 )
266 ) 382 )
267 ) 383 )
268 } 384 }
385
386 /// Calculates the derivative of the 1D hat convolution further convolved by a interval
387 /// indicator. The implementation is similar to [`Self::value_1d_σ1`], using the fact that
388 /// $(θ * ψ)' = θ * ψ'$.
389 #[inline]
390 pub fn diff_1d_σ1(&self, x : F, β : F) -> F {
391 // The integration interval
392 let a = x - β;
393 let b = x + β;
394
395 // The factor 4 is from normalisation of the original function.
396 4.0 * i(a, b, -1.0, -0.5,
397 // (2/3) (y+1)^3 on -1 < y ≤ -1/2
398 |y| (2.0/3.0) * (y + 1.0).powi(3),
399 || i(a, b, -0.5, 0.0,
400 // -2 y^3 - 2 y^2 + 1/3 on -1/2 < y ≤ 0
401 |y| -2.0*(y - 1.0) * y * y + (1.0/3.0),
402 || i(a, b, 0.0, 0.5,
403 // 2 y^3 - 2 y^2 + 1/3 on 0 < y < 1/2
404 |y| 2.0*(y - 1.0) * y * y + (1.0/3.0),
405 || i(a, b, 0.5, 1.0,
406 // -(2/3) (y-1)^3 on 1/2 < y ≤ 1
407 |y| -(2.0/3.0) * (y - 1.0).powi(3),
408 || 0.0
409 )
410 )
411 )
412 )
413 }
269 } 414 }
270 415
271 impl<F : Float, R, C, const N : usize> 416 impl<F : Float, R, C, const N : usize>
272 Convolution<CubeIndicator<R, N>, HatConv<C, N>> 417 Convolution<CubeIndicator<R, N>, HatConv<C, N>>
273 where R : Constant<Type=F>, 418 where R : Constant<Type=F>,

mercurial