diff -r 9738b51d90d7 -r 4f468d35fa29 src/pdps.rs --- a/src/pdps.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/pdps.rs Thu Feb 26 11:38:43 2026 -0500 @@ -38,50 +38,25 @@

*/ -use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; -use nalgebra::DVector; -use clap::ValueEnum; - -use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::euclidean::Euclidean; -use alg_tools::linops::Mapping; -use alg_tools::norms::{ - Linfinity, - Projection, -}; -use alg_tools::mapping::{RealMapping, Instance}; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::linops::AXPY; - -use crate::types::*; -use crate::measures::{DiscreteMeasure, RNDM}; +use crate::fb::{postprocess, prune_with_stats}; +use crate::forward_model::ForwardModel; use crate::measures::merging::SpikeMerging; -use crate::forward_model::{ - ForwardModel, - AdjointProductBoundedBy, -}; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; -use crate::fb::{ - postprocess, - prune_with_stats -}; -pub use crate::prox_penalty::{ - FBGenericConfig, - ProxPenalty -}; +use crate::measures::merging::SpikeMergingMethod; +use crate::measures::{DiscreteMeasure, RNDM}; +use crate::plot::Plotter; +pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBoundPD}; use crate::regularisation::RegTerm; -use crate::dataterm::{ - DataTerm, - L2Squared, - L1 -}; -use crate::measures::merging::SpikeMergingMethod; - +use crate::types::*; +use alg_tools::convex::{Conjugable, ConvexMapping, Prox}; +use alg_tools::error::DynResult; +use alg_tools::iterate::AlgIteratorFactory; +use alg_tools::linops::{Mapping, AXPY}; +use alg_tools::mapping::{DataTerm, Instance}; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use anyhow::ensure; +use clap::ValueEnum; +use numeric_literals::replace_float_literals; +use serde::{Deserialize, Serialize}; /// Acceleration #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)] @@ -93,15 +68,18 @@ #[clap(name = "partial", help = "Partial acceleration, ω = 1/√(1+σ)")] Partial, /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed - #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] - Full + #[clap( + name = "full", + help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed" + )] + Full, } #[replace_float_literals(F::cast_from(literal))] impl Acceleration { /// PDPS parameter acceleration. Updates τ and σ and returns ω. /// This uses dual strong convexity, not primal. - fn accelerate(self, τ : &mut F, σ : &mut F, γ : F) -> F { + fn accelerate(self, τ: &mut F, σ: &mut F, γ: F) -> F { match self { Acceleration::None => 1.0, Acceleration::Partial => { @@ -109,13 +87,13 @@ *σ *= ω; *τ /= ω; ω - }, + } Acceleration::Full => { let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt(); *σ *= ω; *τ /= ω; ω - }, + } } } } @@ -123,91 +101,35 @@ /// Settings for [`pointsource_pdps_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct PDPSConfig { +pub struct PDPSConfig { /// Primal step length scaling. We must have `τ0 * σ0 < 1`. - pub τ0 : F, + pub τ0: F, /// Dual step length scaling. We must have `τ0 * σ0 < 1`. - pub σ0 : F, + pub σ0: F, /// Accelerate if available - pub acceleration : Acceleration, + pub acceleration: Acceleration, /// Generic parameters - pub generic : FBGenericConfig, + pub generic: InsertionConfig, } #[replace_float_literals(F::cast_from(literal))] -impl Default for PDPSConfig { +impl Default for PDPSConfig { fn default() -> Self { let τ0 = 5.0; PDPSConfig { τ0, - σ0 : 0.99/τ0, - acceleration : Acceleration::Partial, - generic : FBGenericConfig { - merging : SpikeMergingMethod { enabled : true, ..Default::default() }, - .. Default::default() + σ0: 0.99 / τ0, + acceleration: Acceleration::Partial, + generic: InsertionConfig { + merging: SpikeMergingMethod { enabled: true, ..Default::default() }, + ..Default::default() }, } } } -/// Trait for data terms for the PDPS -#[replace_float_literals(F::cast_from(literal))] -pub trait PDPSDataTerm : DataTerm { - /// Calculate some subdifferential at `x` for the conjugate - fn some_subdifferential(&self, x : V) -> V; - - /// Factor of strong convexity of the conjugate - #[inline] - fn factor_of_strong_convexity(&self) -> F { - 0.0 - } - - /// Perform dual update - fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F); -} - - -#[replace_float_literals(F::cast_from(literal))] -impl PDPSDataTerm -for L2Squared -where - F : Float, - V : Euclidean + AXPY, - for<'b> &'b V : Instance, -{ - fn some_subdifferential(&self, x : V) -> V { x } - - fn factor_of_strong_convexity(&self) -> F { - 1.0 - } - - #[inline] - fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { - y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ)); - } -} - -#[replace_float_literals(F::cast_from(literal))] -impl -PDPSDataTerm, N> -for L1 { - fn some_subdifferential(&self, mut x : DVector) -> DVector { - // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. - x.iter_mut() - .for_each(|v| if *v != F::ZERO { *v = *v/::abs(*v) }); - x - } - - #[inline] - fn dual_update(&self, y : &mut DVector, y_prev : &DVector, σ : F) { - y.axpy(1.0, y_prev, σ); - y.proj_ball_mut(1.0, Linfinity); - } -} - /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. /// -/// The `dataterm` should be either [`L1`] for norm-1 data term or [`L2Squared`] for norm-2-squared. /// The settings in `config` have their [respective documentation](PDPSConfig). `opA` is the /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution @@ -218,42 +140,44 @@ /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_pdps_reg( - opA : &A, - b : &A::Observable, - reg : Reg, - prox_penalty : &P, - pdpsconfig : &PDPSConfig, - iterator : I, - mut plotter : SeqPlotter, - dataterm : D, -) -> RNDM +pub fn pointsource_pdps_reg<'a, F, I, A, Phi, Reg, Plot, P, const N: usize>( + f: &'a DataTerm, A, Phi>, + reg: &Reg, + prox_penalty: &P, + pdpsconfig: &PDPSConfig, + iterator: I, + mut plotter: Plot, + μ0 : Option>, +) -> DynResult> where - F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - A : ForwardModel, F> - + AdjointProductBoundedBy, P, FloatType=F>, - A::PreadjointCodomain : RealMapping, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - PlotLookup : Plotting, - RNDM : SpikeMerging, - D : PDPSDataTerm, - Reg : RegTerm, - P : ProxPenalty, + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + A: ForwardModel, F>, + for<'b> &'b A::Observable: Instance, + A::Observable: AXPY, + RNDM: SpikeMerging, + Reg: RegTerm, F>, + Phi: Conjugable, + for<'b> Phi::Conjugate<'b>: Prox, + P: ProxPenalty, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD>, + Plot: Plotter>, { + // Check parameters + ensure!( + pdpsconfig.τ0 > 0.0 && pdpsconfig.σ0 > 0.0 && pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, + "Invalid step length parameters" + ); - // Check parameters - assert!(pdpsconfig.τ0 > 0.0 && - pdpsconfig.σ0 > 0.0 && - pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, - "Invalid step length parameters"); + let opA = f.operator(); + let b = f.data(); + let phistar = f.fidelity().conjugate(); // Set up parameters let config = &pdpsconfig.generic; - let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); + let l = prox_penalty.step_length_bound_pd(opA)?; let mut τ = pdpsconfig.τ0 / l; let mut σ = pdpsconfig.σ0 / l; - let γ = dataterm.factor_of_strong_convexity(); + let γ = phistar.factor_of_strong_convexity(); // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. @@ -261,38 +185,35 @@ let mut ε = tolerance.initial(); // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut y = dataterm.some_subdifferential(-b); - let mut y_prev = y.clone(); - let full_stats = |μ : &RNDM, ε, stats| IterInfo { - value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ), - n_spikes : μ.len(), + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); + let mut y = f.residual(&μ); + let full_stats = |μ: &RNDM, ε, stats| IterInfo { + value: f.apply(μ) + reg.apply(μ), + n_spikes: μ.len(), ε, // postprocessing: config.postprocessing.then(|| μ.clone()), - .. stats + ..stats }; let mut stats = IterInfo::new(); // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(y * τ); + // FIXME: the clone is required to avoid compiler overflows with reference-Mul requirement above. + let mut τv = opA.preadjoint().apply(y.clone() * τ); // Save current base point let μ_base = μ.clone(); - + // Insert and reweigh let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( - &mut μ, &mut τv, &μ_base, None, - τ, ε, - config, ®, &state, &mut stats - ); + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, + )?; // Prune and possibly merge spikes if config.merge_now(&state) { - stats.merged += prox_penalty.merge_spikes_no_fitness( - &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, - ); + stats.merged += prox_penalty + .merge_spikes_no_fitness(&mut μ, &mut τv, &μ_base, None, τ, ε, config, ®); } stats.pruned += prune_with_stats(&mut μ); @@ -300,11 +221,13 @@ let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); // Do dual update - y = b.clone(); // y = b - opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b - opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b - dataterm.dual_update(&mut y, &y_prev, σ); - y_prev.copy_from(&y); + // y = y_prev + τb + y.axpy(τ, b, 1.0); + // y = y_prev - τ(A[(1+ω)μ^{k+1}]-b) + opA.gemv(&mut y, -τ * (1.0 + ω), &μ, 1.0); + // y = y_prev - τ(A[(1+ω)μ^{k+1} - ω μ^k]-b) + opA.gemv(&mut y, τ * ω, &μ_base, 1.0); + y = phistar.prox(τ, y); // Give statistics if requested let iter = state.iteration(); @@ -318,6 +241,5 @@ ε = tolerance.update(ε, iter); } - postprocess(μ, config, dataterm, opA, b) + postprocess(μ, config, |μ| f.apply(μ)) } -