47 pub struct TransportConfig<F : Float> { |
47 pub struct TransportConfig<F : Float> { |
48 /// Transport step length $θ$ normalised to $(0, 1)$. |
48 /// Transport step length $θ$ normalised to $(0, 1)$. |
49 pub θ0 : F, |
49 pub θ0 : F, |
50 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. |
50 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. |
51 pub adaptation : F, |
51 pub adaptation : F, |
52 /// Transport tolerance wrt. ω |
52 /// A priori transport tolerance multiplier (C_pri) |
53 pub tolerance_ω : F, |
53 pub tolerance_mult_pri : F, |
54 /// Transport tolerance wrt. ∇v |
54 /// A posteriori transport tolerance multiplier (C_pos) |
55 pub tolerance_dv : F, |
55 pub tolerance_mult_pos : F, |
56 } |
56 } |
57 |
57 |
58 #[replace_float_literals(F::cast_from(literal))] |
58 #[replace_float_literals(F::cast_from(literal))] |
59 impl <F : Float> TransportConfig<F> { |
59 impl <F : Float> TransportConfig<F> { |
60 /// Check that the parameters are ok. Panics if not. |
60 /// Check that the parameters are ok. Panics if not. |
61 pub fn check(&self) { |
61 pub fn check(&self) { |
62 assert!(self.θ0 > 0.0); |
62 assert!(self.θ0 > 0.0); |
63 assert!(0.0 < self.adaptation && self.adaptation < 1.0); |
63 assert!(0.0 < self.adaptation && self.adaptation < 1.0); |
64 assert!(self.tolerance_dv > 0.0); |
64 assert!(self.tolerance_mult_pri > 0.0); |
65 assert!(self.tolerance_ω > 0.0); |
65 assert!(self.tolerance_mult_pos > 0.0); |
66 } |
66 } |
67 } |
67 } |
68 |
68 |
69 #[replace_float_literals(F::cast_from(literal))] |
69 #[replace_float_literals(F::cast_from(literal))] |
70 impl<F : Float> Default for TransportConfig<F> { |
70 impl<F : Float> Default for TransportConfig<F> { |
71 fn default() -> Self { |
71 fn default() -> Self { |
72 TransportConfig { |
72 TransportConfig { |
73 θ0 : 0.01, |
73 θ0 : 0.4, |
74 adaptation : 0.9, |
74 adaptation : 0.9, |
75 tolerance_ω : 1000.0, // TODO: no idea what this should be |
75 tolerance_mult_pos : 100.0, |
76 tolerance_dv : 1000.0, // TODO: no idea what this should be |
76 tolerance_mult_pri : 1000.0, |
77 } |
77 } |
78 } |
78 } |
79 } |
79 } |
80 |
80 |
81 /// Settings for [`pointsource_sliding_fb_reg`]. |
81 /// Settings for [`pointsource_sliding_fb_reg`]. |
179 } |
179 } |
180 }, |
180 }, |
181 FullyAdaptive{ l : ref mut adaptive_ℓ_v, ref mut max_transport, g : ref calculate_θ } => { |
181 FullyAdaptive{ l : ref mut adaptive_ℓ_v, ref mut max_transport, g : ref calculate_θ } => { |
182 *max_transport = max_transport.max(γ1.norm(Radon)); |
182 *max_transport = max_transport.max(γ1.norm(Radon)); |
183 let mut θ = calculate_θ(*adaptive_ℓ_v, *max_transport); |
183 let mut θ = calculate_θ(*adaptive_ℓ_v, *max_transport); |
184 loop { |
184 // Do two runs through the spikes to update θ, breaking if first run did not cause |
185 let θτ = τ * θ; |
185 // a change. |
|
186 for _i in 0..=1 { |
|
187 let mut changes = false; |
186 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
188 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { |
187 let dv_x = v.differential(&δ.x); |
189 let dv_x = v.differential(&δ.x); |
188 ρ.x = δ.x - &dv_x * (ρ.α.signum() * θτ); |
190 let g = &dv_x * (ρ.α.signum() * θ * τ); |
189 // Estimate Lipschitz factor of ∇v |
191 ρ.x = δ.x - g; |
190 let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2(); |
192 let n = g.norm2(); |
191 *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v); |
193 if n >= F::EPSILON { |
|
194 // Estimate Lipschitz factor of ∇v |
|
195 let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2() / n; |
|
196 *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v); |
|
197 θ = calculate_θ(*adaptive_ℓ_v, *max_transport); |
|
198 changes = true |
|
199 } |
192 } |
200 } |
193 let new_θ = calculate_θ(*adaptive_ℓ_v / tconfig.adaptation, *max_transport); |
201 if !changes { |
194 if new_θ <= θ { |
|
195 break |
202 break |
196 } |
203 } |
197 θ = new_θ; |
|
198 } |
204 } |
199 } |
205 } |
200 } |
206 } |
201 |
207 |
202 // 2. Adjust transport mass, if needed. |
208 // 2. Adjust transport mass, if needed. |
207 let nr =γ1.norm(Radon); |
213 let nr =γ1.norm(Radon); |
208 let n = τ * 2.0 * opAnorm * (opAapply(&*γ1)-opAapply(&*μ)).norm2(); |
214 let n = τ * 2.0 * opAnorm * (opAapply(&*γ1)-opAapply(&*μ)).norm2(); |
209 if n <= 0.0 || nr <= 0.0 { |
215 if n <= 0.0 || nr <= 0.0 { |
210 break |
216 break |
211 } |
217 } |
212 let reduction_needed = nr - (ε * tconfig.tolerance_dv / n); |
218 let reduction_needed = nr - (ε * tconfig.tolerance_mult_pri / n); |
213 if reduction_needed <= 0.0 { |
219 if reduction_needed <= 0.0 { |
214 break |
220 break |
215 } |
221 } |
216 let (min_nonzero, n_nonzero) = γ1.iter_masses() |
222 let (min_nonzero, n_nonzero) = γ1.iter_masses() |
217 .map(|α| α.abs()) |
223 .map(|α| α.abs()) |
237 let na = a.norm2(); |
243 let na = a.norm2(); |
238 let n = τ * 2.0 * opAnorm * na; |
244 let n = τ * 2.0 * opAnorm * na; |
239 if n <= 0.0 || nr <= 0.0 { |
245 if n <= 0.0 || nr <= 0.0 { |
240 break |
246 break |
241 } |
247 } |
242 let reduction_needed = nr - (ε * tconfig.tolerance_dv / n); |
248 let reduction_needed = nr - (ε * tconfig.tolerance_mult_pri / n); |
243 if reduction_needed <= 0.0 { |
249 if reduction_needed <= 0.0 { |
244 break |
250 break |
245 } |
251 } |
246 let mut max_d = 0.0; |
252 let mut max_d = 0.0; |
247 let mut max_d_ind = 0; |
253 let mut max_d_ind = 0; |
286 pub(crate) fn aposteriori_transport<F, const N : usize>( |
292 pub(crate) fn aposteriori_transport<F, const N : usize>( |
287 γ1 : &mut RNDM<F, N>, |
293 γ1 : &mut RNDM<F, N>, |
288 μ : &mut RNDM<F, N>, |
294 μ : &mut RNDM<F, N>, |
289 μ_base_minus_γ0 : &mut RNDM<F, N>, |
295 μ_base_minus_γ0 : &mut RNDM<F, N>, |
290 μ_base_masses : &Vec<F>, |
296 μ_base_masses : &Vec<F>, |
|
297 extra : Option<F>, |
291 ε : F, |
298 ε : F, |
292 tconfig : &TransportConfig<F> |
299 tconfig : &TransportConfig<F> |
293 ) -> bool |
300 ) -> bool |
294 where F : Float + ToNalgebraRealField { |
301 where F : Float + ToNalgebraRealField { |
295 |
302 |
306 |
313 |
307 // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z). |
314 // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z). |
308 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1}, |
315 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1}, |
309 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient. |
316 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient. |
310 let nγ = γ1.norm(Radon); |
317 let nγ = γ1.norm(Radon); |
311 let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1); |
318 let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0); |
312 let t = ε * tconfig.tolerance_ω; |
319 let t = ε * tconfig.tolerance_mult_pos; |
313 if nγ*nΔ > t { |
320 if nγ*nΔ > t { |
314 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, |
321 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, |
315 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we |
322 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we |
316 // will not enter here. |
323 // will not enter here. |
317 *γ1 *= tconfig.adaptation * t / ( nγ * nΔ ); |
324 *γ1 *= tconfig.adaptation * t / ( nγ * nΔ ); |
377 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); |
384 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); |
378 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { |
385 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { |
379 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v |
386 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v |
380 // (the uniform Lipschitz factor of ∇v). |
387 // (the uniform Lipschitz factor of ∇v). |
381 // We assume that the residual is decreasing. |
388 // We assume that the residual is decreasing. |
382 Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * residual.norm2(), 0.0)), |
389 Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * b.norm2(), 0.0)), |
383 None => TransportStepLength::FullyAdaptive { |
390 None => TransportStepLength::FullyAdaptive { |
384 l : 0.0, |
391 l : 10.0 * F::EPSILON, // Start with something very small to estimate differentials |
385 max_transport : 0.0, |
392 max_transport : 0.0, |
386 g : calculate_θ |
393 g : calculate_θ |
387 }, |
394 }, |
388 }; |
395 }; |
389 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
396 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
413 v, &config.transport, |
420 v, &config.transport, |
414 ); |
421 ); |
415 |
422 |
416 // Solve finite-dimensional subproblem several times until the dual variable for the |
423 // Solve finite-dimensional subproblem several times until the dual variable for the |
417 // regularisation term conforms to the assumptions made for the transport above. |
424 // regularisation term conforms to the assumptions made for the transport above. |
418 let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop { |
425 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { |
419 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
426 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
420 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); |
427 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); |
421 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); |
428 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); |
422 |
429 |
423 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
430 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
428 ); |
435 ); |
429 |
436 |
430 // A posteriori transport adaptation. |
437 // A posteriori transport adaptation. |
431 if aposteriori_transport( |
438 if aposteriori_transport( |
432 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, |
439 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, |
|
440 None, |
433 ε, &config.transport |
441 ε, &config.transport |
434 ) { |
442 ) { |
435 break 'adapt_transport (maybe_d, within_tolerances, τv̆) |
443 break 'adapt_transport (maybe_d, within_tolerances, τv̆) |
436 } |
444 } |
437 }; |
445 }; |
446 assert_eq!(μ_base_masses.len(), γ1.len()); |
454 assert_eq!(μ_base_masses.len(), γ1.len()); |
447 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); |
455 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); |
448 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) |
456 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) |
449 }); |
457 }); |
450 |
458 |
451 // // Merge spikes. |
459 // Merge spikes. |
452 // // This expects the prune below to prune γ. |
460 // This crucially expects the merge routine to be stable with respect to spike locations, |
453 // // TODO: This may not work correctly in all cases. |
461 // and not to performing any pruning. That is be to done below simultaneously for γ. |
454 // let ins = &config.insertion; |
462 let ins = &config.insertion; |
455 // if ins.merge_now(&state) { |
463 if ins.merge_now(&state) { |
456 // if let SpikeMergingMethod::None = ins.merging { |
464 stats.merged += prox_penalty.merge_spikes( |
457 // } else { |
465 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, ins, ®, |
458 // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { |
466 Some(|μ̃ : &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), |
459 // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; |
467 ); |
460 // let mut d = &τv̆ + op𝒟.preapply(ν); |
468 } |
461 // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) |
|
462 // }); |
|
463 // } |
|
464 // } |
|
465 |
469 |
466 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the |
470 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the |
467 // latter needs to be pruned when μ is. |
471 // latter needs to be pruned when μ is. |
468 // TODO: This could do with a two-vector Vec::retain to avoid copies. |
472 // TODO: This could do with a two-vector Vec::retain to avoid copies. |
469 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); |
473 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); |