src/pdps.rs

branch
dev
changeset 61
4f468d35fa29
parent 39
6316d68b58af
child 63
7a8a55fd41c0
--- a/src/pdps.rs	Sun Apr 27 15:03:51 2025 -0500
+++ b/src/pdps.rs	Thu Feb 26 11:38:43 2026 -0500
@@ -38,50 +38,25 @@
 </p>
 */
 
-use numeric_literals::replace_float_literals;
-use serde::{Serialize, Deserialize};
-use nalgebra::DVector;
-use clap::ValueEnum;
-
-use alg_tools::iterate::AlgIteratorFactory;
-use alg_tools::euclidean::Euclidean;
-use alg_tools::linops::Mapping;
-use alg_tools::norms::{
-    Linfinity,
-    Projection,
-};
-use alg_tools::mapping::{RealMapping, Instance};
-use alg_tools::nalgebra_support::ToNalgebraRealField;
-use alg_tools::linops::AXPY;
-
-use crate::types::*;
-use crate::measures::{DiscreteMeasure, RNDM};
+use crate::fb::{postprocess, prune_with_stats};
+use crate::forward_model::ForwardModel;
 use crate::measures::merging::SpikeMerging;
-use crate::forward_model::{
-    ForwardModel,
-    AdjointProductBoundedBy,
-};
-use crate::plot::{
-    SeqPlotter,
-    Plotting,
-    PlotLookup
-};
-use crate::fb::{
-    postprocess,
-    prune_with_stats
-};
-pub use crate::prox_penalty::{
-    FBGenericConfig,
-    ProxPenalty
-};
+use crate::measures::merging::SpikeMergingMethod;
+use crate::measures::{DiscreteMeasure, RNDM};
+use crate::plot::Plotter;
+pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBoundPD};
 use crate::regularisation::RegTerm;
-use crate::dataterm::{
-    DataTerm,
-    L2Squared,
-    L1
-};
-use crate::measures::merging::SpikeMergingMethod;
-
+use crate::types::*;
+use alg_tools::convex::{Conjugable, ConvexMapping, Prox};
+use alg_tools::error::DynResult;
+use alg_tools::iterate::AlgIteratorFactory;
+use alg_tools::linops::{Mapping, AXPY};
+use alg_tools::mapping::{DataTerm, Instance};
+use alg_tools::nalgebra_support::ToNalgebraRealField;
+use anyhow::ensure;
+use clap::ValueEnum;
+use numeric_literals::replace_float_literals;
+use serde::{Deserialize, Serialize};
 
 /// Acceleration
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)]
@@ -93,15 +68,18 @@
     #[clap(name = "partial", help = "Partial acceleration, ω = 1/√(1+σ)")]
     Partial,
     /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed
-    #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")]
-    Full
+    #[clap(
+        name = "full",
+        help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed"
+    )]
+    Full,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
 impl Acceleration {
     /// PDPS parameter acceleration. Updates τ and σ and returns ω.
     /// This uses dual strong convexity, not primal.
-    fn accelerate<F : Float>(self, τ : &mut F, σ : &mut F, γ : F) -> F {
+    fn accelerate<F: Float>(self, τ: &mut F, σ: &mut F, γ: F) -> F {
         match self {
             Acceleration::None => 1.0,
             Acceleration::Partial => {
@@ -109,13 +87,13 @@
                 *σ *= ω;
                 *τ /= ω;
                 ω
-            },
+            }
             Acceleration::Full => {
                 let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt();
                 *σ *= ω;
                 *τ /= ω;
                 ω
-            },
+            }
         }
     }
 }
@@ -123,91 +101,35 @@
 /// Settings for [`pointsource_pdps_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
-pub struct PDPSConfig<F : Float> {
+pub struct PDPSConfig<F: Float> {
     /// Primal step length scaling. We must have `τ0 * σ0 < 1`.
-    pub τ0 : F,
+    pub τ0: F,
     /// Dual step length scaling. We must have `τ0 * σ0 < 1`.
-    pub σ0 : F,
+    pub σ0: F,
     /// Accelerate if available
-    pub acceleration : Acceleration,
+    pub acceleration: Acceleration,
     /// Generic parameters
-    pub generic : FBGenericConfig<F>,
+    pub generic: InsertionConfig<F>,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for PDPSConfig<F> {
+impl<F: Float> Default for PDPSConfig<F> {
     fn default() -> Self {
         let τ0 = 5.0;
         PDPSConfig {
             τ0,
-            σ0 : 0.99/τ0,
-            acceleration : Acceleration::Partial,
-            generic : FBGenericConfig {
-                merging : SpikeMergingMethod { enabled : true, ..Default::default() },
-                .. Default::default()
+            σ0: 0.99 / τ0,
+            acceleration: Acceleration::Partial,
+            generic: InsertionConfig {
+                merging: SpikeMergingMethod { enabled: true, ..Default::default() },
+                ..Default::default()
             },
         }
     }
 }
 
-/// Trait for data terms for the PDPS
-#[replace_float_literals(F::cast_from(literal))]
-pub trait PDPSDataTerm<F : Float, V, const N : usize> : DataTerm<F, V, N> {
-    /// Calculate some subdifferential at `x` for the conjugate
-    fn some_subdifferential(&self, x : V) -> V;
-
-    /// Factor of strong convexity of the conjugate
-    #[inline]
-    fn factor_of_strong_convexity(&self) -> F {
-        0.0
-    }
-
-    /// Perform dual update
-    fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F);
-}
-
-
-#[replace_float_literals(F::cast_from(literal))]
-impl<F, V, const N : usize> PDPSDataTerm<F, V, N>
-for L2Squared
-where
-    F : Float,
-    V :  Euclidean<F> + AXPY<F>,
-    for<'b> &'b V : Instance<V>,
-{
-    fn some_subdifferential(&self, x : V) -> V { x }
-
-    fn factor_of_strong_convexity(&self) -> F {
-        1.0
-    }
-
-    #[inline]
-    fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) {
-        y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ));
-    }
-}
-
-#[replace_float_literals(F::cast_from(literal))]
-impl<F : Float + nalgebra::RealField, const N : usize>
-PDPSDataTerm<F, DVector<F>, N>
-for L1 {
-    fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> {
-        // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well.
-        x.iter_mut()
-         .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) });
-        x
-    }
-
-     #[inline]
-     fn dual_update(&self, y : &mut DVector<F>, y_prev : &DVector<F>, σ : F) {
-        y.axpy(1.0, y_prev, σ);
-        y.proj_ball_mut(1.0, Linfinity);
-    }
-}
-
 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting.
 ///
-/// The `dataterm` should be either [`L1`] for norm-1 data term or [`L2Squared`] for norm-2-squared.
 /// The settings in `config` have their [respective documentation](PDPSConfig). `opA` is the
 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
@@ -218,42 +140,44 @@
 ///
 /// Returns the final iterate.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_pdps_reg<F, I, A, D, Reg, P, const N : usize>(
-    opA : &A,
-    b : &A::Observable,
-    reg : Reg,
-    prox_penalty : &P,
-    pdpsconfig : &PDPSConfig<F>,
-    iterator : I,
-    mut plotter : SeqPlotter<F, N>,
-    dataterm : D,
-) -> RNDM<F, N>
+pub fn pointsource_pdps_reg<'a, F, I, A, Phi, Reg, Plot, P, const N: usize>(
+    f: &'a DataTerm<F, RNDM<N, F>, A, Phi>,
+    reg: &Reg,
+    prox_penalty: &P,
+    pdpsconfig: &PDPSConfig<F>,
+    iterator: I,
+    mut plotter: Plot,
+    μ0 : Option<RNDM<N, F>>,
+) -> DynResult<RNDM<N, F>>
 where
-    F : Float + ToNalgebraRealField,
-    I : AlgIteratorFactory<IterInfo<F, N>>,
-    A : ForwardModel<RNDM<F, N>, F>
-        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>,
-    A::PreadjointCodomain : RealMapping<F, N>,
-    for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
-    PlotLookup : Plotting<N>,
-    RNDM<F, N> : SpikeMerging<F>,
-    D : PDPSDataTerm<F, A::Observable, N>,
-    Reg : RegTerm<F, N>,
-    P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
+    F: Float + ToNalgebraRealField,
+    I: AlgIteratorFactory<IterInfo<F>>,
+    A: ForwardModel<RNDM<N, F>, F>,
+    for<'b> &'b A::Observable: Instance<A::Observable>,
+    A::Observable: AXPY<Field = F>,
+    RNDM<N, F>: SpikeMerging<F>,
+    Reg: RegTerm<Loc<N, F>, F>,
+    Phi: Conjugable<A::Observable, F>,
+    for<'b> Phi::Conjugate<'b>: Prox<A::Observable>,
+    P: ProxPenalty<Loc<N, F>, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD<F, A, RNDM<N, F>>,
+    Plot: Plotter<P::ReturnMapping, A::PreadjointCodomain, RNDM<N, F>>,
 {
+    // Check parameters
+    ensure!(
+        pdpsconfig.τ0 > 0.0 && pdpsconfig.σ0 > 0.0 && pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0,
+        "Invalid step length parameters"
+    );
 
-    // Check parameters
-    assert!(pdpsconfig.τ0 > 0.0 &&
-            pdpsconfig.σ0 > 0.0 &&
-            pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0,
-            "Invalid step length parameters");
+    let opA = f.operator();
+    let b = f.data();
+    let phistar = f.fidelity().conjugate();
 
     // Set up parameters
     let config = &pdpsconfig.generic;
-    let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt();
+    let l = prox_penalty.step_length_bound_pd(opA)?;
     let mut τ = pdpsconfig.τ0 / l;
     let mut σ = pdpsconfig.σ0 / l;
-    let γ = dataterm.factor_of_strong_convexity();
+    let γ = phistar.factor_of_strong_convexity();
 
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
@@ -261,38 +185,35 @@
     let mut ε = tolerance.initial();
 
     // Initialise iterates
-    let mut μ = DiscreteMeasure::new();
-    let mut y = dataterm.some_subdifferential(-b);
-    let mut y_prev = y.clone();
-    let full_stats = |μ : &RNDM<F, N>, ε, stats| IterInfo {
-        value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ),
-        n_spikes : μ.len(),
+    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
+    let mut y = f.residual(&μ);
+    let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
+        value: f.apply(μ) + reg.apply(μ),
+        n_spikes: μ.len(),
         ε,
         // postprocessing: config.postprocessing.then(|| μ.clone()),
-        .. stats
+        ..stats
     };
     let mut stats = IterInfo::new();
 
     // Run the algorithm
     for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        let mut τv = opA.preadjoint().apply(y * τ);
+        // FIXME: the clone is required to avoid compiler overflows with reference-Mul requirement above.
+        let mut τv = opA.preadjoint().apply(y.clone() * τ);
 
         // Save current base point
         let μ_base = μ.clone();
-        
+
         // Insert and reweigh
         let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
-            &mut μ, &mut τv, &μ_base, None,
-            τ, ε,
-            config, &reg, &state, &mut stats
-        );
+            &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
+        )?;
 
         // Prune and possibly merge spikes
         if config.merge_now(&state) {
-            stats.merged += prox_penalty.merge_spikes_no_fitness(
-                &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg,
-            );
+            stats.merged += prox_penalty
+                .merge_spikes_no_fitness(&mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg);
         }
         stats.pruned += prune_with_stats(&mut μ);
 
@@ -300,11 +221,13 @@
         let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ);
 
         // Do dual update
-        y = b.clone();                          // y = b
-        opA.gemv(&mut y, 1.0 + ω, &μ, -1.0);    // y = A[(1+ω)μ^{k+1}]-b
-        opA.gemv(&mut y, -ω, &μ_base, 1.0);     // y = A[(1+ω)μ^{k+1} - ω μ^k]-b
-        dataterm.dual_update(&mut y, &y_prev, σ);
-        y_prev.copy_from(&y);
+        // y = y_prev + τb
+        y.axpy(τ, b, 1.0);
+        // y = y_prev - τ(A[(1+ω)μ^{k+1}]-b)
+        opA.gemv(&mut y, -τ * (1.0 + ω), &μ, 1.0);
+        // y = y_prev - τ(A[(1+ω)μ^{k+1} - ω μ^k]-b)
+        opA.gemv(&mut y, τ * ω, &μ_base, 1.0);
+        y = phistar.prox(τ, y);
 
         // Give statistics if requested
         let iter = state.iteration();
@@ -318,6 +241,5 @@
         ε = tolerance.update(ε, iter);
     }
 
-    postprocess(μ, config, dataterm, opA, b)
+    postprocess(μ, config, |μ| f.apply(μ))
 }
-

mercurial