src/sliding_pdps.rs

branch
dev
changeset 61
4f468d35fa29
parent 49
6b0db7251ebe
child 62
32328a74c790
child 63
7a8a55fd41c0
--- a/src/sliding_pdps.rs	Sun Apr 27 15:03:51 2025 -0500
+++ b/src/sliding_pdps.rs	Thu Feb 26 11:38:43 2026 -0500
@@ -3,51 +3,53 @@
 primal-dual proximal splitting method.
 */
 
+use crate::fb::*;
+use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
+use crate::measures::merging::SpikeMerging;
+use crate::measures::{DiscreteMeasure, Radon, RNDM};
+use crate::plot::Plotter;
+use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair};
+use crate::regularisation::SlidingRegTerm;
+use crate::sliding_fb::{
+    aposteriori_transport, initial_transport, SlidingFBConfig, TransportConfig, TransportStepLength,
+};
+use crate::types::*;
+use alg_tools::convex::{Conjugable, Prox, Zero};
+use alg_tools::direct_product::Pair;
+use alg_tools::error::DynResult;
+use alg_tools::euclidean::ClosedEuclidean;
+use alg_tools::iterate::AlgIteratorFactory;
+use alg_tools::linops::{
+    BoundedLinear, IdOp, SimplyAdjointable, StaticEuclideanOriginGenerator, ZeroOp, AXPY, GEMV,
+};
+use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance};
+use alg_tools::nalgebra_support::ToNalgebraRealField;
+use alg_tools::norms::{Norm, L2};
+use anyhow::ensure;
 use numeric_literals::replace_float_literals;
 use serde::{Deserialize, Serialize};
 //use colored::Colorize;
 //use nalgebra::{DVector, DMatrix};
 use std::iter::Iterator;
 
-use alg_tools::convex::{Conjugable, Prox};
-use alg_tools::direct_product::Pair;
-use alg_tools::euclidean::Euclidean;
-use alg_tools::iterate::AlgIteratorFactory;
-use alg_tools::linops::{Adjointable, BoundedLinear, IdOp, AXPY, GEMV};
-use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping};
-use alg_tools::nalgebra_support::ToNalgebraRealField;
-use alg_tools::norms::{Dist, Norm};
-use alg_tools::norms::{PairNorm, L2};
-
-use crate::forward_model::{AdjointProductPairBoundedBy, BoundedCurvature, ForwardModel};
-use crate::measures::merging::SpikeMerging;
-use crate::measures::{DiscreteMeasure, Radon, RNDM};
-use crate::types::*;
-// use crate::transport::TransportLipschitz;
-//use crate::tolerance::Tolerance;
-use crate::fb::*;
-use crate::plot::{PlotLookup, Plotting, SeqPlotter};
-use crate::regularisation::SlidingRegTerm;
-// use crate::dataterm::L2Squared;
-use crate::dataterm::{calculate_residual, calculate_residual2};
-use crate::sliding_fb::{
-    aposteriori_transport, initial_transport, TransportConfig, TransportStepLength,
-};
-
 /// Settings for [`pointsource_sliding_pdps_pair`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
 pub struct SlidingPDPSConfig<F: Float> {
-    /// Primal step length scaling.
+    /// Overall primal step length scaling.
     pub τ0: F,
-    /// Primal step length scaling.
+    /// Primal step length scaling for additional variable.
     pub σp0: F,
-    /// Dual step length scaling.
+    /// Dual step length scaling for additional variable.
+    ///
+    /// Taken zero for [`pointsource_sliding_fb_pair`].
     pub σd0: F,
     /// Transport parameters
     pub transport: TransportConfig<F>,
     /// Generic parameters
-    pub insertion: FBGenericConfig<F>,
+    pub insertion: InsertionConfig<F>,
+    /// Guess for curvature bound calculations.
+    pub guess: BoundedCurvatureGuess,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
@@ -57,16 +59,14 @@
             τ0: 0.99,
             σd0: 0.05,
             σp0: 0.99,
-            transport: TransportConfig {
-                θ0: 0.9,
-                ..Default::default()
-            },
+            transport: TransportConfig { θ0: 0.9, ..Default::default() },
             insertion: Default::default(),
+            guess: BoundedCurvatureGuess::BetterThanZero,
         }
     }
 }
 
-type MeasureZ<F, Z, const N: usize> = Pair<RNDM<F, N>, Z>;
+type MeasureZ<F, Z, const N: usize> = Pair<RNDM<N, F>, Z>;
 
 /// Iteratively solve the pointsource localisation with an additional variable
 /// using sliding primal-dual proximal splitting
@@ -76,67 +76,66 @@
 pub fn pointsource_sliding_pdps_pair<
     F,
     I,
-    A,
     S,
+    Dat,
     Reg,
     P,
     Z,
     R,
     Y,
+    Plot,
     /*KOpM, */ KOpZ,
     H,
     const N: usize,
 >(
-    opA: &A,
-    b: &A::Observable,
-    reg: Reg,
+    f: &Dat,
+    reg: &Reg,
     prox_penalty: &P,
     config: &SlidingPDPSConfig<F>,
     iterator: I,
-    mut plotter: SeqPlotter<F, N>,
+    mut plotter: Plot,
+    (μ0, mut z, mut y): (Option<RNDM<N, F>>, Z, Y),
     //opKμ : KOpM,
     opKz: &KOpZ,
     fnR: &R,
     fnH: &H,
-    mut z: Z,
-    mut y: Y,
-) -> MeasureZ<F, Z, N>
+) -> DynResult<MeasureZ<F, Z, N>>
 where
     F: Float + ToNalgebraRealField,
-    I: AlgIteratorFactory<IterInfo<F, N>>,
-    A: ForwardModel<MeasureZ<F, Z, N>, F, PairNorm<Radon, L2, L2>, PreadjointCodomain = Pair<S, Z>>
-        + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType = F>
-        + BoundedCurvature<FloatType = F>,
-    S: DifferentiableRealMapping<F, N>,
-    for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>,
-    PlotLookup: Plotting<N>,
-    RNDM<F, N>: SpikeMerging<F>,
-    Reg: SlidingRegTerm<F, N>,
-    P: ProxPenalty<F, S, Reg, N>,
-    // KOpM : Linear<RNDM<F, N>, Codomain=Y>
-    //     + GEMV<F, RNDM<F, N>>
+    I: AlgIteratorFactory<IterInfo<F>>,
+    Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>
+        + BoundedCurvature<F>,
+    S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
+    for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
+    //Pair<S, Z>: ClosedMul<F>,
+    RNDM<N, F>: SpikeMerging<F>,
+    Reg: SlidingRegTerm<Loc<N, F>, F>,
+    P: ProxPenalty<Loc<N, F>, S, Reg, F>,
+    // KOpM : Linear<RNDM<N, F>, Codomain=Y>
+    //     + GEMV<F, RNDM<N, F>>
     //     + Preadjointable<
-    //         RNDM<F, N>, Y,
+    //         RNDM<N, F>, Y,
     //         PreadjointCodomain = S,
     //     >
     //     + TransportLipschitz<L2Squared, FloatType=F>
-    //     + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
+    //     + AdjointProductBoundedBy<RNDM<N, F>, 𝒟, FloatType=F>,
     // for<'b> KOpM::Preadjoint<'b> : GEMV<F, Y>,
     // Since Z is Hilbert, we may just as well use adjoints for K_z.
     KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y>
         + GEMV<F, Z>
-        + Adjointable<Z, Y, AdjointCodomain = Z>,
-    for<'b> KOpZ::Adjoint<'b>: GEMV<F, Y>,
-    Y: AXPY<F> + Euclidean<F, Output = Y> + Clone + ClosedAdd,
+        + SimplyAdjointable<Z, Y, AdjointCodomain = Z>,
+    KOpZ::SimpleAdjoint: GEMV<F, Y>,
+    Y: ClosedEuclidean<F>,
     for<'b> &'b Y: Instance<Y>,
-    Z: AXPY<F, Owned = Z> + Euclidean<F, Output = Z> + Clone + Norm<F, L2> + Dist<F, L2>,
+    Z: ClosedEuclidean<F>,
     for<'b> &'b Z: Instance<Z>,
     R: Prox<Z, Codomain = F>,
     H: Conjugable<Y, F, Codomain = F>,
     for<'b> H::Conjugate<'b>: Prox<Y>,
+    Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
 {
     // Check parameters
-    assert!(
+    /*ensure!(
         config.τ0 > 0.0
             && config.τ0 < 1.0
             && config.σp0 > 0.0
@@ -144,26 +143,25 @@
             && config.σd0 > 0.0
             && config.σp0 * config.σd0 <= 1.0,
         "Invalid step length parameters"
-    );
-    config.transport.check();
+    );*/
+    config.transport.check()?;
 
     // Initialise iterates
-    let mut μ = DiscreteMeasure::new();
+    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
     let mut γ1 = DiscreteMeasure::new();
-    let mut residual = calculate_residual(Pair(&μ, &z), opA, b);
-    let zero_z = z.similar_origin();
+    //let zero_z = z.similar_origin();
 
     // Set up parameters
     // TODO: maybe this PairNorm doesn't make sense here?
     // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2);
     let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared);
     let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt();
-    let nKz = opKz.opnorm_bound(L2, L2);
+    let nKz = opKz.opnorm_bound(L2, L2)?;
     let ℓ = 0.0;
-    let opIdZ = IdOp::new();
-    let (l, l_z) = opA
-        .adjoint_product_pair_bound(prox_penalty, &opIdZ)
-        .unwrap();
+    let idOpZ = IdOp::new();
+    let opKz_adj = opKz.adjoint();
+    let (l, l_z) = Pair(prox_penalty, &idOpZ).step_length_bound_pair(&f)?;
+
     // We need to satisfy
     //
     //     τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1
@@ -172,7 +170,8 @@
     //
     // To do so, we first solve σ_p and σ_d from standard PDPS step length condition
     // ^^^^^ < 1. then we solve τ from  the rest.
-    let σ_d = config.σd0 / nKz;
+    // If opKZ is the zero operator, then we set σ_d = 0 for τ to be calculated correctly below.
+    let σ_d = if nKz == 0.0 { 0.0 } else { config.σd0 / nKz };
     let σ_p = config.σp0 / (l_z + config.σd0 * nKz);
     // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0}
     // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L)
@@ -182,29 +181,29 @@
     let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l);
     let ψ = 1.0 - τ * l;
     let β = σ_p * config.σd0 * nKz / a; // σ_p * σ_d * (nKz * nK_z) / a;
-    assert!(β < 1.0);
+    ensure!(β < 1.0);
     // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as:
     let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM);
     //  The factor two in the manuscript disappears due to the definition of 𝚹 being
     // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2.
-    let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components();
-    let transport_lip = maybe_transport_lip.unwrap();
+    let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess);
+    let transport_lip = maybe_transport_lip?;
     let calculate_θ = |ℓ_F, max_transport| {
         let ℓ_r = transport_lip * max_transport;
         config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport)
     };
-    let mut θ_or_adaptive = match maybe_ℓ_F0 {
+    let mut θ_or_adaptive = match maybe_ℓ_F {
         // We assume that the residual is decreasing.
-        Some(ℓ_F0) => TransportStepLength::AdaptiveMax {
-            l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual
+        Ok(ℓ_F) => TransportStepLength::AdaptiveMax {
+            l: ℓ_F, // TODO: could estimate computing the real reesidual
             max_transport: 0.0,
             g: calculate_θ,
         },
-        None => TransportStepLength::FullyAdaptive {
-            l: F::EPSILON,
-            max_transport: 0.0,
-            g: calculate_θ,
-        },
+        Err(_) => {
+            TransportStepLength::FullyAdaptive {
+                l: F::EPSILON, max_transport: 0.0, g: calculate_θ
+            }
+        }
     };
     // Acceleration is not currently supported
     // let γ = dataterm.factor_of_strong_convexity();
@@ -218,8 +217,8 @@
     let starH = fnH.conjugate();
 
     // Statistics
-    let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, z: &Z, ε, stats| IterInfo {
-        value: residual.norm2_squared_div2()
+    let full_stats = |μ: &RNDM<N, F>, z: &Z, ε, stats| IterInfo {
+        value: f.apply(Pair(μ, z))
             + fnR.apply(z)
             + reg.apply(μ)
             + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)),
@@ -231,9 +230,9 @@
     let mut stats = IterInfo::new();
 
     // Run the algorithm
-    for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) {
+    for state in iterator.iter_init(|| full_stats(&μ, &z, ε, stats.clone())) {
         // Calculate initial transport
-        let Pair(v, _) = opA.preadjoint().apply(&residual);
+        let Pair(v, _) = f.differential(Pair(&μ, &z));
         //opKμ.preadjoint().apply_add(&mut v, y);
         // We want to proceed as in Example 4.12 but with v and v̆ as in §5.
         // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have
@@ -242,6 +241,8 @@
         // This is much easier with K_μ = 0, which is the only reason why are enforcing it.
         // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0.
 
+        //dbg!(&μ);
+
         let (μ_base_masses, mut μ_base_minus_γ0) =
             initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v);
 
@@ -249,9 +250,11 @@
         // regularisation term conforms to the assumptions made for the transport above.
         let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop {
             // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
-            let residual_μ̆ =
-                calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b);
-            let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ);
+            // let residual_μ̆ =
+            //     calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b);
+            // let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ);
+            // TODO: might be able to optimise the measure sum working as calculate_residual2 above.
+            let Pair(mut τv̆, τz̆) = f.differential(Pair(&γ1 + &μ_base_minus_γ0, &z)) * τ;
             // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0);
 
             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
@@ -266,11 +269,11 @@
                 &reg,
                 &state,
                 &mut stats,
-            );
+            )?;
 
             // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}}
             let mut z_new = τz̆;
-            opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p / τ);
+            opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ);
             z_new = fnR.prox(σ_p, z_new + &z);
 
             // A posteriori transport adaptation.
@@ -279,7 +282,7 @@
                 &mut μ,
                 &mut μ_base_minus_γ0,
                 &μ_base_masses,
-                Some(z_new.dist(&z, L2)),
+                Some(z_new.dist2(&z)),
                 ε,
                 &config.transport,
             ) {
@@ -313,7 +316,7 @@
                 ε,
                 ins,
                 &reg,
-                //Some(|μ̃ : &RNDM<F, N>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()),
+                //Some(|μ̃ : &RNDM<N, F>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()),
             );
         }
 
@@ -336,9 +339,6 @@
         y = starH.prox(σ_d, y);
         z = z_new;
 
-        // Update residual
-        residual = calculate_residual(Pair(&μ, &z), opA, b);
-
         // Update step length parameters
         // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ);
 
@@ -348,26 +348,78 @@
 
         state.if_verbose(|| {
             plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
-            full_stats(
-                &residual,
-                &μ,
-                &z,
-                ε,
-                std::mem::replace(&mut stats, IterInfo::new()),
-            )
+            full_stats(&μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new()))
         });
 
         // Update main tolerance for next iteration
         ε = tolerance.update(ε, iter);
     }
 
-    let fit = |μ̃: &RNDM<F, N>| {
-        (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2()
-        //+ fnR.apply(z) + reg.apply(μ)
+    let fit = |μ̃: &RNDM<N, F>| {
+        f.apply(Pair(μ̃, &z)) /*+ fnR.apply(z) + reg.apply(μ)*/
         + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z))
     };
 
     μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v);
     μ.prune();
-    Pair(μ, z)
+    Ok(Pair(μ, z))
 }
+
+/// Iteratively solve the pointsource localisation with an additional variable
+/// using sliding forward-backward splitting.
+///
+/// The implementation uses [`pointsource_sliding_pdps_pair`] with appropriate dummy
+/// variables, operators, and functions.
+#[replace_float_literals(F::cast_from(literal))]
+pub fn pointsource_sliding_fb_pair<F, I, S, Dat, Reg, P, Z, R, Plot, const N: usize>(
+    f: &Dat,
+    reg: &Reg,
+    prox_penalty: &P,
+    config: &SlidingFBConfig<F>,
+    iterator: I,
+    plotter: Plot,
+    (μ0, z): (Option<RNDM<N, F>>, Z),
+    //opKμ : KOpM,
+    fnR: &R,
+) -> DynResult<MeasureZ<F, Z, N>>
+where
+    F: Float + ToNalgebraRealField,
+    I: AlgIteratorFactory<IterInfo<F>>,
+    Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>
+        + BoundedCurvature<F>,
+    S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
+    RNDM<N, F>: SpikeMerging<F>,
+    Reg: SlidingRegTerm<Loc<N, F>, F>,
+    P: ProxPenalty<Loc<N, F>, S, Reg, F>,
+    for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
+    Z: ClosedEuclidean<F> + AXPY + Clone,
+    for<'b> &'b Z: Instance<Z>,
+    R: Prox<Z, Codomain = F>,
+    Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
+    // We should not need to explicitly require this:
+    for<'b> &'b Loc<0, F>: Instance<Loc<0, F>>,
+    // Loc<0, F>: StaticEuclidean<Field = F, PrincipalE = Loc<0, F>>
+    //     + Instance<Loc<0, F>>
+    //     + VectorSpace<Field = F>,
+{
+    let opKz: ZeroOp<Z, Loc<0, F>, _, _, F> =
+        ZeroOp::new_dualisable(StaticEuclideanOriginGenerator, z.dual_origin());
+    let fnH = Zero::new();
+    // Convert config. We don't implement From (that could be done with the o2o crate), as σd0
+    // needs to be chosen in a general case; for the problem of this fucntion, anything is valid.
+    let &SlidingFBConfig { τ0, σp0, insertion, transport, guess } = config;
+    let pdps_config = SlidingPDPSConfig { τ0, σp0, insertion, transport, guess, σd0: 0.0 };
+
+    pointsource_sliding_pdps_pair(
+        f,
+        reg,
+        prox_penalty,
+        &pdps_config,
+        iterator,
+        plotter,
+        (μ0, z, Loc([])),
+        &opKz,
+        fnR,
+        &fnH,
+    )
+}

mercurial