src/sliding_fb.rs

branch
dev
changeset 39
6316d68b58af
parent 37
c5d8bd1a7728
child 41
b6bdb6cb4d44
equal deleted inserted replaced
37:c5d8bd1a7728 39:6316d68b58af
33 }; 33 };
34 use crate::fb::*; 34 use crate::fb::*;
35 use crate::regularisation::SlidingRegTerm; 35 use crate::regularisation::SlidingRegTerm;
36 use crate::dataterm::{ 36 use crate::dataterm::{
37 L2Squared, 37 L2Squared,
38 //DataTerm, 38 DataTerm,
39 calculate_residual, 39 calculate_residual,
40 calculate_residual2, 40 calculate_residual2,
41 }; 41 };
42 //use crate::transport::TransportLipschitz; 42 //use crate::transport::TransportLipschitz;
43 43
47 pub struct TransportConfig<F : Float> { 47 pub struct TransportConfig<F : Float> {
48 /// Transport step length $θ$ normalised to $(0, 1)$. 48 /// Transport step length $θ$ normalised to $(0, 1)$.
49 pub θ0 : F, 49 pub θ0 : F,
50 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. 50 /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance.
51 pub adaptation : F, 51 pub adaptation : F,
52 /// Transport tolerance wrt. ω 52 /// A priori transport tolerance multiplier (C_pri)
53 pub tolerance_ω : F, 53 pub tolerance_mult_pri : F,
54 /// Transport tolerance wrt. ∇v 54 /// A posteriori transport tolerance multiplier (C_pos)
55 pub tolerance_dv : F, 55 pub tolerance_mult_pos : F,
56 } 56 }
57 57
58 #[replace_float_literals(F::cast_from(literal))] 58 #[replace_float_literals(F::cast_from(literal))]
59 impl <F : Float> TransportConfig<F> { 59 impl <F : Float> TransportConfig<F> {
60 /// Check that the parameters are ok. Panics if not. 60 /// Check that the parameters are ok. Panics if not.
61 pub fn check(&self) { 61 pub fn check(&self) {
62 assert!(self.θ0 > 0.0); 62 assert!(self.θ0 > 0.0);
63 assert!(0.0 < self.adaptation && self.adaptation < 1.0); 63 assert!(0.0 < self.adaptation && self.adaptation < 1.0);
64 assert!(self.tolerance_dv > 0.0); 64 assert!(self.tolerance_mult_pri > 0.0);
65 assert!(self.tolerance_ω > 0.0); 65 assert!(self.tolerance_mult_pos > 0.0);
66 } 66 }
67 } 67 }
68 68
69 #[replace_float_literals(F::cast_from(literal))] 69 #[replace_float_literals(F::cast_from(literal))]
70 impl<F : Float> Default for TransportConfig<F> { 70 impl<F : Float> Default for TransportConfig<F> {
71 fn default() -> Self { 71 fn default() -> Self {
72 TransportConfig { 72 TransportConfig {
73 θ0 : 0.01, 73 θ0 : 0.4,
74 adaptation : 0.9, 74 adaptation : 0.9,
75 tolerance_ω : 1000.0, // TODO: no idea what this should be 75 tolerance_mult_pos : 100.0,
76 tolerance_dv : 1000.0, // TODO: no idea what this should be 76 tolerance_mult_pri : 1000.0,
77 } 77 }
78 } 78 }
79 } 79 }
80 80
81 /// Settings for [`pointsource_sliding_fb_reg`]. 81 /// Settings for [`pointsource_sliding_fb_reg`].
179 } 179 }
180 }, 180 },
181 FullyAdaptive{ l : ref mut adaptive_ℓ_v, ref mut max_transport, g : ref calculate_θ } => { 181 FullyAdaptive{ l : ref mut adaptive_ℓ_v, ref mut max_transport, g : ref calculate_θ } => {
182 *max_transport = max_transport.max(γ1.norm(Radon)); 182 *max_transport = max_transport.max(γ1.norm(Radon));
183 let mut θ = calculate_θ(*adaptive_ℓ_v, *max_transport); 183 let mut θ = calculate_θ(*adaptive_ℓ_v, *max_transport);
184 loop { 184 // Do two runs through the spikes to update θ, breaking if first run did not cause
185 let θτ = τ * θ; 185 // a change.
186 for _i in 0..=1 {
187 let mut changes = false;
186 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { 188 for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
187 let dv_x = v.differential(&δ.x); 189 let dv_x = v.differential(&δ.x);
188 ρ.x = δ.x - &dv_x * (ρ.α.signum() * θτ); 190 let g = &dv_x * (ρ.α.signum() * θ * τ);
189 // Estimate Lipschitz factor of ∇v 191 ρ.x = δ.x - g;
190 let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2(); 192 let n = g.norm2();
191 *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v); 193 if n >= F::EPSILON {
194 // Estimate Lipschitz factor of ∇v
195 let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2() / n;
196 *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v);
197 θ = calculate_θ(*adaptive_ℓ_v, *max_transport);
198 changes = true
199 }
192 } 200 }
193 let new_θ = calculate_θ(*adaptive_ℓ_v / tconfig.adaptation, *max_transport); 201 if !changes {
194 if new_θ <= θ {
195 break 202 break
196 } 203 }
197 θ = new_θ;
198 } 204 }
199 } 205 }
200 } 206 }
201 207
202 // 2. Adjust transport mass, if needed. 208 // 2. Adjust transport mass, if needed.
207 let nr =γ1.norm(Radon); 213 let nr =γ1.norm(Radon);
208 let n = τ * 2.0 * opAnorm * (opAapply(&*γ1)-opAapply(&*μ)).norm2(); 214 let n = τ * 2.0 * opAnorm * (opAapply(&*γ1)-opAapply(&*μ)).norm2();
209 if n <= 0.0 || nr <= 0.0 { 215 if n <= 0.0 || nr <= 0.0 {
210 break 216 break
211 } 217 }
212 let reduction_needed = nr - (ε * tconfig.tolerance_dv / n); 218 let reduction_needed = nr - (ε * tconfig.tolerance_mult_pri / n);
213 if reduction_needed <= 0.0 { 219 if reduction_needed <= 0.0 {
214 break 220 break
215 } 221 }
216 let (min_nonzero, n_nonzero) = γ1.iter_masses() 222 let (min_nonzero, n_nonzero) = γ1.iter_masses()
217 .map(|α| α.abs()) 223 .map(|α| α.abs())
237 let na = a.norm2(); 243 let na = a.norm2();
238 let n = τ * 2.0 * opAnorm * na; 244 let n = τ * 2.0 * opAnorm * na;
239 if n <= 0.0 || nr <= 0.0 { 245 if n <= 0.0 || nr <= 0.0 {
240 break 246 break
241 } 247 }
242 let reduction_needed = nr - (ε * tconfig.tolerance_dv / n); 248 let reduction_needed = nr - (ε * tconfig.tolerance_mult_pri / n);
243 if reduction_needed <= 0.0 { 249 if reduction_needed <= 0.0 {
244 break 250 break
245 } 251 }
246 let mut max_d = 0.0; 252 let mut max_d = 0.0;
247 let mut max_d_ind = 0; 253 let mut max_d_ind = 0;
286 pub(crate) fn aposteriori_transport<F, const N : usize>( 292 pub(crate) fn aposteriori_transport<F, const N : usize>(
287 γ1 : &mut RNDM<F, N>, 293 γ1 : &mut RNDM<F, N>,
288 μ : &mut RNDM<F, N>, 294 μ : &mut RNDM<F, N>,
289 μ_base_minus_γ0 : &mut RNDM<F, N>, 295 μ_base_minus_γ0 : &mut RNDM<F, N>,
290 μ_base_masses : &Vec<F>, 296 μ_base_masses : &Vec<F>,
297 extra : Option<F>,
291 ε : F, 298 ε : F,
292 tconfig : &TransportConfig<F> 299 tconfig : &TransportConfig<F>
293 ) -> bool 300 ) -> bool
294 where F : Float + ToNalgebraRealField { 301 where F : Float + ToNalgebraRealField {
295 302
306 313
307 // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z). 314 // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z).
308 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1}, 315 // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1},
309 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient. 316 // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient.
310 let nγ = γ1.norm(Radon); 317 let nγ = γ1.norm(Radon);
311 let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1); 318 let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0);
312 let t = ε * tconfig.tolerance_ω; 319 let t = ε * tconfig.tolerance_mult_pos;
313 if nγ*nΔ > t { 320 if nγ*nΔ > t {
314 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, 321 // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
315 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we 322 // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
316 // will not enter here. 323 // will not enter here.
317 *γ1 *= tconfig.adaptation * t / ( nγ * nΔ ); 324 *γ1 *= tconfig.adaptation * t / ( nγ * nΔ );
377 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); 384 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v));
378 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { 385 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() {
379 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v 386 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v
380 // (the uniform Lipschitz factor of ∇v). 387 // (the uniform Lipschitz factor of ∇v).
381 // We assume that the residual is decreasing. 388 // We assume that the residual is decreasing.
382 Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * residual.norm2(), 0.0)), 389 Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * b.norm2(), 0.0)),
383 None => TransportStepLength::FullyAdaptive { 390 None => TransportStepLength::FullyAdaptive {
384 l : 0.0, 391 l : 10.0 * F::EPSILON, // Start with something very small to estimate differentials
385 max_transport : 0.0, 392 max_transport : 0.0,
386 g : calculate_θ 393 g : calculate_θ
387 }, 394 },
388 }; 395 };
389 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 396 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
413 v, &config.transport, 420 v, &config.transport,
414 ); 421 );
415 422
416 // Solve finite-dimensional subproblem several times until the dual variable for the 423 // Solve finite-dimensional subproblem several times until the dual variable for the
417 // regularisation term conforms to the assumptions made for the transport above. 424 // regularisation term conforms to the assumptions made for the transport above.
418 let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop { 425 let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop {
419 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) 426 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
420 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); 427 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
421 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); 428 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ);
422 429
423 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 430 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
428 ); 435 );
429 436
430 // A posteriori transport adaptation. 437 // A posteriori transport adaptation.
431 if aposteriori_transport( 438 if aposteriori_transport(
432 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, 439 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses,
440 None,
433 ε, &config.transport 441 ε, &config.transport
434 ) { 442 ) {
435 break 'adapt_transport (maybe_d, within_tolerances, τv̆) 443 break 'adapt_transport (maybe_d, within_tolerances, τv̆)
436 } 444 }
437 }; 445 };
446 assert_eq!(μ_base_masses.len(), γ1.len()); 454 assert_eq!(μ_base_masses.len(), γ1.len());
447 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); 455 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
448 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) 456 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
449 }); 457 });
450 458
451 // // Merge spikes. 459 // Merge spikes.
452 // // This expects the prune below to prune γ. 460 // This crucially expects the merge routine to be stable with respect to spike locations,
453 // // TODO: This may not work correctly in all cases. 461 // and not to performing any pruning. That is be to done below simultaneously for γ.
454 // let ins = &config.insertion; 462 let ins = &config.insertion;
455 // if ins.merge_now(&state) { 463 if ins.merge_now(&state) {
456 // if let SpikeMergingMethod::None = ins.merging { 464 stats.merged += prox_penalty.merge_spikes(
457 // } else { 465 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, ins, &reg,
458 // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { 466 Some(|μ̃ : &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
459 // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; 467 );
460 // let mut d = &τv̆ + op𝒟.preapply(ν); 468 }
461 // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins)
462 // });
463 // }
464 // }
465 469
466 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the 470 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
467 // latter needs to be pruned when μ is. 471 // latter needs to be pruned when μ is.
468 // TODO: This could do with a two-vector Vec::retain to avoid copies. 472 // TODO: This could do with a two-vector Vec::retain to avoid copies.
469 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); 473 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());

mercurial