src/sliding_fb.rs

Fri, 16 Jan 2026 19:39:22 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Fri, 16 Jan 2026 19:39:22 -0500
branch
dev
changeset 62
32328a74c790
parent 61
4f468d35fa29
permissions
-rw-r--r--

Lipschitz estimation attempt (incomplete, not implemented for sliding. Doesn't work anyway for basic FB either.)

/*!
Solver for the point source localisation problem using a sliding
forward-backward splitting method.
*/

use numeric_literals::replace_float_literals;
use serde::{Deserialize, Serialize};
//use colored::Colorize;
//use nalgebra::{DVector, DMatrix};
use itertools::izip;
use std::iter::Iterator;

use crate::fb::*;
use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
use crate::measures::merging::SpikeMerging;
use crate::measures::{DiscreteMeasure, Radon, RNDM};
use crate::plot::Plotter;
use crate::prox_penalty::{ProxPenalty, StepLengthBound};
use crate::regularisation::SlidingRegTerm;
use crate::types::*;
use alg_tools::error::DynResult;
use alg_tools::euclidean::Euclidean;
use alg_tools::instance::{ClosedSpace, Instance};
use alg_tools::iterate::AlgIteratorFactory;
use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping};
use alg_tools::nalgebra_support::ToNalgebraRealField;
use alg_tools::norms::Norm;
use anyhow::ensure;

/// Transport settings for [`pointsource_sliding_fb_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct TransportConfig<F: Float> {
    /// Transport step length $θ$ normalised to $(0, 1)$.
    pub θ0: F,
    /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance.
    pub adaptation: F,
    /// A posteriori transport tolerance multiplier (C_pos)
    pub tolerance_mult_con: F,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F: Float> TransportConfig<F> {
    /// Check that the parameters are ok. Panics if not.
    pub fn check(&self) -> DynResult<()> {
        ensure!(self.θ0 > 0.0);
        ensure!(0.0 < self.adaptation && self.adaptation < 1.0);
        ensure!(self.tolerance_mult_con > 0.0);
        Ok(())
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F: Float> Default for TransportConfig<F> {
    fn default() -> Self {
        TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0 }
    }
}

/// Settings for [`pointsource_sliding_fb_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct SlidingFBConfig<F: Float> {
    /// Step length scaling
    pub τ0: F,
    // Auxiliary variable step length scaling for [`crate::sliding_pdps::pointsource_sliding_fb_pair`]
    pub σp0: F,
    /// Transport parameters
    pub transport: TransportConfig<F>,
    /// Generic parameters
    pub insertion: InsertionConfig<F>,
    /// Guess for curvature bound calculations.
    pub guess: BoundedCurvatureGuess,
    /// Always adaptive step length
    pub always_adaptive_τ: bool,
}

impl<'a, F: Float> Into<FBConfig<F>> for &'a SlidingFBConfig<F> {
    fn into(self) -> FBConfig<F> {
        let SlidingFBConfig { τ0, σp0, insertion, always_adaptive_τ, .. } = *self;
        FBConfig { τ0, σp0, insertion, always_adaptive_τ }
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F: Float> Default for SlidingFBConfig<F> {
    fn default() -> Self {
        SlidingFBConfig {
            τ0: 0.99,
            σp0: 0.99,
            transport: Default::default(),
            insertion: Default::default(),
            guess: BoundedCurvatureGuess::BetterThanZero,
            always_adaptive_τ: false,
        }
    }
}

/// Internal type of adaptive transport step length calculation
pub(crate) enum TransportStepLength<F: Float, G: Fn(F, F) -> F> {
    /// Fixed, known step length
    #[allow(dead_code)]
    Fixed(F),
    /// Adaptive step length, only wrt. maximum transport.
    /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
    AdaptiveMax { l: F, max_transport: F, g: G },
    /// Adaptive step length.
    /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
    FullyAdaptive { l: F, max_transport: F, g: G },
}

/// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)`
/// with step lengh τ and transport step length `θ_or_adaptive`.
#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn initial_transport<F, G, D, const N: usize>(
    γ1: &mut RNDM<N, F>,
    μ: &mut RNDM<N, F>,
    τ: F,
    θ_or_adaptive: &mut TransportStepLength<F, G>,
    v: D,
) -> (Vec<F>, RNDM<N, F>)
where
    F: Float + ToNalgebraRealField,
    G: Fn(F, F) -> F,
    D: DifferentiableRealMapping<N, F>,
{
    use TransportStepLength::*;

    // Save current base point and shift μ to new positions. Idea is that
    //  μ_base(_masses) = μ^k (vector of masses)
    //  μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
    //  γ1 = π_♯^1γ^{k+1}
    //  μ = μ^{k+1}
    let μ_base_masses: Vec<F> = μ.iter_masses().collect();
    let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
                                         // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
                                         //let mut sum_norm_dv = 0.0;
    let γ_prev_len = γ1.len();
    assert!(μ.len() >= γ_prev_len);
    γ1.extend(μ[γ_prev_len..].iter().cloned());

    // Calculate initial transport and step length.
    // First calculate initial transported weights
    for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
        // If old transport has opposing sign, the new transport will be none.
        ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
            0.0
        } else {
            δ.α
        };
    }

    // Calculate transport rays.
    match *θ_or_adaptive {
        Fixed(θ) => {
            let θτ = τ * θ;
            for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
                ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
            }
        }
        AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θ } => {
            *max_transport = max_transport.max(γ1.norm(Radon));
            let θτ = τ * calculate_θ(ℓ_F, *max_transport);
            for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
                ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
            }
        }
        FullyAdaptive { l: ref mut adaptive_ℓ_F, ref mut max_transport, g: ref calculate_θ } => {
            *max_transport = max_transport.max(γ1.norm(Radon));
            let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
            // Do two runs through the spikes to update θ, breaking if first run did not cause
            // a change.
            for _i in 0..=1 {
                let mut changes = false;
                for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
                    let dv_x = v.differential(&δ.x);
                    let g = &dv_x * (ρ.α.signum() * θ * τ);
                    ρ.x = δ.x - g;
                    let n = g.norm2();
                    if n >= F::EPSILON {
                        // Estimate Lipschitz factor of ∇v
                        let this_ℓ_F = (dv_x - v.differential(&ρ.x)).norm2() / n;
                        *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F);
                        θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
                        changes = true
                    }
                }
                if !changes {
                    break;
                }
            }
        }
    }

    // Set initial guess for μ=μ^{k+1}.
    for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) {
        if ρ.α.abs() > F::EPSILON {
            δ.x = ρ.x;
            //δ.α = ρ.α; // already set above
        } else {
            δ.α = β;
        }
    }
    // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b)
    μ_base_minus_γ0.set_masses(
        μ_base_masses
            .iter()
            .zip(γ1.iter_masses())
            .map(|(&a, b)| a - b),
    );
    (μ_base_masses, μ_base_minus_γ0)
}

/// A posteriori transport adaptation.
#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn aposteriori_transport<F, const N: usize>(
    γ1: &mut RNDM<N, F>,
    μ: &mut RNDM<N, F>,
    μ_base_minus_γ0: &mut RNDM<N, F>,
    μ_base_masses: &Vec<F>,
    extra: Option<F>,
    ε: F,
    tconfig: &TransportConfig<F>,
) -> bool
where
    F: Float + ToNalgebraRealField,
{
    // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
    // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
    // at that point to zero, and retry.
    let mut all_ok = true;
    for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
        if α_μ == 0.0 && *α_γ1 != 0.0 {
            all_ok = false;
            *α_γ1 = 0.0;
        }
    }

    // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z).
    //    through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1},
    //    which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient.
    let nγ = γ1.norm(Radon);
    let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0);
    let t = ε * tconfig.tolerance_mult_con;
    if nγ * nΔ > t {
        // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
        // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
        // will not enter here.
        *γ1 *= tconfig.adaptation * t / (nγ * nΔ);
        all_ok = false
    }

    if !all_ok {
        // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
        μ_base_minus_γ0.set_masses(
            μ_base_masses
                .iter()
                .zip(γ1.iter_masses())
                .map(|(&a, b)| a - b),
        );
    }

    all_ok
}

/// Iteratively solve the pointsource localisation problem using sliding forward-backward
/// splitting
///
/// The parametrisation is as for [`pointsource_fb_reg`].
/// Inertia is currently not supported.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_sliding_fb_reg<F, I, Dat, Reg, Plot, P, const N: usize>(
    f: &Dat,
    reg: &Reg,
    prox_penalty: &P,
    config: &SlidingFBConfig<F>,
    iterator: I,
    mut plotter: Plot,
    μ0: Option<RNDM<N, F>>,
) -> DynResult<RNDM<N, F>>
where
    F: Float + ToNalgebraRealField,
    I: AlgIteratorFactory<IterInfo<F>>,
    Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>,
    Dat::DerivativeDomain:
        ClosedMul<F> + DifferentiableRealMapping<N, F, Codomain = F> + ClosedSpace,
    RNDM<N, F>: SpikeMerging<F>,
    Reg: SlidingRegTerm<Loc<N, F>, F>,
    P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
    Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
    for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>,
{
    // Check parameters
    ensure!(config.τ0 > 0.0, "Invalid step length parameter");
    config.transport.check()?;

    // Initialise iterates
    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
    let mut γ1 = DiscreteMeasure::new();

    // Set up parameters
    // let opAnorm = opA.opnorm_bound(Radon, L2);
    //let max_transport = config.max_transport.scale
    //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0);
    //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
    let ℓ = 0.0;
    let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, &config.into())?;
    let τ = adaptive_τ.current();
    println!("TODO: τ in calculate_θ should be adaptive");
    let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess);
    let transport_lip = maybe_transport_lip?;
    let calculate_θ = |ℓ_F, max_transport| {
        let ℓ_r = transport_lip * max_transport;
        config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r))
    };
    let mut θ_or_adaptive = match maybe_ℓ_F {
        //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)),
        Ok(ℓ_F) => TransportStepLength::AdaptiveMax {
            l: ℓ_F, // TODO: could estimate computing the real reesidual
            max_transport: 0.0,
            g: calculate_θ,
        },
        Err(_) => TransportStepLength::FullyAdaptive {
            l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials
            max_transport: 0.0,
            g: calculate_θ,
        },
    };
    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    // Statistics
    let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
        value: f.apply(μ) + reg.apply(μ),
        n_spikes: μ.len(),
        ε,
        // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
        ..stats
    };
    let mut stats = IterInfo::new();

    // Run the algorithm
    for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
        // Calculate initial transport
        let (fμ, v) = f.apply_and_differential(&μ);
        let τ = adaptive_τ.update(&μ, fμ, &v);

        let (μ_base_masses, mut μ_base_minus_γ0) =
            initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v);

        // Solve finite-dimensional subproblem several times until the dual variable for the
        // regularisation term conforms to the assumptions made for the transport above.
        let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop {
            // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
            //let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
            // TODO: this could be optimised by doing the differential like the
            // old residual2.
            let μ̆ = &γ1 + &μ_base_minus_γ0;
            let mut τv̆ = f.differential(μ̆) * τ;

            // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
            let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
                &mut μ,
                &mut τv̆,
                &γ1,
                Some(&μ_base_minus_γ0),
                τ,
                ε,
                &config.insertion,
                &reg,
                &state,
                &mut stats,
            )?;

            // A posteriori transport adaptation.
            if aposteriori_transport(
                &mut γ1,
                &mut μ,
                &mut μ_base_minus_γ0,
                &μ_base_masses,
                None,
                ε,
                &config.transport,
            ) {
                break 'adapt_transport (maybe_d, within_tolerances, τv̆);
            }
        };

        // We don't treat merge in adaptive Lipschitz.
        println!("WARNING: finish_step does not work with sliding");
        adaptive_τ.finish_step(&μ);

        stats.untransported_fraction = Some({
            assert_eq!(μ_base_masses.len(), γ1.len());
            let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
            let source = μ_base_masses.iter().map(|v| v.abs()).sum();
            (a + μ_base_minus_γ0.norm(Radon), b + source)
        });
        stats.transport_error = Some({
            assert_eq!(μ_base_masses.len(), γ1.len());
            let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
            (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
        });

        // Merge spikes.
        // This crucially expects the merge routine to be stable with respect to spike locations,
        // and not to performing any pruning. That is be to done below simultaneously for γ.
        let ins = &config.insertion;
        if ins.merge_now(&state) {
            stats.merged += prox_penalty.merge_spikes(
                &mut μ,
                &mut τv̆,
                &γ1,
                Some(&μ_base_minus_γ0),
                τ,
                ε,
                ins,
                &reg,
                Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)),
            );
        }

        // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
        // latter needs to be pruned when μ is.
        // TODO: This could do with a two-vector Vec::retain to avoid copies.
        let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
        if μ_new.len() != μ.len() {
            let mut μ_iter = μ.iter_spikes();
            γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
            stats.pruned += μ.len() - μ_new.len();
            μ = μ_new;
        }

        let iter = state.iteration();
        stats.this_iters += 1;

        // Give statistics if requested
        state.if_verbose(|| {
            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
            full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
        });

        // Update main tolerance for next iteration
        ε = tolerance.update(ε, iter);
    }

    //postprocess(μ, &config.insertion, f)
    postprocess(μ, &config.insertion, |μ̃| f.apply(μ̃))
}

mercurial