diff -r c5d8bd1a7728 -r 6316d68b58af src/sliding_fb.rs --- a/src/sliding_fb.rs Thu Jan 23 23:35:28 2025 +0100 +++ b/src/sliding_fb.rs Thu Jan 23 23:34:05 2025 +0100 @@ -35,7 +35,7 @@ use crate::regularisation::SlidingRegTerm; use crate::dataterm::{ L2Squared, - //DataTerm, + DataTerm, calculate_residual, calculate_residual2, }; @@ -49,10 +49,10 @@ pub θ0 : F, /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. pub adaptation : F, - /// Transport tolerance wrt. ω - pub tolerance_ω : F, - /// Transport tolerance wrt. ∇v - pub tolerance_dv : F, + /// A priori transport tolerance multiplier (C_pri) + pub tolerance_mult_pri : F, + /// A posteriori transport tolerance multiplier (C_pos) + pub tolerance_mult_pos : F, } #[replace_float_literals(F::cast_from(literal))] @@ -61,8 +61,8 @@ pub fn check(&self) { assert!(self.θ0 > 0.0); assert!(0.0 < self.adaptation && self.adaptation < 1.0); - assert!(self.tolerance_dv > 0.0); - assert!(self.tolerance_ω > 0.0); + assert!(self.tolerance_mult_pri > 0.0); + assert!(self.tolerance_mult_pos > 0.0); } } @@ -70,10 +70,10 @@ impl Default for TransportConfig { fn default() -> Self { TransportConfig { - θ0 : 0.01, + θ0 : 0.4, adaptation : 0.9, - tolerance_ω : 1000.0, // TODO: no idea what this should be - tolerance_dv : 1000.0, // TODO: no idea what this should be + tolerance_mult_pos : 100.0, + tolerance_mult_pri : 1000.0, } } } @@ -181,20 +181,26 @@ FullyAdaptive{ l : ref mut adaptive_ℓ_v, ref mut max_transport, g : ref calculate_θ } => { *max_transport = max_transport.max(γ1.norm(Radon)); let mut θ = calculate_θ(*adaptive_ℓ_v, *max_transport); - loop { - let θτ = τ * θ; + // Do two runs through the spikes to update θ, breaking if first run did not cause + // a change. + for _i in 0..=1 { + let mut changes = false; for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { let dv_x = v.differential(&δ.x); - ρ.x = δ.x - &dv_x * (ρ.α.signum() * θτ); - // Estimate Lipschitz factor of ∇v - let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2(); - *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v); + let g = &dv_x * (ρ.α.signum() * θ * τ); + ρ.x = δ.x - g; + let n = g.norm2(); + if n >= F::EPSILON { + // Estimate Lipschitz factor of ∇v + let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2() / n; + *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v); + θ = calculate_θ(*adaptive_ℓ_v, *max_transport); + changes = true + } } - let new_θ = calculate_θ(*adaptive_ℓ_v / tconfig.adaptation, *max_transport); - if new_θ <= θ { + if !changes { break } - θ = new_θ; } } } @@ -209,7 +215,7 @@ if n <= 0.0 || nr <= 0.0 { break } - let reduction_needed = nr - (ε * tconfig.tolerance_dv / n); + let reduction_needed = nr - (ε * tconfig.tolerance_mult_pri / n); if reduction_needed <= 0.0 { break } @@ -239,7 +245,7 @@ if n <= 0.0 || nr <= 0.0 { break } - let reduction_needed = nr - (ε * tconfig.tolerance_dv / n); + let reduction_needed = nr - (ε * tconfig.tolerance_mult_pri / n); if reduction_needed <= 0.0 { break } @@ -288,6 +294,7 @@ μ : &mut RNDM, μ_base_minus_γ0 : &mut RNDM, μ_base_masses : &Vec, + extra : Option, ε : F, tconfig : &TransportConfig ) -> bool @@ -308,8 +315,8 @@ // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1}, // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient. let nγ = γ1.norm(Radon); - let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1); - let t = ε * tconfig.tolerance_ω; + let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0); + let t = ε * tconfig.tolerance_mult_pos; if nγ*nΔ > t { // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, // this will guarantee that eventually ‖γ‖ decreases sufficiently that we @@ -379,9 +386,9 @@ // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v // (the uniform Lipschitz factor of ∇v). // We assume that the residual is decreasing. - Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * residual.norm2(), 0.0)), + Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * b.norm2(), 0.0)), None => TransportStepLength::FullyAdaptive { - l : 0.0, + l : 10.0 * F::EPSILON, // Start with something very small to estimate differentials max_transport : 0.0, g : calculate_θ }, @@ -415,7 +422,7 @@ // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. - let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop { + let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); @@ -430,6 +437,7 @@ // A posteriori transport adaptation. if aposteriori_transport( &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, + None, ε, &config.transport ) { break 'adapt_transport (maybe_d, within_tolerances, τv̆) @@ -448,20 +456,16 @@ (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) }); - // // Merge spikes. - // // This expects the prune below to prune γ. - // // TODO: This may not work correctly in all cases. - // let ins = &config.insertion; - // if ins.merge_now(&state) { - // if let SpikeMergingMethod::None = ins.merging { - // } else { - // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { - // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; - // let mut d = &τv̆ + op𝒟.preapply(ν); - // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) - // }); - // } - // } + // Merge spikes. + // This crucially expects the merge routine to be stable with respect to spike locations, + // and not to performing any pruning. That is be to done below simultaneously for γ. + let ins = &config.insertion; + if ins.merge_now(&state) { + stats.merged += prox_penalty.merge_spikes( + &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, ins, ®, + Some(|μ̃ : &RNDM| L2Squared.calculate_fit_op(μ̃, opA, b)), + ); + } // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the // latter needs to be pruned when μ is.