src/sliding_pdps.rs

Mon, 23 Feb 2026 18:18:02 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Mon, 23 Feb 2026 18:18:02 -0500
branch
dev
changeset 64
d3be4f7ffd60
parent 63
7a8a55fd41c0
child 66
fe47ad484deb
permissions
-rw-r--r--

ATTEMPT, HAS BUGS: Make shifted_nonneg_soft_thresholding more readable

/*!
Solver for the point source localisation problem using a sliding
primal-dual proximal splitting method.
*/

use crate::fb::*;
use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
use crate::measures::merging::SpikeMerging;
use crate::measures::{DiscreteMeasure, RNDM};
use crate::plot::Plotter;
use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair};
use crate::regularisation::SlidingRegTerm;
use crate::sliding_fb::{SlidingFBConfig, Transport, TransportConfig, TransportStepLength};
use crate::types::*;
use alg_tools::convex::{Conjugable, Prox, Zero};
use alg_tools::direct_product::Pair;
use alg_tools::error::DynResult;
use alg_tools::euclidean::ClosedEuclidean;
use alg_tools::iterate::AlgIteratorFactory;
use alg_tools::linops::{
    BoundedLinear, IdOp, SimplyAdjointable, StaticEuclideanOriginGenerator, ZeroOp, AXPY, GEMV,
};
use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance};
use alg_tools::nalgebra_support::ToNalgebraRealField;
use alg_tools::norms::L2;
use anyhow::ensure;
use numeric_literals::replace_float_literals;
use serde::{Deserialize, Serialize};

/// Settings for [`pointsource_sliding_pdps_pair`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct SlidingPDPSConfig<F: Float> {
    /// Overall primal step length scaling.
    pub τ0: F,
    /// Primal step length scaling for additional variable.
    pub σp0: F,
    /// Dual step length scaling for additional variable.
    ///
    /// Taken zero for [`pointsource_sliding_fb_pair`].
    pub σd0: F,
    /// Transport parameters
    pub transport: TransportConfig<F>,
    /// Generic parameters
    pub insertion: InsertionConfig<F>,
    /// Guess for curvature bound calculations.
    pub guess: BoundedCurvatureGuess,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F: Float> Default for SlidingPDPSConfig<F> {
    fn default() -> Self {
        SlidingPDPSConfig {
            τ0: 0.99,
            σd0: 0.05,
            σp0: 0.99,
            transport: TransportConfig { θ0: 0.9, ..Default::default() },
            insertion: Default::default(),
            guess: BoundedCurvatureGuess::BetterThanZero,
        }
    }
}

type MeasureZ<F, Z, const N: usize> = Pair<RNDM<N, F>, Z>;

/// Iteratively solve the pointsource localisation with an additional variable
/// using sliding primal-dual proximal splitting
///
/// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`].
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_sliding_pdps_pair<
    F,
    I,
    S,
    Dat,
    Reg,
    P,
    Z,
    R,
    Y,
    Plot,
    /*KOpM, */ KOpZ,
    H,
    const N: usize,
>(
    f: &Dat,
    reg: &Reg,
    prox_penalty: &P,
    config: &SlidingPDPSConfig<F>,
    iterator: I,
    mut plotter: Plot,
    (μ0, mut z, mut y): (Option<RNDM<N, F>>, Z, Y),
    //opKμ : KOpM,
    opKz: &KOpZ,
    fnR: &R,
    fnH: &H,
) -> DynResult<MeasureZ<F, Z, N>>
where
    F: Float + ToNalgebraRealField,
    I: AlgIteratorFactory<IterInfo<F>>,
    Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>
        + BoundedCurvature<F>,
    S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
    for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
    //Pair<S, Z>: ClosedMul<F>,
    RNDM<N, F>: SpikeMerging<F>,
    Reg: SlidingRegTerm<Loc<N, F>, F>,
    P: ProxPenalty<Loc<N, F>, S, Reg, F>,
    // KOpM : Linear<RNDM<N, F>, Codomain=Y>
    //     + GEMV<F, RNDM<N, F>>
    //     + Preadjointable<
    //         RNDM<N, F>, Y,
    //         PreadjointCodomain = S,
    //     >
    //     + TransportLipschitz<L2Squared, FloatType=F>
    //     + AdjointProductBoundedBy<RNDM<N, F>, 𝒟, FloatType=F>,
    // for<'b> KOpM::Preadjoint<'b> : GEMV<F, Y>,
    // Since Z is Hilbert, we may just as well use adjoints for K_z.
    KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y>
        + GEMV<F, Z>
        + SimplyAdjointable<Z, Y, AdjointCodomain = Z>,
    KOpZ::SimpleAdjoint: GEMV<F, Y>,
    Y: ClosedEuclidean<F>,
    for<'b> &'b Y: Instance<Y>,
    Z: ClosedEuclidean<F>,
    for<'b> &'b Z: Instance<Z>,
    R: Prox<Z, Codomain = F>,
    H: Conjugable<Y, F, Codomain = F>,
    for<'b> H::Conjugate<'b>: Prox<Y>,
    Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
{
    // Check parameters
    /*ensure!(
        config.τ0 > 0.0
            && config.τ0 < 1.0
            && config.σp0 > 0.0
            && config.σp0 < 1.0
            && config.σd0 > 0.0
            && config.σp0 * config.σd0 <= 1.0,
        "Invalid step length parameters"
    );*/
    config.transport.check()?;

    // Initialise iterates
    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
    let mut γ = Transport::new();
    //let zero_z = z.similar_origin();

    // Set up parameters
    // TODO: maybe this PairNorm doesn't make sense here?
    // let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2);
    let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared);
    let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt();
    let nKz = opKz.opnorm_bound(L2, L2)?;
    let ℓ = 0.0;
    let idOpZ = IdOp::new();
    let opKz_adj = opKz.adjoint();
    let (l, l_z) = Pair(prox_penalty, &idOpZ).step_length_bound_pair(&f)?;

    // We need to satisfy
    //
    //     τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1
    //                                  ^^^^^^^^^^^^^^^^^^^^^^^^^
    // with 1 > σ_p L_z and 1 > τ L.
    //
    // To do so, we first solve σ_p and σ_d from standard PDPS step length condition
    // ^^^^^ < 1. then we solve τ from  the rest.
    // If opKZ is the zero operator, then we set σ_d = 0 for τ to be calculated correctly below.
    let σ_d = if nKz == 0.0 { 0.0 } else { config.σd0 / nKz };
    let σ_p = config.σp0 / (l_z + config.σd0 * nKz);
    // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0}
    // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L)
    // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0})
    let φ = 1.0 - config.σp0;
    let a = 1.0 - σ_p * l_z;
    let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l);
    let ψ = 1.0 - τ * l;
    let β = σ_p * config.σd0 * nKz / a; // σ_p * σ_d * (nKz * nK_z) / a;
    ensure!(β < 1.0);
    // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as:
    let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM);
    //  The factor two in the manuscript disappears due to the definition of 𝚹 being
    // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2.

    let mut θ_or_adaptive = match f.curvature_bound_components(config.guess) {
        (_, Err(_)) => TransportStepLength::Fixed(config.transport.θ0),
        (maybe_ℓ_F, Ok(transport_lip)) => {
            let calculate_θτ = move |ℓ_F, max_transport| {
                let ℓ_r = transport_lip * max_transport;
                config.transport.θ0 / ((ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport / τ)
            };
            match maybe_ℓ_F {
                Ok(ℓ_F) => TransportStepLength::AdaptiveMax {
                    l: ℓ_F, // TODO: could estimate computing the real reesidual
                    max_transport: 0.0,
                    g: calculate_θτ,
                },
                Err(_) => TransportStepLength::FullyAdaptive {
                    l: F::EPSILON, // Start with something very small to estimate differentials
                    max_transport: 0.0,
                    g: calculate_θτ,
                },
            }
        }
    };
    // Acceleration is not currently supported
    // let γ = dataterm.factor_of_strong_convexity();
    let ω = 1.0;

    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    let starH = fnH.conjugate();

    // Statistics
    let full_stats = |μ: &RNDM<N, F>, z: &Z, ε, stats| IterInfo {
        value: f.apply(Pair(μ, z))
            + fnR.apply(z)
            + reg.apply(μ)
            + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)),
        n_spikes: μ.len(),
        ε,
        // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
        ..stats
    };
    let mut stats = IterInfo::new();

    // Run the algorithm
    for state in iterator.iter_init(|| full_stats(&μ, &z, ε, stats.clone())) {
        // Calculate initial transport
        let Pair(v, _) = f.differential(Pair(&μ, &z));
        //opKμ.preadjoint().apply_add(&mut v, y);
        // We want to proceed as in Example 4.12 but with v and v̆ as in §5.
        // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have
        // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν,
        // where A_ν^* becomes a multiplier.
        // This is much easier with K_μ = 0, which is the only reason why are enforcing it.
        // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0.

        //dbg!(&μ);

        γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v);

        let mut attempts = 0;

        // Solve finite-dimensional subproblem several times until the dual variable for the
        // regularisation term conforms to the assumptions made for the transport above.
        let (maybe_d, _within_tolerances, mut τv̆, z_new, μ̆) = 'adapt_transport: loop {
            // Set initial guess for μ=μ^{k+1}.
            γ.μ̆_into(&mut μ);
            let μ̆ = μ.clone();

            // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
            let Pair(mut τv̆, τz̆) = f.differential(Pair(&μ̆, &z)) * τ;
            // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0);

            // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
            let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
                &mut μ,
                &mut τv̆,
                τ,
                ε,
                &config.insertion,
                &reg,
                &state,
                &mut stats,
            )?;

            // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}}
            let mut z_new = τz̆;
            opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ);
            z_new = fnR.prox(σ_p, z_new + &z);

            // A posteriori transport adaptation.
            if γ.aposteriori_transport(
                &μ,
                &μ̆,
                &mut τv̆,
                Some(z_new.dist2(&z)),
                ε,
                &config.transport,
                &mut attempts,
            ) {
                break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new, μ̆);
            }
        };

        γ.get_transport_stats(&mut stats, &μ);

        // Merge spikes.
        // This crucially expects the merge routine to be stable with respect to spike locations,
        // and not to performing any pruning. That is be to done below simultaneously for γ.
        if config.insertion.merge_now(&state) {
            stats.merged += prox_penalty.merge_spikes_no_fitness(
                &mut μ,
                &mut τv̆,
                &μ̆,
                τ,
                ε,
                &config.insertion,
                &reg,
            );
        }

        γ.prune_compat(&mut μ, &mut stats);

        // Do dual update
        // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0);    // y = y + σ_d K[(1+ω)(μ,z)^{k+1}]
        opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0);
        // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
        opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
        y = starH.prox(σ_d, y);
        z = z_new;

        // Update step length parameters
        // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ);

        // Give statistics if requested
        let iter = state.iteration();
        stats.this_iters += 1;

        state.if_verbose(|| {
            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
            full_stats(&μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new()))
        });

        // Update main tolerance for next iteration
        ε = tolerance.update(ε, iter);
    }

    let fit = |μ̃: &RNDM<N, F>| {
        f.apply(Pair(μ̃, &z)) /*+ fnR.apply(z) + reg.apply(μ)*/
        + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z))
    };

    μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v);
    μ.prune();
    Ok(Pair(μ, z))
}

/// Iteratively solve the pointsource localisation with an additional variable
/// using sliding forward-backward splitting.
///
/// The implementation uses [`pointsource_sliding_pdps_pair`] with appropriate dummy
/// variables, operators, and functions.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_sliding_fb_pair<F, I, S, Dat, Reg, P, Z, R, Plot, const N: usize>(
    f: &Dat,
    reg: &Reg,
    prox_penalty: &P,
    config: &SlidingFBConfig<F>,
    iterator: I,
    plotter: Plot,
    (μ0, z): (Option<RNDM<N, F>>, Z),
    //opKμ : KOpM,
    fnR: &R,
) -> DynResult<MeasureZ<F, Z, N>>
where
    F: Float + ToNalgebraRealField,
    I: AlgIteratorFactory<IterInfo<F>>,
    Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>
        + BoundedCurvature<F>,
    S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
    RNDM<N, F>: SpikeMerging<F>,
    Reg: SlidingRegTerm<Loc<N, F>, F>,
    P: ProxPenalty<Loc<N, F>, S, Reg, F>,
    for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
    Z: ClosedEuclidean<F> + AXPY + Clone,
    for<'b> &'b Z: Instance<Z>,
    R: Prox<Z, Codomain = F>,
    Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
    // We should not need to explicitly require this:
    for<'b> &'b Loc<0, F>: Instance<Loc<0, F>>,
    // Loc<0, F>: StaticEuclidean<Field = F, PrincipalE = Loc<0, F>>
    //     + Instance<Loc<0, F>>
    //     + VectorSpace<Field = F>,
{
    let opKz: ZeroOp<Z, Loc<0, F>, _, _, F> =
        ZeroOp::new_dualisable(StaticEuclideanOriginGenerator, z.dual_origin());
    let fnH = Zero::new();
    // Convert config. We don't implement From (that could be done with the o2o crate), as σd0
    // needs to be chosen in a general case; for the problem of this fucntion, anything is valid.
    let &SlidingFBConfig { τ0, σp0, insertion, transport, guess } = config;
    let pdps_config = SlidingPDPSConfig { τ0, σp0, insertion, transport, guess, σd0: 0.0 };

    pointsource_sliding_pdps_pair(
        f,
        reg,
        prox_penalty,
        &pdps_config,
        iterator,
        plotter,
        (μ0, z, Loc([])),
        &opKz,
        fnR,
        &fnH,
    )
}

mercurial