src/kernels/gaussian.rs

branch
dev
changeset 33
aec67cdd6b14
parent 31
6105b5cd8d89
child 34
efa60bc4f743
equal deleted inserted replaced
32:56c8adc32b09 33:aec67cdd6b14
218 SupportProductFirst(ref cut, 218 SupportProductFirst(ref cut,
219 ref gaussian)) = self; 219 ref gaussian)) = self;
220 let a = cut.r.value(); 220 let a = cut.r.value();
221 let b = ind.r.value(); 221 let b = ind.r.value();
222 let σ = gaussian.variance.value().sqrt(); 222 let σ = gaussian.variance.value().sqrt();
223 let π = F::PI;
224 let t = F::SQRT_2 * σ; 223 let t = F::SQRT_2 * σ;
225 let c = σ * (8.0/π).sqrt(); 224 let c = 0.5; // 1/(σ√(2π) * σ√(π/2) = 1/2
226 225
227 // This is just a product of one-dimensional versions 226 // This is just a product of one-dimensional versions
228 let unscaled = y.product_map(|x| { 227 y.product_map(|x| {
229 let c1 = -(a.min(b + x)); //(-a).max(-x-b); 228 let c1 = -(a.min(b + x)); //(-a).max(-x-b);
230 let c2 = a.min(b - x); 229 let c2 = a.min(b - x);
231 if c1 >= c2 { 230 if c1 >= c2 {
232 0.0 231 0.0
233 } else { 232 } else {
234 let e1 = F::cast_from(erf((c1 / t).as_())); 233 let e1 = F::cast_from(erf((c1 / t).as_()));
235 let e2 = F::cast_from(erf((c2 / t).as_())); 234 let e2 = F::cast_from(erf((c2 / t).as_()));
236 debug_assert!(e2 >= e1); 235 debug_assert!(e2 >= e1);
237 c * (e2 - e1) 236 c * (e2 - e1)
238 } 237 }
239 }); 238 })
240
241 unscaled / gaussian.scale()
242 } 239 }
243 } 240 }
244 241
245 impl<F : Float, R, C, S, const N : usize> Apply<Loc<F, N>> 242 impl<F : Float, R, C, S, const N : usize> Apply<Loc<F, N>>
246 for Convolution<CubeIndicator<R, N>, BasicCutGaussian<C, S, N>> 243 for Convolution<CubeIndicator<R, N>, BasicCutGaussian<C, S, N>>
274 SupportProductFirst(ref cut, 271 SupportProductFirst(ref cut,
275 ref gaussian)) = self; 272 ref gaussian)) = self;
276 let a = cut.r.value(); 273 let a = cut.r.value();
277 let b = ind.r.value(); 274 let b = ind.r.value();
278 let σ = gaussian.variance.value().sqrt(); 275 let σ = gaussian.variance.value().sqrt();
279 let π = F::PI;
280 let t = F::SQRT_2 * σ; 276 let t = F::SQRT_2 * σ;
281 let c = σ * (8.0/π).sqrt(); 277 let c = 0.5; // 1/(σ√(2π) * σ√(π/2) = 1/2
282 let cd = (8.0).sqrt(); // σ * (8.0/π).sqrt() / t * (√2/π) 278 let c_div_t = c / t;
283 279
284 // Calculate the values for all component functions of the 280 // Calculate the values for all component functions of the
285 // product. This is just the loop from apply above. 281 // product. This is just the loop from apply above.
286 let unscaled_vs = y.map(|x| { 282 let unscaled_vs = y.map(|x| {
287 let c1 = -(a.min(b + x)); //(-a).max(-x-b); 283 let c1 = -(a.min(b + x)); //(-a).max(-x-b);
304 } else { 300 } else {
305 // erf'(z) = (2/√π)*exp(-z^2), and we get extra factor -1/√(2*σ) = -1/t 301 // erf'(z) = (2/√π)*exp(-z^2), and we get extra factor -1/√(2*σ) = -1/t
306 // from the chain rule 302 // from the chain rule
307 let de1 = (-(c1/t).powi(2)).exp(); 303 let de1 = (-(c1/t).powi(2)).exp();
308 let de2 = (-(c2/t).powi(2)).exp(); 304 let de2 = (-(c2/t).powi(2)).exp();
309 cd * (de1 - de2) 305 c_div_t * (de1 - de2)
310 } 306 }
311 }) / gaussian.scale() 307 })
312 } 308 }
313 } 309 }
314 310
315 impl<F : Float, R, C, S, const N : usize> Differentiable<Loc<F, N>> 311 impl<F : Float, R, C, S, const N : usize> Differentiable<Loc<F, N>>
316 for Convolution<CubeIndicator<R, N>, BasicCutGaussian<C, S, N>> 312 for Convolution<CubeIndicator<R, N>, BasicCutGaussian<C, S, N>>

mercurial