src/pdps.rs

Mon, 06 Jan 2025 11:32:57 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Mon, 06 Jan 2025 11:32:57 -0500
branch
dev
changeset 36
fb911f72e698
parent 35
b087e3eab191
child 37
c5d8bd1a7728
permissions
-rw-r--r--

Factor fix

/*!
Solver for the point source localisation problem with primal-dual proximal splitting.

This corresponds to the manuscript

 * Valkonen T. - _Proximal methods for point source localisation_,
   [arXiv:2212.02991](https://arxiv.org/abs/2212.02991).

The main routine is [`pointsource_pdps_reg`].
Both norm-2-squared and norm-1 data terms are supported. That is, implemented are solvers for
<div>
$$
    \min_{μ ∈ ℳ(Ω)}~ F_0(Aμ - b) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ),
$$
for both $F_0(y)=\frac{1}{2}\|y\|_2^2$ and  $F_0(y)=\|y\|_1$ with the forward operator
$A \in 𝕃(ℳ(Ω); ℝ^n)$.
</div>

## Approach

<p>
The problem above can be written as
$$
    \min_μ \max_y G(μ) + ⟨y, Aμ-b⟩ - F_0^*(μ),
$$
where $G(μ) = α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ)$.
The Fenchel–Rockafellar optimality conditions, employing the predual in $ℳ(Ω)$, are
$$
    0 ∈ A_*y + ∂G(μ)
    \quad\text{and}\quad
    Aμ - b ∈ ∂ F_0^*(y).
$$
The solution of the first part is as for forward-backward, treated in the manuscript.
This is the task of <code>generic_pointsource_fb</code>, where we use <code>FBSpecialisation</code>
to replace the specific residual $Aμ-b$ by $y$.
For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$.
For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$.
</p>
*/

use numeric_literals::replace_float_literals;
use serde::{Serialize, Deserialize};
use nalgebra::DVector;
use clap::ValueEnum;

use alg_tools::iterate::AlgIteratorFactory;
use alg_tools::loc::Loc;
use alg_tools::euclidean::Euclidean;
use alg_tools::linops::Mapping;
use alg_tools::norms::{
    Linfinity,
    Projection,
};
use alg_tools::bisection_tree::{
    BTFN,
    PreBTFN,
    Bounds,
    BTNodeLookup,
    BTNode,
    BTSearch,
    SupportGenerator,
    LocalAnalysis,
};
use alg_tools::mapping::{RealMapping, Instance};
use alg_tools::nalgebra_support::ToNalgebraRealField;
use alg_tools::linops::AXPY;

use crate::types::*;
use crate::measures::{DiscreteMeasure, RNDM, Radon};
use crate::measures::merging::SpikeMerging;
use crate::forward_model::{
    AdjointProductBoundedBy,
    ForwardModel
};
use crate::seminorms::DiscreteMeasureOp;
use crate::plot::{
    SeqPlotter,
    Plotting,
    PlotLookup
};
use crate::fb::{
    FBGenericConfig,
    insert_and_reweigh,
    postprocess,
    prune_with_stats
};
use crate::regularisation::RegTerm;
use crate::dataterm::{
    DataTerm,
    L2Squared,
    L1
};

/// Acceleration
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)]
pub enum Acceleration {
    /// No acceleration
    #[clap(name = "none")]
    None,
    /// Partial acceleration, $ω = 1/\sqrt{1+σ}$
    #[clap(name = "partial", help = "Partial acceleration, ω = 1/√(1+σ)")]
    Partial,
    /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed
    #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")]
    Full
}

#[replace_float_literals(F::cast_from(literal))]
impl Acceleration {
    /// PDPS parameter acceleration. Updates τ and σ and returns ω.
    /// This uses dual strong convexity, not primal.
    fn accelerate<F : Float>(self, τ : &mut F, σ : &mut F, γ : F) -> F {
        match self {
            Acceleration::None => 1.0,
            Acceleration::Partial => {
                let ω = 1.0 / (1.0 + γ * (*σ)).sqrt();
                *σ *= ω;
                *τ /= ω;
                ω
            },
            Acceleration::Full => {
                let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt();
                *σ *= ω;
                *τ /= ω;
                ω
            },
        }
    }
}

/// Settings for [`pointsource_pdps_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct PDPSConfig<F : Float> {
    /// Primal step length scaling. We must have `τ0 * σ0 < 1`.
    pub τ0 : F,
    /// Dual step length scaling. We must have `τ0 * σ0 < 1`.
    pub σ0 : F,
    /// Accelerate if available
    pub acceleration : Acceleration,
    /// Generic parameters
    pub generic : FBGenericConfig<F>,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for PDPSConfig<F> {
    fn default() -> Self {
        let τ0 = 0.5;
        PDPSConfig {
            τ0,
            σ0 : 0.99/τ0,
            acceleration : Acceleration::Partial,
            generic : Default::default(),
        }
    }
}

/// Trait for data terms for the PDPS
#[replace_float_literals(F::cast_from(literal))]
pub trait PDPSDataTerm<F : Float, V, const N : usize> : DataTerm<F, V, N> {
    /// Calculate some subdifferential at `x` for the conjugate
    fn some_subdifferential(&self, x : V) -> V;

    /// Factor of strong convexity of the conjugate
    #[inline]
    fn factor_of_strong_convexity(&self) -> F {
        0.0
    }

    /// Perform dual update
    fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F);
}


#[replace_float_literals(F::cast_from(literal))]
impl<F, V, const N : usize> PDPSDataTerm<F, V, N>
for L2Squared
where
    F : Float,
    V :  Euclidean<F> + AXPY<F>,
    for<'b> &'b V : Instance<V>,
{
    fn some_subdifferential(&self, x : V) -> V { x }

    fn factor_of_strong_convexity(&self) -> F {
        1.0
    }

    #[inline]
    fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) {
        y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ));
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float + nalgebra::RealField, const N : usize>
PDPSDataTerm<F, DVector<F>, N>
for L1 {
    fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> {
        // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well.
        x.iter_mut()
         .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) });
        x
    }

     #[inline]
     fn dual_update(&self, y : &mut DVector<F>, y_prev : &DVector<F>, σ : F) {
        y.axpy(1.0, y_prev, σ);
        y.proj_ball_mut(1.0, Linfinity);
    }
}

/// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting.
///
/// The `dataterm` should be either [`L1`] for norm-1 data term or [`L2Squared`] for norm-2-squared.
/// The settings in `config` have their [respective documentation](PDPSConfig). `opA` is the
/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
/// as documented in [`alg_tools::iterate`].
///
/// For the mathematical formulation, see the [module level](self) documentation and the manuscript.
///
/// Returns the final iterate.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>(
    opA : &'a A,
    b : &'a A::Observable,
    reg : Reg,
    op𝒟 : &'a 𝒟,
    pdpsconfig : &PDPSConfig<F>,
    iterator : I,
    mut plotter : SeqPlotter<F, N>,
    dataterm : D,
) -> RNDM<F, N>
where F : Float + ToNalgebraRealField,
      I : AlgIteratorFactory<IterInfo<F, N>>,
      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
          + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
      𝒟::Codomain : RealMapping<F, N>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      PlotLookup : Plotting<N>,
      RNDM<F, N> : SpikeMerging<F>,
      D : PDPSDataTerm<F, A::Observable, N>,
      Reg : RegTerm<F, N> {

    // Check parameters
    assert!(pdpsconfig.τ0 > 0.0 &&
            pdpsconfig.σ0 > 0.0 &&
            pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0,
            "Invalid step length parameters");

    // Set up parameters
    let config = &pdpsconfig.generic;
    let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
    let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt();
    let mut τ = pdpsconfig.τ0 / l;
    let mut σ = pdpsconfig.σ0 / l;
    let γ = dataterm.factor_of_strong_convexity();

    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    // Initialise iterates
    let mut μ = DiscreteMeasure::new();
    let mut y = dataterm.some_subdifferential(-b);
    let mut y_prev = y.clone();
    let full_stats = |μ : &RNDM<F, N>, ε, stats| IterInfo {
        value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ),
        n_spikes : μ.len(),
        ε,
        // postprocessing: config.postprocessing.then(|| μ.clone()),
        .. stats
    };
    let mut stats = IterInfo::new();

    // Run the algorithm
    for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
        // Calculate smooth part of surrogate model.
        let τv = opA.preadjoint().apply(y * τ);

        // Save current base point
        let μ_base = μ.clone();
        
        // Insert and reweigh
        let (d, _within_tolerances) = insert_and_reweigh(
            &mut μ, &τv, &μ_base, None,
            op𝒟, op𝒟norm,
            τ, ε,
            config, &reg, &state, &mut stats
        );

        // Prune and possibly merge spikes
        if config.merge_now(&state) {
            stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
                let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base));
                reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
            });
        }
        stats.pruned += prune_with_stats(&mut μ);

        // Update step length parameters
        let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ);

        // Do dual update
        y = b.clone();                          // y = b
        opA.gemv(&mut y, 1.0 + ω, &μ, -1.0);    // y = A[(1+ω)μ^{k+1}]-b
        opA.gemv(&mut y, -ω, &μ_base, 1.0);     // y = A[(1+ω)μ^{k+1} - ω μ^k]-b
        dataterm.dual_update(&mut y, &y_prev, σ);
        y_prev.copy_from(&y);

        // Give statistics if requested
        let iter = state.iteration();
        stats.this_iters += 1;

        state.if_verbose(|| {
            plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ);
            full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
        });

        ε = tolerance.update(ε, iter);
    }

    postprocess(μ, config, dataterm, opA, b)
}

mercurial