src/forward_pdps.rs

Wed, 22 Apr 2026 23:43:00 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Wed, 22 Apr 2026 23:43:00 -0500
branch
dev
changeset 69
3ad8879ee6e1
parent 66
fe47ad484deb
permissions
-rw-r--r--

Bump versions

/*!
Solver for the point source localisation problem using a
primal-dual proximal splitting with a forward step.
*/

use crate::fb::*;
use crate::measures::merging::SpikeMerging;
use crate::measures::{DiscreteMeasure, RNDM};
use crate::plot::Plotter;
use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair};
use crate::regularisation::RegTerm;
use crate::types::*;
use alg_tools::convex::{Conjugable, Prox, Zero};
use alg_tools::direct_product::Pair;
use alg_tools::error::DynResult;
use alg_tools::euclidean::ClosedEuclidean;
use alg_tools::iterate::AlgIteratorFactory;
use alg_tools::linops::{BoundedLinear, IdOp, SimplyAdjointable, ZeroOp, AXPY, GEMV};
use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance};
use alg_tools::nalgebra_support::ToNalgebraRealField;
use alg_tools::norms::L2;
use anyhow::ensure;
use numeric_literals::replace_float_literals;
use serde::{Deserialize, Serialize};

/// Settings for [`pointsource_forward_pdps_pair`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct ForwardPDPSConfig<F: Float> {
    /// Overall primal step length scaling.
    pub τ0: F,
    /// Primal step length scaling for additional variable.
    pub σp0: F,
    /// Dual step length scaling for additional variable.
    ///
    /// Taken zero for [`pointsource_fb_pair`].
    pub σd0: F,
    /// Generic parameters
    pub insertion: InsertionConfig<F>,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F: Float> Default for ForwardPDPSConfig<F> {
    fn default() -> Self {
        ForwardPDPSConfig { τ0: 0.99, σd0: 0.05, σp0: 0.99, insertion: Default::default() }
    }
}

type MeasureZ<F, Z, const N: usize> = Pair<RNDM<N, F>, Z>;

/// Iteratively solve the pointsource localisation with an additional variable
/// using primal-dual proximal splitting with a forward step.
///
/// The problem is
/// $$
///    \min_{μ, z}~ F(μ, z) + R(z) + H(K_z z) + Q(μ),
/// $$
/// where
///   * The data term $F$ is given in `f`,
///   * the measure (Radon or positivity-constrained Radon) regulariser in $Q$ is given in `reg`,
///   * the functions $R$ and $H$ are given in `fnR` and `fnH`, and
///   * the operator $K_z$ in `opKz`.
///
/// This is dualised to
/// $$
///    \min_{μ, z}\max_y~ F(μ, z) + R(z) + ⟨K_z z, y⟩ + Q(μ) - H^*(y).
/// $$
///
/// The algorithm is controlled by:
///   * the proximal penalty in `prox_penalty`.
///   * the initial iterates in `z`, `y`
///   * The configuration in `config`.
///   * The `iterator` that controls stopping and reporting.
/// Moreover, plotting is performed by `plotter`.
///
/// The step lengths need to satisfy
/// $$
///     τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1
/// $$                               ^^^^^^^^^^^^^^^^^^^^^^^^^
/// with $1 > σ_p L_z$ and $1 > τ L$.
/// Since we are given “scalings” $τ_0$, $σ_{p,0}$, and $σ_{d,0}$ in `config`, we take
/// $σ_d=σ_{d,0}/‖K_z‖$, and $σ_p = σ_{p,0} / (L_z σ_d‖K_z‖)$. This satisfies the
/// part $[σ_p L_z + σ_pσ_d‖K_z‖^2] < 1$. Then with these cohices, we solve
/// $$
///     τ = τ_0 \frac{1 - σ_{p,0}}{(σ_d M (1-σ_p L_z) + (1 - σ_{p,0} L)}.
/// $$
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_forward_pdps_pair<
    F,
    I,
    S,
    Dat,
    Reg,
    P,
    Z,
    R,
    Y,
    /*KOpM, */ KOpZ,
    H,
    Plot,
    const N: usize,
>(
    f: &Dat,
    reg: &Reg,
    prox_penalty: &P,
    config: &ForwardPDPSConfig<F>,
    iterator: I,
    mut plotter: Plot,
    (μ0, mut z, mut y): (Option<RNDM<N, F>>, Z, Y),
    //opKμ : KOpM,
    opKz: &KOpZ,
    fnR: &R,
    fnH: &H,
) -> DynResult<MeasureZ<F, Z, N>>
where
    F: Float + ToNalgebraRealField,
    I: AlgIteratorFactory<IterInfo<F>>,
    Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>,
    //Pair<S, Z>: ClosedMul<F>, // Doesn't really need to be closed, if make this signature more complex…
    S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
    RNDM<N, F>: SpikeMerging<F>,
    Reg: RegTerm<Loc<N, F>, F>,
    P: ProxPenalty<Loc<N, F>, S, Reg, F>,
    for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
    KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y>
        + GEMV<F, Z, Y>
        + SimplyAdjointable<Z, Y, Codomain = Y, AdjointCodomain = Z>,
    KOpZ::SimpleAdjoint: GEMV<F, Y, Z>,
    Y: ClosedEuclidean<F>,
    for<'b> &'b Y: Instance<Y>,
    Z: ClosedEuclidean<F>,
    for<'b> &'b Z: Instance<Z>,
    R: Prox<Z, Codomain = F>,
    H: Conjugable<Y, F, Codomain = F>,
    for<'b> H::Conjugate<'b>: Prox<Y>,
    Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
{
    // Check parameters
    ensure!(
        config.τ0 > 0.0
            && config.τ0 < 1.0
            && config.σp0 > 0.0
            && config.σp0 < 1.0
            && config.σd0 >= 0.0
            && config.σp0 * config.σd0 <= 1.0,
        "Invalid step length parameters"
    );

    // Initialise iterates
    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());

    // Set up parameters
    let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt();
    let nKz = opKz.opnorm_bound(L2, L2)?;
    let is_fb = nKz == 0.0;
    let idOpZ = IdOp::new();
    let opKz_adj = opKz.adjoint();
    let (l, l_z) = Pair(prox_penalty, &idOpZ).step_length_bound_pair(&f)?;
    // We need to satisfy
    //
    //     τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1
    //                                  ^^^^^^^^^^^^^^^^^^^^^^^^^
    // with 1 > σ_p L_z and 1 > τ L.
    //
    // To do so, we first solve σ_p and σ_d from standard PDPS step length condition
    // ^^^^^ < 1. then we solve τ from  the rest.
    // If opKZ is the zero operator, then we set σ_d = 0 for τ to be calculated correctly below.
    let σ_d = if is_fb { 0.0 } else { config.σd0 / nKz };
    let σ_p = config.σp0 / (l_z + config.σd0 * nKz);
    // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0}
    // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L)
    // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0})
    let φ = 1.0 - config.σp0;
    let a = 1.0 - σ_p * l_z;
    let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l);
    // Acceleration is not currently supported
    // let γ = dataterm.factor_of_strong_convexity();
    let ω = 1.0;

    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    let starH = fnH.conjugate();

    // Statistics
    let full_stats = |μ: &RNDM<N, F>, z: &Z, ε, stats| IterInfo {
        value: f.apply(Pair(μ, z))
            + fnR.apply(z)
            + reg.apply(μ)
            + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)),
        n_spikes: μ.len(),
        ε,
        // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
        ..stats
    };
    let mut stats = IterInfo::new();

    // Run the algorithm
    for state in iterator.iter_init(|| full_stats(&μ, &z, ε, stats.clone())) {
        // Calculate initial transport
        let Pair(mut τv, τz) = f.differential(Pair(&μ, &z)) * τ;
        let μ_base = μ.clone();

        // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
        let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
            &mut μ,
            &mut τv,
            τ,
            ε,
            &config.insertion,
            &reg,
            &state,
            &mut stats,
        )?;

        stats.inserted += μ.len() - μ_base.len();

        // Merge spikes.
        // This crucially expects the merge routine to be stable with respect to spike locations,
        // and not to performing any pruning. That is be to done below simultaneously for γ.
        // Merge spikes.
        // This crucially expects the merge routine to be stable with respect to spike locations,
        // and not to performing any pruning. That is be to done below simultaneously for γ.
        let ins = &config.insertion;
        if ins.merge_now(&state) {
            stats.merged += prox_penalty.merge_spikes(
                &mut μ,
                &mut τv,
                &μ_base,
                τ,
                ε,
                ins,
                &reg,
                is_fb.then_some(|μ̃: &RNDM<N, F>| f.apply(Pair(μ̃, &z))),
            );
        }

        // Prune spikes with zero weight.
        stats.pruned += prune_with_stats(&mut μ);

        // Do z variable primal update
        let mut z_new = τz;
        opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ);
        z_new = fnR.prox(σ_p, z_new + &z);
        // Do dual update
        // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0);    // y = y + σ_d K[(1+ω)(μ,z)^{k+1}]
        opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0);
        // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
        opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
        y = starH.prox(σ_d, y);
        z = z_new;

        // Update step length parameters
        // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ);

        // Give statistics if requested
        let iter = state.iteration();
        stats.this_iters += 1;

        state.if_verbose(|| {
            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
            full_stats(&μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new()))
        });

        // Update main tolerance for next iteration
        ε = tolerance.update(ε, iter);
    }

    let fit = |μ̃: &RNDM<N, F>| {
        f.apply(Pair(μ̃, &z)) /*+ fnR.apply(z) + reg.apply(μ)*/
        + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z))
    };

    μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v);
    μ.prune();
    Ok(Pair(μ, z))
}

/// Iteratively solve the pointsource localisation with an additional variable
/// using forward-backward splitting.
///
/// The implementation uses [`pointsource_forward_pdps_pair`] with appropriate dummy
/// variables, operators, and functions.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_fb_pair<F, I, S, Dat, Reg, P, Z, R, Plot, const N: usize>(
    f: &Dat,
    reg: &Reg,
    prox_penalty: &P,
    config: &FBConfig<F>,
    iterator: I,
    plotter: Plot,
    (μ0, z): (Option<RNDM<N, F>>, Z),
    //opKμ : KOpM,
    fnR: &R,
) -> DynResult<MeasureZ<F, Z, N>>
where
    F: Float + ToNalgebraRealField,
    I: AlgIteratorFactory<IterInfo<F>>,
    Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>,
    S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
    RNDM<N, F>: SpikeMerging<F>,
    Reg: RegTerm<Loc<N, F>, F>,
    P: ProxPenalty<Loc<N, F>, S, Reg, F>,
    for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
    Z: ClosedEuclidean<F> + AXPY<Field = F> + Clone,
    for<'b> &'b Z: Instance<Z>,
    R: Prox<Z, Codomain = F>,
    Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
    // We should not need to explicitly require this:
    for<'b> &'b Loc<0, F>: Instance<Loc<0, F>>,
{
    let opKz = ZeroOp::new_dualisable(Loc([]), z.dual_origin());
    let fnH = Zero::new();
    // Convert config. We don't implement From (that could be done with the o2o crate), as σd0
    // needs to be chosen in a general case; for the problem of this fucntion, anything is valid.
    let &FBConfig { τ0, σp0, insertion } = config;
    let pdps_config = ForwardPDPSConfig { τ0, σp0, insertion, σd0: 0.0 };

    pointsource_forward_pdps_pair(
        f,
        reg,
        prox_penalty,
        &pdps_config,
        iterator,
        plotter,
        (μ0, z, Loc([])),
        &opKz,
        fnR,
        &fnH,
    )
}

mercurial