src/pdps.rs

Thu, 26 Feb 2026 13:05:07 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 26 Feb 2026 13:05:07 -0500
branch
dev
changeset 66
fe47ad484deb
parent 63
7a8a55fd41c0
permissions
-rw-r--r--

Allow fitness merge when forward_pdps and sliding_pdps are used as forward-backward with aux variable.

/*!
Solver for the point source localisation problem with primal-dual proximal splitting.

This corresponds to the manuscript

 * Valkonen T. - _Proximal methods for point source localisation_,
   [arXiv:2212.02991](https://arxiv.org/abs/2212.02991).

The main routine is [`pointsource_pdps_reg`].
Both norm-2-squared and norm-1 data terms are supported. That is, implemented are solvers for
<div>
$$
    \min_{μ ∈ ℳ(Ω)}~ F_0(Aμ - b) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ),
$$
for both $F_0(y)=\frac{1}{2}\|y\|_2^2$ and  $F_0(y)=\|y\|_1$ with the forward operator
$A \in 𝕃(ℳ(Ω); ℝ^n)$.
</div>

## Approach

<p>
The problem above can be written as
$$
    \min_μ \max_y G(μ) + ⟨y, Aμ-b⟩ - F_0^*(μ),
$$
where $G(μ) = α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ)$.
The Fenchel–Rockafellar optimality conditions, employing the predual in $ℳ(Ω)$, are
$$
    0 ∈ A_*y + ∂G(μ)
    \quad\text{and}\quad
    Aμ - b ∈ ∂ F_0^*(y).
$$
The solution of the first part is as for forward-backward, treated in the manuscript.
This is the task of <code>generic_pointsource_fb</code>, where we use <code>FBSpecialisation</code>
to replace the specific residual $Aμ-b$ by $y$.
For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$.
For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$.
</p>
*/

use crate::fb::{postprocess, prune_with_stats};
use crate::forward_model::ForwardModel;
use crate::measures::merging::SpikeMerging;
use crate::measures::merging::SpikeMergingMethod;
use crate::measures::{DiscreteMeasure, RNDM};
use crate::plot::Plotter;
pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBoundPD};
use crate::regularisation::RegTerm;
use crate::types::*;
use alg_tools::convex::{Conjugable, ConvexMapping, Prox};
use alg_tools::error::DynResult;
use alg_tools::iterate::AlgIteratorFactory;
use alg_tools::linops::{Mapping, AXPY};
use alg_tools::mapping::{DataTerm, Instance};
use alg_tools::nalgebra_support::ToNalgebraRealField;
use anyhow::ensure;
use clap::ValueEnum;
use numeric_literals::replace_float_literals;
use serde::{Deserialize, Serialize};

/// Acceleration
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)]
pub enum Acceleration {
    /// No acceleration
    #[clap(name = "none")]
    None,
    /// Partial acceleration, $ω = 1/\sqrt{1+σ}$
    #[clap(name = "partial", help = "Partial acceleration, ω = 1/√(1+σ)")]
    Partial,
    /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed
    #[clap(
        name = "full",
        help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed"
    )]
    Full,
}

#[replace_float_literals(F::cast_from(literal))]
impl Acceleration {
    /// PDPS parameter acceleration. Updates τ and σ and returns ω.
    /// This uses dual strong convexity, not primal.
    fn accelerate<F: Float>(self, τ: &mut F, σ: &mut F, γ: F) -> F {
        match self {
            Acceleration::None => 1.0,
            Acceleration::Partial => {
                let ω = 1.0 / (1.0 + γ * (*σ)).sqrt();
                *σ *= ω;
                *τ /= ω;
                ω
            }
            Acceleration::Full => {
                let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt();
                *σ *= ω;
                *τ /= ω;
                ω
            }
        }
    }
}

/// Settings for [`pointsource_pdps_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct PDPSConfig<F: Float> {
    /// Primal step length scaling. We must have `τ0 * σ0 < 1`.
    pub τ0: F,
    /// Dual step length scaling. We must have `τ0 * σ0 < 1`.
    pub σ0: F,
    /// Accelerate if available
    pub acceleration: Acceleration,
    /// Generic parameters
    pub generic: InsertionConfig<F>,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F: Float> Default for PDPSConfig<F> {
    fn default() -> Self {
        let τ0 = 5.0;
        PDPSConfig {
            τ0,
            σ0: 0.99 / τ0,
            acceleration: Acceleration::Partial,
            generic: InsertionConfig {
                merging: SpikeMergingMethod { enabled: true, ..Default::default() },
                ..Default::default()
            },
        }
    }
}

/// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting.
///
/// The settings in `config` have their [respective documentation](PDPSConfig). `opA` is the
/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
/// as documented in [`alg_tools::iterate`].
///
/// For the mathematical formulation, see the [module level](self) documentation and the manuscript.
///
/// Returns the final iterate.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_pdps_reg<'a, F, I, A, Phi, Reg, Plot, P, const N: usize>(
    f: &'a DataTerm<F, RNDM<N, F>, A, Phi>,
    reg: &Reg,
    prox_penalty: &P,
    pdpsconfig: &PDPSConfig<F>,
    iterator: I,
    mut plotter: Plot,
    μ0: Option<RNDM<N, F>>,
) -> DynResult<RNDM<N, F>>
where
    F: Float + ToNalgebraRealField,
    I: AlgIteratorFactory<IterInfo<F>>,
    A: ForwardModel<RNDM<N, F>, F>,
    for<'b> &'b A::Observable: Instance<A::Observable>,
    A::Observable: AXPY<Field = F>,
    RNDM<N, F>: SpikeMerging<F>,
    Reg: RegTerm<Loc<N, F>, F>,
    Phi: Conjugable<A::Observable, F>,
    for<'b> Phi::Conjugate<'b>: Prox<A::Observable>,
    P: ProxPenalty<Loc<N, F>, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD<F, A, RNDM<N, F>>,
    Plot: Plotter<P::ReturnMapping, A::PreadjointCodomain, RNDM<N, F>>,
{
    // Check parameters
    ensure!(
        pdpsconfig.τ0 > 0.0 && pdpsconfig.σ0 > 0.0 && pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0,
        "Invalid step length parameters"
    );

    let opA = f.operator();
    let b = f.data();
    let phistar = f.fidelity().conjugate();

    // Set up parameters
    let config = &pdpsconfig.generic;
    let l = prox_penalty.step_length_bound_pd(opA)?;
    let mut τ = pdpsconfig.τ0 / l;
    let mut σ = pdpsconfig.σ0 / l;
    let γ = phistar.factor_of_strong_convexity();

    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    // Initialise iterates
    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
    let mut y = f.residual(&μ);
    let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
        value: f.apply(μ) + reg.apply(μ),
        n_spikes: μ.len(),
        ε,
        // postprocessing: config.postprocessing.then(|| μ.clone()),
        ..stats
    };
    let mut stats = IterInfo::new();

    // Run the algorithm
    for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
        // Calculate smooth part of surrogate model.
        // FIXME: the clone is required to avoid compiler overflows with reference-Mul requirement above.
        let mut τv = opA.preadjoint().apply(y.clone() * τ);

        // Save current base point
        let μ_base = μ.clone();

        // Insert and reweigh
        let (maybe_d, _within_tolerances) = prox_penalty
            .insert_and_reweigh(&mut μ, &mut τv, τ, ε, config, &reg, &state, &mut stats)?;

        // Prune and possibly merge spikes
        if config.merge_now(&state) {
            stats.merged +=
                prox_penalty.merge_spikes_no_fitness(&mut μ, &mut τv, &μ_base, τ, ε, config, &reg);
        }
        stats.inserted += μ.len() - μ_base.len();
        stats.pruned += prune_with_stats(&mut μ);

        // Update step length parameters
        let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ);

        // Do dual update
        // y = y_prev + τb
        y.axpy(τ, b, 1.0);
        // y = y_prev - τ(A[(1+ω)μ^{k+1}]-b)
        opA.gemv(&mut y, -τ * (1.0 + ω), &μ, 1.0);
        // y = y_prev - τ(A[(1+ω)μ^{k+1} - ω μ^k]-b)
        opA.gemv(&mut y, τ * ω, &μ_base, 1.0);
        y = phistar.prox(τ, y);

        // Give statistics if requested
        let iter = state.iteration();
        stats.this_iters += 1;

        state.if_verbose(|| {
            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
            full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
        });

        ε = tolerance.update(ε, iter);
    }

    postprocess(μ, config, |μ| f.apply(μ))
}

mercurial