src/forward_pdps.rs

branch
dev
changeset 61
4f468d35fa29
parent 39
6316d68b58af
child 62
32328a74c790
child 63
7a8a55fd41c0
--- a/src/forward_pdps.rs	Sun Apr 27 15:03:51 2025 -0500
+++ b/src/forward_pdps.rs	Thu Feb 26 11:38:43 2026 -0500
@@ -3,132 +3,158 @@
 primal-dual proximal splitting with a forward step.
 */
 
-use numeric_literals::replace_float_literals;
-use serde::{Serialize, Deserialize};
-
+use crate::fb::*;
+use crate::measures::merging::SpikeMerging;
+use crate::measures::{DiscreteMeasure, RNDM};
+use crate::plot::Plotter;
+use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair};
+use crate::regularisation::RegTerm;
+use crate::types::*;
+use alg_tools::convex::{Conjugable, Prox, Zero};
+use alg_tools::direct_product::Pair;
+use alg_tools::error::DynResult;
+use alg_tools::euclidean::ClosedEuclidean;
 use alg_tools::iterate::AlgIteratorFactory;
-use alg_tools::euclidean::Euclidean;
-use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance};
-use alg_tools::norms::Norm;
-use alg_tools::direct_product::Pair;
+use alg_tools::linops::{BoundedLinear, IdOp, SimplyAdjointable, ZeroOp, AXPY, GEMV};
+use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance};
 use alg_tools::nalgebra_support::ToNalgebraRealField;
-use alg_tools::linops::{
-    BoundedLinear, AXPY, GEMV, Adjointable, IdOp,
-};
-use alg_tools::convex::{Conjugable, Prox};
-use alg_tools::norms::{L2, PairNorm};
-
-use crate::types::*;
-use crate::measures::{DiscreteMeasure, Radon, RNDM};
-use crate::measures::merging::SpikeMerging;
-use crate::forward_model::{
-    ForwardModel,
-    AdjointProductPairBoundedBy,
-};
-use crate::plot::{
-    SeqPlotter,
-    Plotting,
-    PlotLookup
-};
-use crate::fb::*;
-use crate::regularisation::RegTerm;
-use crate::dataterm::calculate_residual;
+use alg_tools::norms::L2;
+use anyhow::ensure;
+use numeric_literals::replace_float_literals;
+use serde::{Deserialize, Serialize};
 
 /// Settings for [`pointsource_forward_pdps_pair`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
-pub struct ForwardPDPSConfig<F : Float> {
-    /// Primal step length scaling.
-    pub τ0 : F,
-    /// Primal step length scaling.
-    pub σp0 : F,
-    /// Dual step length scaling.
-    pub σd0 : F,
+pub struct ForwardPDPSConfig<F: Float> {
+    /// Overall primal step length scaling.
+    pub τ0: F,
+    /// Primal step length scaling for additional variable.
+    pub σp0: F,
+    /// Dual step length scaling for additional variable.
+    ///
+    /// Taken zero for [`pointsource_fb_pair`].
+    pub σd0: F,
     /// Generic parameters
-    pub insertion : FBGenericConfig<F>,
+    pub insertion: InsertionConfig<F>,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for ForwardPDPSConfig<F> {
+impl<F: Float> Default for ForwardPDPSConfig<F> {
     fn default() -> Self {
-        ForwardPDPSConfig {
-            τ0 : 0.99,
-            σd0 : 0.05,
-            σp0 : 0.99,
-            insertion : Default::default()
-        }
+        ForwardPDPSConfig { τ0: 0.99, σd0: 0.05, σp0: 0.99, insertion: Default::default() }
     }
 }
 
-type MeasureZ<F, Z, const N : usize> = Pair<RNDM<F, N>, Z>;
+type MeasureZ<F, Z, const N: usize> = Pair<RNDM<N, F>, Z>;
 
 /// Iteratively solve the pointsource localisation with an additional variable
 /// using primal-dual proximal splitting with a forward step.
+///
+/// The problem is
+/// $$
+///    \min_{μ, z}~ F(μ, z) + R(z) + H(K_z z) + Q(μ),
+/// $$
+/// where
+///   * The data term $F$ is given in `f`,
+///   * the measure (Radon or positivity-constrained Radon) regulariser in $Q$ is given in `reg`,
+///   * the functions $R$ and $H$ are given in `fnR` and `fnH`, and
+///   * the operator $K_z$ in `opKz`.
+///
+/// This is dualised to
+/// $$
+///    \min_{μ, z}\max_y~ F(μ, z) + R(z) + ⟨K_z z, y⟩ + Q(μ) - H^*(y).
+/// $$
+///
+/// The algorithm is controlled by:
+///   * the proximal penalty in `prox_penalty`.
+///   * the initial iterates in `z`, `y`
+///   * The configuration in `config`.
+///   * The `iterator` that controls stopping and reporting.
+/// Moreover, plotting is performed by `plotter`.
+///
+/// The step lengths need to satisfy
+/// $$
+///     τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1
+/// $$                               ^^^^^^^^^^^^^^^^^^^^^^^^^
+/// with $1 > σ_p L_z$ and $1 > τ L$.
+/// Since we are given “scalings” $τ_0$, $σ_{p,0}$, and $σ_{d,0}$ in `config`, we take
+/// $σ_d=σ_{d,0}/‖K_z‖$, and $σ_p = σ_{p,0} / (L_z σ_d‖K_z‖)$. This satisfies the
+/// part $[σ_p L_z + σ_pσ_d‖K_z‖^2] < 1$. Then with these cohices, we solve
+/// $$
+///     τ = τ_0 \frac{1 - σ_{p,0}}{(σ_d M (1-σ_p L_z) + (1 - σ_{p,0} L)}.
+/// $$
 #[replace_float_literals(F::cast_from(literal))]
 pub fn pointsource_forward_pdps_pair<
-    F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize
+    F,
+    I,
+    S,
+    Dat,
+    Reg,
+    P,
+    Z,
+    R,
+    Y,
+    /*KOpM, */ KOpZ,
+    H,
+    Plot,
+    const N: usize,
 >(
-    opA : &A,
-    b : &A::Observable,
-    reg : Reg,
-    prox_penalty : &P,
-    config : &ForwardPDPSConfig<F>,
-    iterator : I,
-    mut plotter : SeqPlotter<F, N>,
+    f: &Dat,
+    reg: &Reg,
+    prox_penalty: &P,
+    config: &ForwardPDPSConfig<F>,
+    iterator: I,
+    mut plotter: Plot,
+    (μ0, mut z, mut y): (Option<RNDM<N, F>>, Z, Y),
     //opKμ : KOpM,
-    opKz : &KOpZ,
-    fnR : &R,
-    fnH : &H,
-    mut z : Z,
-    mut y : Y,
-) -> MeasureZ<F, Z, N>
+    opKz: &KOpZ,
+    fnR: &R,
+    fnH: &H,
+) -> DynResult<MeasureZ<F, Z, N>>
 where
-    F : Float + ToNalgebraRealField,
-    I : AlgIteratorFactory<IterInfo<F, N>>,
-    A : ForwardModel<
-            MeasureZ<F, Z, N>,
-            F,
-            PairNorm<Radon, L2, L2>,
-            PreadjointCodomain = Pair<S, Z>,
-        >
-        + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F>,
-    S: DifferentiableRealMapping<F, N>,
-    for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
-    PlotLookup : Plotting<N>,
-    RNDM<F, N> : SpikeMerging<F>,
-    Reg : RegTerm<F, N>,
-    P : ProxPenalty<F, S, Reg, N>,
-    KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y>
-        + GEMV<F, Z>
-        + Adjointable<Z, Y, AdjointCodomain = Z>,
-    for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>,
-    Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd,
-    for<'b> &'b Y : Instance<Y>,
-    Z : AXPY<F, Owned=Z> + Euclidean<F, Output=Z> + Clone + Norm<F, L2>,
-    for<'b> &'b Z : Instance<Z>,
-    R : Prox<Z, Codomain=F>,
-    H : Conjugable<Y, F, Codomain=F>,
-    for<'b> H::Conjugate<'b> : Prox<Y>,
+    F: Float + ToNalgebraRealField,
+    I: AlgIteratorFactory<IterInfo<F>>,
+    Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>,
+    //Pair<S, Z>: ClosedMul<F>, // Doesn't really need to be closed, if make this signature more complex…
+    S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
+    RNDM<N, F>: SpikeMerging<F>,
+    Reg: RegTerm<Loc<N, F>, F>,
+    P: ProxPenalty<Loc<N, F>, S, Reg, F>,
+    for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
+    KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y>
+        + GEMV<F, Z, Y>
+        + SimplyAdjointable<Z, Y, Codomain = Y, AdjointCodomain = Z>,
+    KOpZ::SimpleAdjoint: GEMV<F, Y, Z>,
+    Y: ClosedEuclidean<F>,
+    for<'b> &'b Y: Instance<Y>,
+    Z: ClosedEuclidean<F>,
+    for<'b> &'b Z: Instance<Z>,
+    R: Prox<Z, Codomain = F>,
+    H: Conjugable<Y, F, Codomain = F>,
+    for<'b> H::Conjugate<'b>: Prox<Y>,
+    Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
 {
-
     // Check parameters
-    assert!(config.τ0 > 0.0 &&
-            config.τ0 < 1.0 &&
-            config.σp0 > 0.0 &&
-            config.σp0 < 1.0 &&
-            config.σd0 > 0.0 &&
-            config.σp0 * config.σd0 <= 1.0,
-            "Invalid step length parameters");
+    // ensure!(
+    //     config.τ0 > 0.0
+    //         && config.τ0 < 1.0
+    //         && config.σp0 > 0.0
+    //         && config.σp0 < 1.0
+    //         && config.σd0 >= 0.0
+    //         && config.σp0 * config.σd0 <= 1.0,
+    //     "Invalid step length parameters"
+    // );
 
     // Initialise iterates
-    let mut μ = DiscreteMeasure::new();
-    let mut residual = calculate_residual(Pair(&μ, &z), opA, b);
+    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
 
     // Set up parameters
     let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt();
-    let nKz = opKz.opnorm_bound(L2, L2);
-    let opIdZ = IdOp::new();
-    let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap();
+    let nKz = opKz.opnorm_bound(L2, L2)?;
+    let idOpZ = IdOp::new();
+    let opKz_adj = opKz.adjoint();
+    let (l, l_z) = Pair(prox_penalty, &idOpZ).step_length_bound_pair(&f)?;
     // We need to satisfy
     //
     //     τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1
@@ -137,14 +163,15 @@
     //
     // To do so, we first solve σ_p and σ_d from standard PDPS step length condition
     // ^^^^^ < 1. then we solve τ from  the rest.
-    let σ_d = config.σd0 / nKz;
+    // If opKZ is the zero operator, then we set σ_d = 0 for τ to be calculated correctly below.
+    let σ_d = if nKz == 0.0 { 0.0 } else { config.σd0 / nKz };
     let σ_p = config.σp0 / (l_z + config.σd0 * nKz);
     // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0}
     // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L)
     // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0})
     let φ = 1.0 - config.σp0;
     let a = 1.0 - σ_p * l_z;
-    let τ = config.τ0 * φ / ( σ_d * bigM * a + φ * l );
+    let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l);
     // Acceleration is not currently supported
     // let γ = dataterm.factor_of_strong_convexity();
     let ω = 1.0;
@@ -157,28 +184,37 @@
     let starH = fnH.conjugate();
 
     // Statistics
-    let full_stats = |residual : &A::Observable, μ : &RNDM<F, N>, z : &Z, ε, stats| IterInfo {
-        value : residual.norm2_squared_div2() + fnR.apply(z)
-                + reg.apply(μ) + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)),
-        n_spikes : μ.len(),
+    let full_stats = |μ: &RNDM<N, F>, z: &Z, ε, stats| IterInfo {
+        value: f.apply(Pair(μ, z))
+            + fnR.apply(z)
+            + reg.apply(μ)
+            + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)),
+        n_spikes: μ.len(),
         ε,
         // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
-        .. stats
+        ..stats
     };
     let mut stats = IterInfo::new();
 
     // Run the algorithm
-    for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) {
+    for state in iterator.iter_init(|| full_stats(&μ, &z, ε, stats.clone())) {
         // Calculate initial transport
-        let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ);
+        let Pair(mut τv, τz) = f.differential(Pair(&μ, &z));
         let μ_base = μ.clone();
 
         // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
         let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
-            &mut μ, &mut τv, &μ_base, None,
-            τ, ε, &config.insertion,
-            &reg, &state, &mut stats,
-        );
+            &mut μ,
+            &mut τv,
+            &μ_base,
+            None,
+            τ,
+            ε,
+            &config.insertion,
+            &reg,
+            &state,
+            &mut stats,
+        )?;
 
         // Merge spikes.
         // This crucially expects the merge routine to be stable with respect to spike locations,
@@ -189,8 +225,9 @@
         let ins = &config.insertion;
         if ins.merge_now(&state) {
             stats.merged += prox_penalty.merge_spikes_no_fitness(
-                &mut μ, &mut τv, &μ_base, None, τ, ε, ins, &reg,
-                //Some(|μ̃ : &RNDM<F, N>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()),
+                &mut μ, &mut τv, &μ_base, None, τ, ε, ins,
+                &reg,
+                //Some(|μ̃ : &RNDM<N, F>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()),
             );
         }
 
@@ -199,19 +236,16 @@
 
         // Do z variable primal update
         let mut z_new = τz;
-        opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ);
+        opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ);
         z_new = fnR.prox(σ_p, z_new + &z);
         // Do dual update
         // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0);    // y = y + σ_d K[(1+ω)(μ,z)^{k+1}]
-        opKz.gemv(&mut y, σ_d*(1.0 + ω), &z_new, 1.0);
+        opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0);
         // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
-        opKz.gemv(&mut y, -σ_d*ω, z, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
+        opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
         y = starH.prox(σ_d, y);
         z = z_new;
 
-        // Update residual
-        residual = calculate_residual(Pair(&μ, &z), opA, b);
-
         // Update step length parameters
         // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ);
 
@@ -221,20 +255,73 @@
 
         state.if_verbose(|| {
             plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
-            full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new()))
+            full_stats(&μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new()))
         });
 
         // Update main tolerance for next iteration
         ε = tolerance.update(ε, iter);
     }
 
-    let fit = |μ̃ : &RNDM<F, N>| {
-        (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2()
-        //+ fnR.apply(z) + reg.apply(μ)
+    let fit = |μ̃: &RNDM<N, F>| {
+        f.apply(Pair(μ̃, &z)) /*+ fnR.apply(z) + reg.apply(μ)*/
         + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z))
     };
 
     μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v);
     μ.prune();
-    Pair(μ, z)
+    Ok(Pair(μ, z))
 }
+
+/// Iteratively solve the pointsource localisation with an additional variable
+/// using forward-backward splitting.
+///
+/// The implementation uses [`pointsource_forward_pdps_pair`] with appropriate dummy
+/// variables, operators, and functions.
+#[replace_float_literals(F::cast_from(literal))]
+pub fn pointsource_fb_pair<F, I, S, Dat, Reg, P, Z, R, Plot, const N: usize>(
+    f: &Dat,
+    reg: &Reg,
+    prox_penalty: &P,
+    config: &FBConfig<F>,
+    iterator: I,
+    plotter: Plot,
+    (μ0, z): (Option<RNDM<N, F>>, Z),
+    //opKμ : KOpM,
+    fnR: &R,
+) -> DynResult<MeasureZ<F, Z, N>>
+where
+    F: Float + ToNalgebraRealField,
+    I: AlgIteratorFactory<IterInfo<F>>,
+    Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>,
+    S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
+    RNDM<N, F>: SpikeMerging<F>,
+    Reg: RegTerm<Loc<N, F>, F>,
+    P: ProxPenalty<Loc<N, F>, S, Reg, F>,
+    for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
+    Z: ClosedEuclidean<F> + AXPY<Field = F> + Clone,
+    for<'b> &'b Z: Instance<Z>,
+    R: Prox<Z, Codomain = F>,
+    Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
+    // We should not need to explicitly require this:
+    for<'b> &'b Loc<0, F>: Instance<Loc<0, F>>,
+{
+    let opKz = ZeroOp::new_dualisable(Loc([]), z.dual_origin());
+    let fnH = Zero::new();
+    // Convert config. We don't implement From (that could be done with the o2o crate), as σd0
+    // needs to be chosen in a general case; for the problem of this fucntion, anything is valid.
+    let &FBConfig { τ0, σp0, insertion } = config;
+    let pdps_config = ForwardPDPSConfig { τ0, σp0, insertion, σd0: 0.0 };
+
+    pointsource_forward_pdps_pair(
+        f,
+        reg,
+        prox_penalty,
+        &pdps_config,
+        iterator,
+        plotter,
+        (μ0, z, Loc([])),
+        &opKz,
+        fnR,
+        &fnH,
+    )
+}

mercurial