--- a/src/pdps.rs Thu Aug 29 00:00:00 2024 -0500 +++ b/src/pdps.rs Tue Dec 31 09:25:45 2024 -0500 @@ -6,8 +6,7 @@ * Valkonen T. - _Proximal methods for point source localisation_, [arXiv:2212.02991](https://arxiv.org/abs/2212.02991). -The main routine is [`pointsource_pdps`]. It is based on specilisatinn of -[`generic_pointsource_fb_reg`] through relevant [`FBSpecialisation`] implementations. +The main routine is [`pointsource_pdps_reg`]. Both norm-2-squared and norm-1 data terms are supported. That is, implemented are solvers for <div> $$ @@ -37,10 +36,6 @@ For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$. For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$. </p> - -Based on zero initialisation for $μ$, we use the [`Subdifferentiable`] trait to make an -initialisation corresponding to the second part of the optimality conditions. -In the algorithm itself, standard proximal steps are taking with respect to $F\_0^* + ⟨b, ·⟩$. */ use numeric_literals::replace_float_literals; @@ -48,13 +43,10 @@ use nalgebra::DVector; use clap::ValueEnum; -use alg_tools::iterate::{ - AlgIteratorFactory, - AlgIteratorState, -}; +use alg_tools::iterate::AlgIteratorFactory; use alg_tools::loc::Loc; use alg_tools::euclidean::Euclidean; -use alg_tools::linops::Apply; +use alg_tools::linops::Mapping; use alg_tools::norms::{ Linfinity, Projection, @@ -69,14 +61,17 @@ SupportGenerator, LocalAnalysis, }; -use alg_tools::mapping::RealMapping; +use alg_tools::mapping::{RealMapping, Instance}; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::AXPY; use crate::types::*; -use crate::measures::DiscreteMeasure; +use crate::measures::{DiscreteMeasure, RNDM, Radon}; use crate::measures::merging::SpikeMerging; -use crate::forward_model::ForwardModel; +use crate::forward_model::{ + AdjointProductBoundedBy, + ForwardModel +}; use crate::seminorms::DiscreteMeasureOp; use crate::plot::{ SeqPlotter, @@ -87,7 +82,7 @@ FBGenericConfig, insert_and_reweigh, postprocess, - prune_and_maybe_simple_merge + prune_with_stats }; use crate::regularisation::RegTerm; use crate::dataterm::{ @@ -110,7 +105,30 @@ Full } -/// Settings for [`pointsource_pdps`]. +#[replace_float_literals(F::cast_from(literal))] +impl Acceleration { + /// PDPS parameter acceleration. Updates τ and σ and returns ω. + /// This uses dual strong convexity, not primal. + fn accelerate<F : Float>(self, τ : &mut F, σ : &mut F, γ : F) -> F { + match self { + Acceleration::None => 1.0, + Acceleration::Partial => { + let ω = 1.0 / (1.0 + γ * (*σ)).sqrt(); + *σ *= ω; + *τ /= ω; + ω + }, + Acceleration::Full => { + let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt(); + *σ *= ω; + *τ /= ω; + ω + }, + } + } +} + +/// Settings for [`pointsource_pdps_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct PDPSConfig<F : Float> { @@ -155,9 +173,13 @@ #[replace_float_literals(F::cast_from(literal))] -impl<F : Float, V : Euclidean<F> + AXPY<F>, const N : usize> -PDPSDataTerm<F, V, N> -for L2Squared { +impl<F, V, const N : usize> PDPSDataTerm<F, V, N> +for L2Squared +where + F : Float, + V : Euclidean<F> + AXPY<F>, + for<'b> &'b V : Instance<V>, +{ fn some_subdifferential(&self, x : V) -> V { x } fn factor_of_strong_convexity(&self) -> F { @@ -166,7 +188,7 @@ #[inline] fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { - y.axpy(1.0 / (1.0 + σ), &y_prev, σ / (1.0 + σ)); + y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ)); } } @@ -210,16 +232,13 @@ iterator : I, mut plotter : SeqPlotter<F, N>, dataterm : D, -) -> DiscreteMeasure<Loc<F, N>, F> +) -> RNDM<F, N> where F : Float + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> - + std::ops::Add<A::Observable, Output=A::Observable>, - //+ std::ops::Mul<F, Output=A::Observable>, // <-- FIXME: compiler overflow - A::Observable : std::ops::MulAssign<F>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> - + Lipschitz<&'a 𝒟, FloatType=F>, + A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> + + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, @@ -228,14 +247,20 @@ K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, PlotLookup : Plotting<N>, - DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, + RNDM<F, N> : SpikeMerging<F>, D : PDPSDataTerm<F, A::Observable, N>, Reg : RegTerm<F, N> { + // Check parameters + assert!(pdpsconfig.τ0 > 0.0 && + pdpsconfig.σ0 > 0.0 && + pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, + "Invalid step length parameters"); + // Set up parameters let config = &pdpsconfig.generic; - let op𝒟norm = op𝒟.opnorm_bound(); - let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt(); + let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); + let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt(); let mut τ = pdpsconfig.τ0 / l; let mut σ = pdpsconfig.σ0 / l; let γ = dataterm.factor_of_strong_convexity(); @@ -249,53 +274,42 @@ let mut μ = DiscreteMeasure::new(); let mut y = dataterm.some_subdifferential(-b); let mut y_prev = y.clone(); + let full_stats = |μ : &RNDM<F, N>, ε, stats| IterInfo { + value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ), + n_spikes : μ.len(), + ε, + // postprocessing: config.postprocessing.then(|| μ.clone()), + .. stats + }; let mut stats = IterInfo::new(); // Run the algorithm - iterator.iterate(|state| { + for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` - // has no significant overhead. For some reosn Rust doesn't allow us simply moving - // the residual and replacing it below before the end of this closure. - y *= -τ; - let r = std::mem::replace(&mut y, opA.empty_observable()); - let minus_τv = opA.preadjoint().apply(r); + let τv = opA.preadjoint().apply(y * τ); // Save current base point let μ_base = μ.clone(); // Insert and reweigh - let (d, within_tolerances) = insert_and_reweigh( - &mut μ, &minus_τv, &μ_base, None, + let (d, _within_tolerances) = insert_and_reweigh( + &mut μ, &τv, &μ_base, None, op𝒟, op𝒟norm, τ, ε, - config, ®, state, &mut stats + config, ®, &state, &mut stats ); // Prune and possibly merge spikes - prune_and_maybe_simple_merge( - &mut μ, &minus_τv, &μ_base, - op𝒟, - τ, ε, - config, ®, state, &mut stats - ); + if config.merge_now(&state) { + stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { + let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); + reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) + }); + } + stats.pruned += prune_with_stats(&mut μ); // Update step length parameters - let ω = match pdpsconfig.acceleration { - Acceleration::None => 1.0, - Acceleration::Partial => { - let ω = 1.0 / (1.0 + γ * σ).sqrt(); - σ = σ * ω; - τ = τ / ω; - ω - }, - Acceleration::Full => { - let ω = 1.0 / (1.0 + 2.0 * γ * σ).sqrt(); - σ = σ * ω; - τ = τ / ω; - ω - }, - }; + let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); // Do dual update y = b.clone(); // y = b @@ -304,32 +318,17 @@ dataterm.dual_update(&mut y, &y_prev, σ); y_prev.copy_from(&y); - // Update main tolerance for next iteration - let ε_prev = ε; - ε = tolerance.update(ε, state.iteration()); + // Give statistics if requested + let iter = state.iteration(); stats.this_iters += 1; - // Give function value if needed state.if_verbose(|| { - // Plot if so requested - plotter.plot_spikes( - format!("iter {} end; {}", state.iteration(), within_tolerances), &d, - "start".to_string(), Some(&minus_τv), - reg.target_bounds(τ, ε_prev), &μ, - ); - // Calculate mean inner iterations and reset relevant counters. - // Return the statistics - let res = IterInfo { - value : dataterm.calculate_fit_op(&μ, opA, b) + reg.apply(&μ), - n_spikes : μ.len(), - ε : ε_prev, - postprocessing: config.postprocessing.then(|| μ.clone()), - .. stats - }; - stats = IterInfo::new(); - res - }) - }); + plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); + full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) + }); + + ε = tolerance.update(ε, iter); + } postprocess(μ, config, dataterm, opA, b) }