diff -r 6105b5cd8d89 -r f0e8704d3f0e src/sliding_pdps.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/sliding_pdps.rs Mon Feb 17 13:54:53 2025 -0500 @@ -0,0 +1,373 @@ +/*! +Solver for the point source localisation problem using a sliding +primal-dual proximal splitting method. +*/ + +use numeric_literals::replace_float_literals; +use serde::{Deserialize, Serialize}; +//use colored::Colorize; +//use nalgebra::{DVector, DMatrix}; +use std::iter::Iterator; + +use alg_tools::convex::{Conjugable, Prox}; +use alg_tools::direct_product::Pair; +use alg_tools::euclidean::Euclidean; +use alg_tools::iterate::AlgIteratorFactory; +use alg_tools::linops::{Adjointable, BoundedLinear, IdOp, AXPY, GEMV}; +use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping}; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::norms::{Dist, Norm}; +use alg_tools::norms::{PairNorm, L2}; + +use crate::forward_model::{AdjointProductPairBoundedBy, BoundedCurvature, ForwardModel}; +use crate::measures::merging::SpikeMerging; +use crate::measures::{DiscreteMeasure, Radon, RNDM}; +use crate::types::*; +// use crate::transport::TransportLipschitz; +//use crate::tolerance::Tolerance; +use crate::fb::*; +use crate::plot::{PlotLookup, Plotting, SeqPlotter}; +use crate::regularisation::SlidingRegTerm; +// use crate::dataterm::L2Squared; +use crate::dataterm::{calculate_residual, calculate_residual2}; +use crate::sliding_fb::{ + aposteriori_transport, initial_transport, TransportConfig, TransportStepLength, +}; + +/// Settings for [`pointsource_sliding_pdps_pair`]. +#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] +#[serde(default)] +pub struct SlidingPDPSConfig { + /// Primal step length scaling. + pub τ0: F, + /// Primal step length scaling. + pub σp0: F, + /// Dual step length scaling. + pub σd0: F, + /// Transport parameters + pub transport: TransportConfig, + /// Generic parameters + pub insertion: FBGenericConfig, +} + +#[replace_float_literals(F::cast_from(literal))] +impl Default for SlidingPDPSConfig { + fn default() -> Self { + SlidingPDPSConfig { + τ0: 0.99, + σd0: 0.05, + σp0: 0.99, + transport: TransportConfig { + θ0: 0.9, + ..Default::default() + }, + insertion: Default::default(), + } + } +} + +type MeasureZ = Pair, Z>; + +/// Iteratively solve the pointsource localisation with an additional variable +/// using sliding primal-dual proximal splitting +/// +/// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. +#[replace_float_literals(F::cast_from(literal))] +pub fn pointsource_sliding_pdps_pair< + F, + I, + A, + S, + Reg, + P, + Z, + R, + Y, + /*KOpM, */ KOpZ, + H, + const N: usize, +>( + opA: &A, + b: &A::Observable, + reg: Reg, + prox_penalty: &P, + config: &SlidingPDPSConfig, + iterator: I, + mut plotter: SeqPlotter, + //opKμ : KOpM, + opKz: &KOpZ, + fnR: &R, + fnH: &H, + mut z: Z, + mut y: Y, +) -> MeasureZ +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + A: ForwardModel, F, PairNorm, PreadjointCodomain = Pair> + + AdjointProductPairBoundedBy, P, IdOp, FloatType = F> + + BoundedCurvature, + S: DifferentiableRealMapping, + for<'b> &'b A::Observable: std::ops::Neg + Instance, + PlotLookup: Plotting, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, + P: ProxPenalty, + // KOpM : Linear, Codomain=Y> + // + GEMV> + // + Preadjointable< + // RNDM, Y, + // PreadjointCodomain = S, + // > + // + TransportLipschitz + // + AdjointProductBoundedBy, 𝒟, FloatType=F>, + // for<'b> KOpM::Preadjoint<'b> : GEMV, + // Since Z is Hilbert, we may just as well use adjoints for K_z. + KOpZ: BoundedLinear + + GEMV + + Adjointable, + for<'b> KOpZ::Adjoint<'b>: GEMV, + Y: AXPY + Euclidean + Clone + ClosedAdd, + for<'b> &'b Y: Instance, + Z: AXPY + Euclidean + Clone + Norm + Dist, + for<'b> &'b Z: Instance, + R: Prox, + H: Conjugable, + for<'b> H::Conjugate<'b>: Prox, +{ + // Check parameters + assert!( + config.τ0 > 0.0 + && config.τ0 < 1.0 + && config.σp0 > 0.0 + && config.σp0 < 1.0 + && config.σd0 > 0.0 + && config.σp0 * config.σd0 <= 1.0, + "Invalid step length parameters" + ); + config.transport.check(); + + // Initialise iterates + let mut μ = DiscreteMeasure::new(); + let mut γ1 = DiscreteMeasure::new(); + let mut residual = calculate_residual(Pair(&μ, &z), opA, b); + let zero_z = z.similar_origin(); + + // Set up parameters + // TODO: maybe this PairNorm doesn't make sense here? + // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); + let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared); + let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); + let nKz = opKz.opnorm_bound(L2, L2); + let ℓ = 0.0; + let opIdZ = IdOp::new(); + let (l, l_z) = opA + .adjoint_product_pair_bound(prox_penalty, &opIdZ) + .unwrap(); + // We need to satisfy + // + // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 + // ^^^^^^^^^^^^^^^^^^^^^^^^^ + // with 1 > σ_p L_z and 1 > τ L. + // + // To do so, we first solve σ_p and σ_d from standard PDPS step length condition + // ^^^^^ < 1. then we solve τ from the rest. + let σ_d = config.σd0 / nKz; + let σ_p = config.σp0 / (l_z + config.σd0 * nKz); + // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} + // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) + // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) + let φ = 1.0 - config.σp0; + let a = 1.0 - σ_p * l_z; + let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l); + let ψ = 1.0 - τ * l; + let β = σ_p * config.σd0 * nKz / a; // σ_p * σ_d * (nKz * nK_z) / a; + assert!(β < 1.0); + // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as: + let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); + // The factor two in the manuscript disappears due to the definition of 𝚹 being + // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. + let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); + let transport_lip = maybe_transport_lip.unwrap(); + let calculate_θ = |ℓ_F, max_transport| { + let ℓ_r = transport_lip * max_transport; + config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport) + }; + let mut θ_or_adaptive = match maybe_ℓ_F0 { + // We assume that the residual is decreasing. + Some(ℓ_F0) => TransportStepLength::AdaptiveMax { + l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual + max_transport: 0.0, + g: calculate_θ, + }, + None => TransportStepLength::FullyAdaptive { + l: F::EPSILON, + max_transport: 0.0, + g: calculate_θ, + }, + }; + // Acceleration is not currently supported + // let γ = dataterm.factor_of_strong_convexity(); + let ω = 1.0; + + // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled + // by τ compared to the conditional gradient approach. + let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling(); + let mut ε = tolerance.initial(); + + let starH = fnH.conjugate(); + + // Statistics + let full_stats = |residual: &A::Observable, μ: &RNDM, z: &Z, ε, stats| IterInfo { + value: residual.norm2_squared_div2() + + fnR.apply(z) + + reg.apply(μ) + + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)), + n_spikes: μ.len(), + ε, + // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), + ..stats + }; + let mut stats = IterInfo::new(); + + // Run the algorithm + for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { + // Calculate initial transport + let Pair(v, _) = opA.preadjoint().apply(&residual); + //opKμ.preadjoint().apply_add(&mut v, y); + // We want to proceed as in Example 4.12 but with v and v̆ as in §5. + // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have + // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν, + // where A_ν^* becomes a multiplier. + // This is much easier with K_μ = 0, which is the only reason why are enforcing it. + // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. + + let (μ_base_masses, mut μ_base_minus_γ0) = + initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); + + // Solve finite-dimensional subproblem several times until the dual variable for the + // regularisation term conforms to the assumptions made for the transport above. + let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { + // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) + let residual_μ̆ = + calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); + let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); + // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); + + // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. + let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, + &mut τv̆, + &γ1, + Some(&μ_base_minus_γ0), + τ, + ε, + &config.insertion, + ®, + &state, + &mut stats, + ); + + // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}} + let mut z_new = τz̆; + opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p / τ); + z_new = fnR.prox(σ_p, z_new + &z); + + // A posteriori transport adaptation. + if aposteriori_transport( + &mut γ1, + &mut μ, + &mut μ_base_minus_γ0, + &μ_base_masses, + Some(z_new.dist(&z, L2)), + ε, + &config.transport, + ) { + break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new); + } + }; + + stats.untransported_fraction = Some({ + assert_eq!(μ_base_masses.len(), γ1.len()); + let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0)); + let source = μ_base_masses.iter().map(|v| v.abs()).sum(); + (a + μ_base_minus_γ0.norm(Radon), b + source) + }); + stats.transport_error = Some({ + assert_eq!(μ_base_masses.len(), γ1.len()); + let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); + (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) + }); + + // Merge spikes. + // This crucially expects the merge routine to be stable with respect to spike locations, + // and not to performing any pruning. That is be to done below simultaneously for γ. + let ins = &config.insertion; + if ins.merge_now(&state) { + stats.merged += prox_penalty.merge_spikes_no_fitness( + &mut μ, + &mut τv̆, + &γ1, + Some(&μ_base_minus_γ0), + τ, + ε, + ins, + ®, + //Some(|μ̃ : &RNDM| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), + ); + } + + // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the + // latter needs to be pruned when μ is. + // TODO: This could do with a two-vector Vec::retain to avoid copies. + let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); + if μ_new.len() != μ.len() { + let mut μ_iter = μ.iter_spikes(); + γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); + stats.pruned += μ.len() - μ_new.len(); + μ = μ_new; + } + + // Do dual update + // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] + opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0); + // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b + opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b + y = starH.prox(σ_d, y); + z = z_new; + + // Update residual + residual = calculate_residual(Pair(&μ, &z), opA, b); + + // Update step length parameters + // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); + + // Give statistics if requested + let iter = state.iteration(); + stats.this_iters += 1; + + state.if_verbose(|| { + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); + full_stats( + &residual, + &μ, + &z, + ε, + std::mem::replace(&mut stats, IterInfo::new()), + ) + }); + + // Update main tolerance for next iteration + ε = tolerance.update(ε, iter); + } + + let fit = |μ̃: &RNDM| { + (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2() + //+ fnR.apply(z) + reg.apply(μ) + + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) + }; + + μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v); + μ.prune(); + Pair(μ, z) +}