src/fb.rs

Fri, 16 Jan 2026 19:39:22 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Fri, 16 Jan 2026 19:39:22 -0500
branch
dev
changeset 62
32328a74c790
parent 61
4f468d35fa29
permissions
-rw-r--r--

Lipschitz estimation attempt (incomplete, not implemented for sliding. Doesn't work anyway for basic FB either.)

/*!
Solver for the point source localisation problem using a forward-backward splitting method.

This corresponds to the manuscript

 * Valkonen T. - _Proximal methods for point source localisation_,
   [arXiv:2212.02991](https://arxiv.org/abs/2212.02991).

The main routine is [`pointsource_fb_reg`].

## Problem

<p>
Our objective is to solve
$$
    \min_{μ ∈ ℳ(Ω)}~ F_0(Aμ-b) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ),
$$
where $F_0(y)=\frac{1}{2}\|y\|_2^2$ and the forward operator $A \in 𝕃(ℳ(Ω); ℝ^n)$.
</p>

## Approach

<p>
As documented in more detail in the paper, on each step we approximately solve
$$
    \min_{μ ∈ ℳ(Ω)}~ F(x) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(x) + \frac{1}{2}\|μ-μ^k|_𝒟^2,
$$
where $𝒟: 𝕃(ℳ(Ω); C_c(Ω))$ is typically a convolution operator.
</p>

## Finite-dimensional subproblems.

With $C$ a projection from [`DiscreteMeasure`] to the weights, and $x^k$ such that $x^k=Cμ^k$, we
form the discretised linearised inner problem
<p>
$$
    \min_{x ∈ ℝ^n}~ τ\bigl(F(Cx^k) + [C^*∇F(Cx^k)]^⊤(x-x^k) + α {\vec 1}^⊤ x\bigr)
                    + δ_{≥ 0}(x) + \frac{1}{2}\|x-x^k\|_{C^*𝒟C}^2,
$$
equivalently
$$
    \begin{aligned}
    \min_x~ & τF(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k
            \\
            &
            - [C^*𝒟C x^k - τC^*∇F(Cx^k)]^⊤ x
            \\
            &
            + \frac{1}{2} x^⊤ C^*𝒟C x
            + τα {\vec 1}^⊤ x + δ_{≥ 0}(x),
    \end{aligned}
$$
In other words, we obtain the quadratic non-negativity constrained problem
$$
    \min_{x ∈ ℝ^n}~ \frac{1}{2} x^⊤ Ã x - b̃^⊤ x + c + τα {\vec 1}^⊤ x + δ_{≥ 0}(x).
$$
where
$$
   \begin{aligned}
    Ã & = C^*𝒟C,
    \\
    g̃ & = C^*𝒟C x^k - τ C^*∇F(Cx^k)
        = C^* 𝒟 μ^k - τ C^*A^*(Aμ^k - b)
    \\
    c & = τ F(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k
        \\
        &
        = \frac{τ}{2} \|Aμ^k-b\|^2 - τ[Aμ^k-b]^⊤Aμ^k + \frac{1}{2} \|μ_k\|_{𝒟}^2
        \\
        &
        = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2.
   \end{aligned}
$$
</p>

We solve this with either SSN or FB as determined by
[`crate::subproblem::InnerSettings`] in [`InsertionConfig::inner`].
*/

use crate::measures::merging::SpikeMerging;
use crate::measures::{DiscreteMeasure, RNDM};
use crate::plot::Plotter;
use crate::prox_penalty::StepLengthBoundValue;
pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound};
use crate::regularisation::RegTerm;
use crate::types::*;
use alg_tools::error::DynResult;
use alg_tools::instance::{ClosedSpace, Instance};
use alg_tools::iterate::AlgIteratorFactory;
use alg_tools::mapping::{DifferentiableMapping, Mapping};
use alg_tools::nalgebra_support::ToNalgebraRealField;
use anyhow::anyhow;
use colored::Colorize;
use numeric_literals::replace_float_literals;
use serde::{Deserialize, Serialize};

/// Settings for [`pointsource_fb_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct FBConfig<F: Float> {
    /// Step length scaling
    pub τ0: F,
    // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`]
    pub σp0: F,
    /// Generic parameters
    pub insertion: InsertionConfig<F>,
    /// Always adaptive step length
    pub always_adaptive_τ: bool,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F: Float> Default for FBConfig<F> {
    fn default() -> Self {
        FBConfig { τ0: 0.99, σp0: 0.99, always_adaptive_τ: false, insertion: Default::default() }
    }
}

pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize {
    let n_before_prune = μ.len();
    μ.prune();
    debug_assert!(μ.len() <= n_before_prune);
    n_before_prune - μ.len()
}

/// Adaptive step length and Lipschitz parameter estimation state.
#[derive(Clone, Debug, Serialize)]
pub enum AdaptiveStepLength<const N: usize, F: Float> {
    Adaptive {
        l: F,
        μ_old: RNDM<N, F>,
        fμ_old: F,
        μ_dist: F,
        τ0: F,
        l_is_initial: bool,
    },
    Fixed {
        τ: F,
    },
}

#[replace_float_literals(F::cast_from(literal))]
impl<const N: usize, F: Float> AdaptiveStepLength<N, F> {
    pub fn new<Dat, Reg, P>(f: &Dat, prox_penalty: &P, fbconfig: &FBConfig<F>) -> DynResult<Self>
    where
        F: ToNalgebraRealField,
        Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
        P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
        Reg: RegTerm<Loc<N, F>, F>,
    {
        match (
            prox_penalty.step_length_bound(&f),
            fbconfig.always_adaptive_τ,
        ) {
            (StepLengthBoundValue::LipschitzFactor(l), false) => {
                Ok(AdaptiveStepLength::Fixed { τ: fbconfig.τ0 / l })
            }
            (StepLengthBoundValue::LipschitzFactor(l), true) => {
                let μ_old = DiscreteMeasure::new();
                let fμ_old = f.apply(&μ_old);
                Ok(AdaptiveStepLength::Adaptive {
                    l: l,
                    μ_old,
                    fμ_old,
                    μ_dist: 0.0,
                    τ0: fbconfig.τ0,
                    l_is_initial: false,
                })
            }
            (StepLengthBoundValue::UnreliableLipschitzFactor(l), _) => {
                println!("Lipschitz factor is unreliable; calculating adaptively.");
                let μ_old = DiscreteMeasure::new();
                let fμ_old = f.apply(&μ_old);
                Ok(AdaptiveStepLength::Adaptive {
                    l: l,
                    μ_old,
                    fμ_old,
                    μ_dist: 0.0,
                    τ0: fbconfig.τ0,
                    l_is_initial: true,
                })
            }
            (StepLengthBoundValue::Failure, _) => Err(anyhow!("No Lipschitz estimate available")),
        }
    }

    /// Returns the current value of the step length parameter.
    pub fn current(&self) -> F {
        match *self {
            AdaptiveStepLength::Adaptive { τ0, l, .. } => τ0 / l,
            AdaptiveStepLength::Fixed { τ } => τ,
        }
    }

    /// Update daptive Lipschitz factor and return new step length parameter `τ`.
    ///
    /// Inputs:
    /// * `μ`: current point
    /// * `fμ`: value of the function `f` at `μ`.
    /// * `ν`: derivative of the function `f` at `μ`.
    /// * `τ0`: fractional step length parameter in $[0, 1)$.
    pub fn update<'a, G>(&mut self, μ: &RNDM<N, F>, fμ: F, v: &'a G) -> F
    where
        G: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace,
        &'a G: Instance<G>,
    {
        match self {
            AdaptiveStepLength::Adaptive { l, μ_old, fμ_old, μ_dist, τ0, l_is_initial } => {
                // Estimate step length parameter
                let b = *fμ_old - fμ - μ_old.apply(v) + μ.apply(v);
                let d = *μ_dist;
                if d.abs() > F::EPSILON && μ.len() > 0 && μ_old.len() > 0 {
                    let lc = b / (d * d / 2.0);
                    dbg!(b, d, lc);
                    if *l_is_initial {
                        *l = lc;
                        *l_is_initial = false;
                    } else {
                        *l = l.max(lc);
                    }
                }

                // Store for next iteration
                *μ_old = μ.clone();
                *fμ_old = fμ;

                return *τ0 / *l;
            }
            AdaptiveStepLength::Fixed { τ } => *τ,
        }
    }

    /// Finalises a step, storing μ and its distance to the previous μ.
    ///
    /// This is not included in [`Self::update`], as this function is to be called
    /// before pruning and merging, while μ and its previous version in their internal
    /// presentation still having matching indices for the same coordinate.
    pub fn finish_step(&mut self, μ: &RNDM<N, F>) {
        if let AdaptiveStepLength::Adaptive { μ_dist, μ_old, .. } = self {
            *μ_dist = μ.dist_matching(&μ_old);
        }
    }
}

#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>(
    mut μ: RNDM<N, F>,
    config: &InsertionConfig<F>,
    f: Dat,
) -> DynResult<RNDM<N, F>>
where
    RNDM<N, F>: SpikeMerging<F>,
    for<'a> &'a RNDM<N, F>: Instance<RNDM<N, F>>,
{
    //μ.merge_spikes_fitness(config.final_merging_method(), |μ̃| f.apply(μ̃), |&v| v);
    μ.merge_spikes_fitness(config.final_merging_method(), f, |&v| v);
    μ.prune();
    Ok(μ)
}

/// Iteratively solve the pointsource localisation problem using forward-backward splitting.
///
/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
/// as documented in [`alg_tools::iterate`].
///
/// For details on the mathematical formulation, see the [module level](self) documentation.
///
/// Returns the final iterate.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_fb_reg<F, I, Dat, Reg, P, Plot, const N: usize>(
    f: &Dat,
    reg: &Reg,
    prox_penalty: &P,
    fbconfig: &FBConfig<F>,
    iterator: I,
    mut plotter: Plot,
    μ0: Option<RNDM<N, F>>,
) -> DynResult<RNDM<N, F>>
where
    F: Float + ToNalgebraRealField,
    I: AlgIteratorFactory<IterInfo<F>>,
    RNDM<N, F>: SpikeMerging<F>,
    Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
    Dat::DerivativeDomain: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace,
    Reg: RegTerm<Loc<N, F>, F>,
    P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
    Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
    for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>,
{
    // Set up parameters
    let config = &fbconfig.insertion;
    let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, fbconfig)?;

    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.tolerance * adaptive_τ.current() * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    // Initialise iterates
    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());

    // Statistics
    let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
        value: f.apply(μ) + reg.apply(μ),
        n_spikes: μ.len(),
        ε,
        //postprocessing: config.postprocessing.then(|| μ.clone()),
        ..stats
    };
    let mut stats = IterInfo::new();

    // Run the algorithm
    for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
        // Calculate smooth part of surrogate model.
        // TODO: optimise τ to be applied to residual.
        let (fμ, v) = f.apply_and_differential(&μ);
        let τ = adaptive_τ.update(&μ, fμ, &v);
        dbg!(τ);
        let mut τv = v * τ;

        // Save current base point
        let μ_base = μ.clone();

        // Insert and reweigh
        let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
            &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
        )?;

        // We don't treat merge in adaptive Lipschitz.
        adaptive_τ.finish_step(&μ);

        // Prune and possibly merge spikes
        if config.merge_now(&state) {
            stats.merged += prox_penalty.merge_spikes(
                &mut μ,
                &mut τv,
                &μ_base,
                None,
                τ,
                ε,
                config,
                &reg,
                Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)),
            );
        }

        stats.pruned += prune_with_stats(&mut μ);

        let iter = state.iteration();
        stats.this_iters += 1;

        // Give statistics if needed
        state.if_verbose(|| {
            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
            full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
        });

        // Update main tolerance for next iteration
        ε = tolerance.update(ε, iter);
    }

    //postprocess(μ_prev, config, f)
    postprocess(μ, config, |μ̃| f.apply(μ̃))
}

/// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting.
///
/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
/// as documented in [`alg_tools::iterate`].
///
/// For details on the mathematical formulation, see the [module level](self) documentation.
///
/// Returns the final iterate.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_fista_reg<F, I, Dat, Reg, P, Plot, const N: usize>(
    f: &Dat,
    reg: &Reg,
    prox_penalty: &P,
    fbconfig: &FBConfig<F>,
    iterator: I,
    mut plotter: Plot,
    μ0: Option<RNDM<N, F>>,
) -> DynResult<RNDM<N, F>>
where
    F: Float + ToNalgebraRealField,
    I: AlgIteratorFactory<IterInfo<F>>,
    RNDM<N, F>: SpikeMerging<F>,
    Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
    Dat::DerivativeDomain: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace,
    Reg: RegTerm<Loc<N, F>, F>,
    P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
    Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
    for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>,
{
    // Set up parameters
    let config = &fbconfig.insertion;
    let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, fbconfig)?;

    let mut λ = 1.0;
    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.tolerance * adaptive_τ.current() * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    // Initialise iterates
    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
    let mut μ_prev = μ.clone();
    let mut warned_merging = false;

    // Statistics
    let full_stats = |ν: &RNDM<N, F>, ε, stats| IterInfo {
        value: f.apply(ν) + reg.apply(ν),
        n_spikes: ν.len(),
        ε,
        // postprocessing: config.postprocessing.then(|| ν.clone()),
        ..stats
    };
    let mut stats = IterInfo::new();

    // Run the algorithm
    for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
        // Calculate smooth part of surrogate model.
        let (fμ, v) = f.apply_and_differential(&μ);
        let τ = adaptive_τ.update(&μ, fμ, &v);
        let mut τv = v * τ;

        // Save current base point
        let μ_base = μ.clone();

        // Insert new spikes and reweigh
        let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
            &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
        )?;

        // We don't treat merge in adaptive Lipschitz.
        adaptive_τ.finish_step(&μ);

        // (Do not) merge spikes.
        if config.merge_now(&state) && !warned_merging {
            let err = format!("Merging not supported for μFISTA");
            println!("{}", err.red());
            warned_merging = true;
        }

        // Update inertial prameters
        let λ_prev = λ;
        λ = 2.0 * λ_prev / (λ_prev + (4.0 + λ_prev * λ_prev).sqrt());
        let θ = λ / λ_prev - λ;

        // Perform inertial update on μ.
        // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ
        // and μ_prev have zero weight. Since both have weights from the finite-dimensional
        // subproblem with a proximal projection step, this is likely to happen when the
        // spike is not needed. A copy of the pruned μ without artithmetic performed is
        // stored in μ_prev.
        let n_before_prune = μ.len();
        μ.pruning_sub(1.0 + θ, θ, &mut μ_prev);
        //let μ_new = (&μ * (1.0 + θ)).sub_matching(&(&μ_prev * θ));
        // μ_prev = μ;
        // μ = μ_new;
        debug_assert!(μ.len() <= n_before_prune);
        stats.pruned += n_before_prune - μ.len();

        let iter = state.iteration();
        stats.this_iters += 1;

        // Give statistics if needed
        state.if_verbose(|| {
            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ_prev);
            full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new()))
        });

        // Update main tolerance for next iteration
        ε = tolerance.update(ε, iter);
    }

    //postprocess(μ_prev, config, f)
    postprocess(μ_prev, config, |μ̃| f.apply(μ̃))
}

mercurial