diff -r 9738b51d90d7 -r 4f468d35fa29 src/forward_pdps.rs --- a/src/forward_pdps.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/forward_pdps.rs Thu Feb 26 11:38:43 2026 -0500 @@ -3,132 +3,158 @@ primal-dual proximal splitting with a forward step. */ -use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; - +use crate::fb::*; +use crate::measures::merging::SpikeMerging; +use crate::measures::{DiscreteMeasure, RNDM}; +use crate::plot::Plotter; +use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair}; +use crate::regularisation::RegTerm; +use crate::types::*; +use alg_tools::convex::{Conjugable, Prox, Zero}; +use alg_tools::direct_product::Pair; +use alg_tools::error::DynResult; +use alg_tools::euclidean::ClosedEuclidean; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::euclidean::Euclidean; -use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; -use alg_tools::norms::Norm; -use alg_tools::direct_product::Pair; +use alg_tools::linops::{BoundedLinear, IdOp, SimplyAdjointable, ZeroOp, AXPY, GEMV}; +use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance}; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::linops::{ - BoundedLinear, AXPY, GEMV, Adjointable, IdOp, -}; -use alg_tools::convex::{Conjugable, Prox}; -use alg_tools::norms::{L2, PairNorm}; - -use crate::types::*; -use crate::measures::{DiscreteMeasure, Radon, RNDM}; -use crate::measures::merging::SpikeMerging; -use crate::forward_model::{ - ForwardModel, - AdjointProductPairBoundedBy, -}; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; -use crate::fb::*; -use crate::regularisation::RegTerm; -use crate::dataterm::calculate_residual; +use alg_tools::norms::L2; +use anyhow::ensure; +use numeric_literals::replace_float_literals; +use serde::{Deserialize, Serialize}; /// Settings for [`pointsource_forward_pdps_pair`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct ForwardPDPSConfig { - /// Primal step length scaling. - pub τ0 : F, - /// Primal step length scaling. - pub σp0 : F, - /// Dual step length scaling. - pub σd0 : F, +pub struct ForwardPDPSConfig { + /// Overall primal step length scaling. + pub τ0: F, + /// Primal step length scaling for additional variable. + pub σp0: F, + /// Dual step length scaling for additional variable. + /// + /// Taken zero for [`pointsource_fb_pair`]. + pub σd0: F, /// Generic parameters - pub insertion : FBGenericConfig, + pub insertion: InsertionConfig, } #[replace_float_literals(F::cast_from(literal))] -impl Default for ForwardPDPSConfig { +impl Default for ForwardPDPSConfig { fn default() -> Self { - ForwardPDPSConfig { - τ0 : 0.99, - σd0 : 0.05, - σp0 : 0.99, - insertion : Default::default() - } + ForwardPDPSConfig { τ0: 0.99, σd0: 0.05, σp0: 0.99, insertion: Default::default() } } } -type MeasureZ = Pair, Z>; +type MeasureZ = Pair, Z>; /// Iteratively solve the pointsource localisation with an additional variable /// using primal-dual proximal splitting with a forward step. +/// +/// The problem is +/// $$ +/// \min_{μ, z}~ F(μ, z) + R(z) + H(K_z z) + Q(μ), +/// $$ +/// where +/// * The data term $F$ is given in `f`, +/// * the measure (Radon or positivity-constrained Radon) regulariser in $Q$ is given in `reg`, +/// * the functions $R$ and $H$ are given in `fnR` and `fnH`, and +/// * the operator $K_z$ in `opKz`. +/// +/// This is dualised to +/// $$ +/// \min_{μ, z}\max_y~ F(μ, z) + R(z) + ⟨K_z z, y⟩ + Q(μ) - H^*(y). +/// $$ +/// +/// The algorithm is controlled by: +/// * the proximal penalty in `prox_penalty`. +/// * the initial iterates in `z`, `y` +/// * The configuration in `config`. +/// * The `iterator` that controls stopping and reporting. +/// Moreover, plotting is performed by `plotter`. +/// +/// The step lengths need to satisfy +/// $$ +/// τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 +/// $$ ^^^^^^^^^^^^^^^^^^^^^^^^^ +/// with $1 > σ_p L_z$ and $1 > τ L$. +/// Since we are given “scalings” $τ_0$, $σ_{p,0}$, and $σ_{d,0}$ in `config`, we take +/// $σ_d=σ_{d,0}/‖K_z‖$, and $σ_p = σ_{p,0} / (L_z σ_d‖K_z‖)$. This satisfies the +/// part $[σ_p L_z + σ_pσ_d‖K_z‖^2] < 1$. Then with these cohices, we solve +/// $$ +/// τ = τ_0 \frac{1 - σ_{p,0}}{(σ_d M (1-σ_p L_z) + (1 - σ_{p,0} L)}. +/// $$ #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_forward_pdps_pair< - F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize + F, + I, + S, + Dat, + Reg, + P, + Z, + R, + Y, + /*KOpM, */ KOpZ, + H, + Plot, + const N: usize, >( - opA : &A, - b : &A::Observable, - reg : Reg, - prox_penalty : &P, - config : &ForwardPDPSConfig, - iterator : I, - mut plotter : SeqPlotter, + f: &Dat, + reg: &Reg, + prox_penalty: &P, + config: &ForwardPDPSConfig, + iterator: I, + mut plotter: Plot, + (μ0, mut z, mut y): (Option>, Z, Y), //opKμ : KOpM, - opKz : &KOpZ, - fnR : &R, - fnH : &H, - mut z : Z, - mut y : Y, -) -> MeasureZ + opKz: &KOpZ, + fnR: &R, + fnH: &H, +) -> DynResult> where - F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - A : ForwardModel< - MeasureZ, - F, - PairNorm, - PreadjointCodomain = Pair, - > - + AdjointProductPairBoundedBy, P, IdOp, FloatType=F>, - S: DifferentiableRealMapping, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : RegTerm, - P : ProxPenalty, - KOpZ : BoundedLinear - + GEMV - + Adjointable, - for<'b> KOpZ::Adjoint<'b> : GEMV, - Y : AXPY + Euclidean + Clone + ClosedAdd, - for<'b> &'b Y : Instance, - Z : AXPY + Euclidean + Clone + Norm, - for<'b> &'b Z : Instance, - R : Prox, - H : Conjugable, - for<'b> H::Conjugate<'b> : Prox, + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F, DerivativeDomain = Pair>, + //Pair: ClosedMul, // Doesn't really need to be closed, if make this signature more complex… + S: DifferentiableRealMapping + ClosedMul, + RNDM: SpikeMerging, + Reg: RegTerm, F>, + P: ProxPenalty, S, Reg, F>, + for<'a> Pair<&'a P, &'a IdOp>: StepLengthBoundPair, + KOpZ: BoundedLinear + + GEMV + + SimplyAdjointable, + KOpZ::SimpleAdjoint: GEMV, + Y: ClosedEuclidean, + for<'b> &'b Y: Instance, + Z: ClosedEuclidean, + for<'b> &'b Z: Instance, + R: Prox, + H: Conjugable, + for<'b> H::Conjugate<'b>: Prox, + Plot: Plotter>, { - // Check parameters - assert!(config.τ0 > 0.0 && - config.τ0 < 1.0 && - config.σp0 > 0.0 && - config.σp0 < 1.0 && - config.σd0 > 0.0 && - config.σp0 * config.σd0 <= 1.0, - "Invalid step length parameters"); + // ensure!( + // config.τ0 > 0.0 + // && config.τ0 < 1.0 + // && config.σp0 > 0.0 + // && config.σp0 < 1.0 + // && config.σd0 >= 0.0 + // && config.σp0 * config.σd0 <= 1.0, + // "Invalid step length parameters" + // ); // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut residual = calculate_residual(Pair(&μ, &z), opA, b); + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); // Set up parameters let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt(); - let nKz = opKz.opnorm_bound(L2, L2); - let opIdZ = IdOp::new(); - let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); + let nKz = opKz.opnorm_bound(L2, L2)?; + let idOpZ = IdOp::new(); + let opKz_adj = opKz.adjoint(); + let (l, l_z) = Pair(prox_penalty, &idOpZ).step_length_bound_pair(&f)?; // We need to satisfy // // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 @@ -137,14 +163,15 @@ // // To do so, we first solve σ_p and σ_d from standard PDPS step length condition // ^^^^^ < 1. then we solve τ from the rest. - let σ_d = config.σd0 / nKz; + // If opKZ is the zero operator, then we set σ_d = 0 for τ to be calculated correctly below. + let σ_d = if nKz == 0.0 { 0.0 } else { config.σd0 / nKz }; let σ_p = config.σp0 / (l_z + config.σd0 * nKz); // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) let φ = 1.0 - config.σp0; let a = 1.0 - σ_p * l_z; - let τ = config.τ0 * φ / ( σ_d * bigM * a + φ * l ); + let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l); // Acceleration is not currently supported // let γ = dataterm.factor_of_strong_convexity(); let ω = 1.0; @@ -157,28 +184,37 @@ let starH = fnH.conjugate(); // Statistics - let full_stats = |residual : &A::Observable, μ : &RNDM, z : &Z, ε, stats| IterInfo { - value : residual.norm2_squared_div2() + fnR.apply(z) - + reg.apply(μ) + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)), - n_spikes : μ.len(), + let full_stats = |μ: &RNDM, z: &Z, ε, stats| IterInfo { + value: f.apply(Pair(μ, z)) + + fnR.apply(z) + + reg.apply(μ) + + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)), + n_spikes: μ.len(), ε, // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), - .. stats + ..stats }; let mut stats = IterInfo::new(); // Run the algorithm - for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { + for state in iterator.iter_init(|| full_stats(&μ, &z, ε, stats.clone())) { // Calculate initial transport - let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ); + let Pair(mut τv, τz) = f.differential(Pair(&μ, &z)); let μ_base = μ.clone(); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( - &mut μ, &mut τv, &μ_base, None, - τ, ε, &config.insertion, - ®, &state, &mut stats, - ); + &mut μ, + &mut τv, + &μ_base, + None, + τ, + ε, + &config.insertion, + ®, + &state, + &mut stats, + )?; // Merge spikes. // This crucially expects the merge routine to be stable with respect to spike locations, @@ -189,8 +225,9 @@ let ins = &config.insertion; if ins.merge_now(&state) { stats.merged += prox_penalty.merge_spikes_no_fitness( - &mut μ, &mut τv, &μ_base, None, τ, ε, ins, ®, - //Some(|μ̃ : &RNDM| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), + &mut μ, &mut τv, &μ_base, None, τ, ε, ins, + ®, + //Some(|μ̃ : &RNDM| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), ); } @@ -199,19 +236,16 @@ // Do z variable primal update let mut z_new = τz; - opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ); + opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ); z_new = fnR.prox(σ_p, z_new + &z); // Do dual update // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] - opKz.gemv(&mut y, σ_d*(1.0 + ω), &z_new, 1.0); + opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0); // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b - opKz.gemv(&mut y, -σ_d*ω, z, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b + opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b y = starH.prox(σ_d, y); z = z_new; - // Update residual - residual = calculate_residual(Pair(&μ, &z), opA, b); - // Update step length parameters // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); @@ -221,20 +255,73 @@ state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); - full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) + full_stats(&μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) }); // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } - let fit = |μ̃ : &RNDM| { - (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2() - //+ fnR.apply(z) + reg.apply(μ) + let fit = |μ̃: &RNDM| { + f.apply(Pair(μ̃, &z)) /*+ fnR.apply(z) + reg.apply(μ)*/ + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) }; μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v); μ.prune(); - Pair(μ, z) + Ok(Pair(μ, z)) } + +/// Iteratively solve the pointsource localisation with an additional variable +/// using forward-backward splitting. +/// +/// The implementation uses [`pointsource_forward_pdps_pair`] with appropriate dummy +/// variables, operators, and functions. +#[replace_float_literals(F::cast_from(literal))] +pub fn pointsource_fb_pair( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + config: &FBConfig, + iterator: I, + plotter: Plot, + (μ0, z): (Option>, Z), + //opKμ : KOpM, + fnR: &R, +) -> DynResult> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F, DerivativeDomain = Pair>, + S: DifferentiableRealMapping + ClosedMul, + RNDM: SpikeMerging, + Reg: RegTerm, F>, + P: ProxPenalty, S, Reg, F>, + for<'a> Pair<&'a P, &'a IdOp>: StepLengthBoundPair, + Z: ClosedEuclidean + AXPY + Clone, + for<'b> &'b Z: Instance, + R: Prox, + Plot: Plotter>, + // We should not need to explicitly require this: + for<'b> &'b Loc<0, F>: Instance>, +{ + let opKz = ZeroOp::new_dualisable(Loc([]), z.dual_origin()); + let fnH = Zero::new(); + // Convert config. We don't implement From (that could be done with the o2o crate), as σd0 + // needs to be chosen in a general case; for the problem of this fucntion, anything is valid. + let &FBConfig { τ0, σp0, insertion } = config; + let pdps_config = ForwardPDPSConfig { τ0, σp0, insertion, σd0: 0.0 }; + + pointsource_forward_pdps_pair( + f, + reg, + prox_penalty, + &pdps_config, + iterator, + plotter, + (μ0, z, Loc([])), + &opKz, + fnR, + &fnH, + ) +}