src/sliding_fb.rs

changeset 52
f0e8704d3f0e
parent 49
6b0db7251ebe
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/sliding_fb.rs	Mon Feb 17 13:54:53 2025 -0500
@@ -0,0 +1,444 @@
+/*!
+Solver for the point source localisation problem using a sliding
+forward-backward splitting method.
+*/
+
+use numeric_literals::replace_float_literals;
+use serde::{Deserialize, Serialize};
+//use colored::Colorize;
+//use nalgebra::{DVector, DMatrix};
+use itertools::izip;
+use std::iter::Iterator;
+
+use alg_tools::euclidean::Euclidean;
+use alg_tools::iterate::AlgIteratorFactory;
+use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping};
+use alg_tools::nalgebra_support::ToNalgebraRealField;
+use alg_tools::norms::Norm;
+
+use crate::forward_model::{AdjointProductBoundedBy, BoundedCurvature, ForwardModel};
+use crate::measures::merging::SpikeMerging;
+use crate::measures::{DiscreteMeasure, Radon, RNDM};
+use crate::types::*;
+//use crate::tolerance::Tolerance;
+use crate::dataterm::{calculate_residual, calculate_residual2, DataTerm, L2Squared};
+use crate::fb::*;
+use crate::plot::{PlotLookup, Plotting, SeqPlotter};
+use crate::regularisation::SlidingRegTerm;
+//use crate::transport::TransportLipschitz;
+
+/// Transport settings for [`pointsource_sliding_fb_reg`].
+#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
+#[serde(default)]
+pub struct TransportConfig<F: Float> {
+    /// Transport step length $θ$ normalised to $(0, 1)$.
+    pub θ0: F,
+    /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance.
+    pub adaptation: F,
+    /// A posteriori transport tolerance multiplier (C_pos)
+    pub tolerance_mult_con: F,
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F: Float> TransportConfig<F> {
+    /// Check that the parameters are ok. Panics if not.
+    pub fn check(&self) {
+        assert!(self.θ0 > 0.0);
+        assert!(0.0 < self.adaptation && self.adaptation < 1.0);
+        assert!(self.tolerance_mult_con > 0.0);
+    }
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F: Float> Default for TransportConfig<F> {
+    fn default() -> Self {
+        TransportConfig {
+            θ0: 0.9,
+            adaptation: 0.9,
+            tolerance_mult_con: 100.0,
+        }
+    }
+}
+
+/// Settings for [`pointsource_sliding_fb_reg`].
+#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
+#[serde(default)]
+pub struct SlidingFBConfig<F: Float> {
+    /// Step length scaling
+    pub τ0: F,
+    /// Transport parameters
+    pub transport: TransportConfig<F>,
+    /// Generic parameters
+    pub insertion: FBGenericConfig<F>,
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F: Float> Default for SlidingFBConfig<F> {
+    fn default() -> Self {
+        SlidingFBConfig {
+            τ0: 0.99,
+            transport: Default::default(),
+            insertion: Default::default(),
+        }
+    }
+}
+
+/// Internal type of adaptive transport step length calculation
+pub(crate) enum TransportStepLength<F: Float, G: Fn(F, F) -> F> {
+    /// Fixed, known step length
+    #[allow(dead_code)]
+    Fixed(F),
+    /// Adaptive step length, only wrt. maximum transport.
+    /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
+    AdaptiveMax { l: F, max_transport: F, g: G },
+    /// Adaptive step length.
+    /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
+    FullyAdaptive { l: F, max_transport: F, g: G },
+}
+
+/// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)`
+/// with step lengh τ and transport step length `θ_or_adaptive`.
+#[replace_float_literals(F::cast_from(literal))]
+pub(crate) fn initial_transport<F, G, D, const N: usize>(
+    γ1: &mut RNDM<F, N>,
+    μ: &mut RNDM<F, N>,
+    τ: F,
+    θ_or_adaptive: &mut TransportStepLength<F, G>,
+    v: D,
+) -> (Vec<F>, RNDM<F, N>)
+where
+    F: Float + ToNalgebraRealField,
+    G: Fn(F, F) -> F,
+    D: DifferentiableRealMapping<F, N>,
+{
+    use TransportStepLength::*;
+
+    // Save current base point and shift μ to new positions. Idea is that
+    //  μ_base(_masses) = μ^k (vector of masses)
+    //  μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
+    //  γ1 = π_♯^1γ^{k+1}
+    //  μ = μ^{k+1}
+    let μ_base_masses: Vec<F> = μ.iter_masses().collect();
+    let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
+                                         // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
+                                         //let mut sum_norm_dv = 0.0;
+    let γ_prev_len = γ1.len();
+    assert!(μ.len() >= γ_prev_len);
+    γ1.extend(μ[γ_prev_len..].iter().cloned());
+
+    // Calculate initial transport and step length.
+    // First calculate initial transported weights
+    for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
+        // If old transport has opposing sign, the new transport will be none.
+        ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
+            0.0
+        } else {
+            δ.α
+        };
+    }
+
+    // Calculate transport rays.
+    match *θ_or_adaptive {
+        Fixed(θ) => {
+            let θτ = τ * θ;
+            for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
+                ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
+            }
+        }
+        AdaptiveMax {
+            l: ℓ_F,
+            ref mut max_transport,
+            g: ref calculate_θ,
+        } => {
+            *max_transport = max_transport.max(γ1.norm(Radon));
+            let θτ = τ * calculate_θ(ℓ_F, *max_transport);
+            for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
+                ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
+            }
+        }
+        FullyAdaptive {
+            l: ref mut adaptive_ℓ_F,
+            ref mut max_transport,
+            g: ref calculate_θ,
+        } => {
+            *max_transport = max_transport.max(γ1.norm(Radon));
+            let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
+            // Do two runs through the spikes to update θ, breaking if first run did not cause
+            // a change.
+            for _i in 0..=1 {
+                let mut changes = false;
+                for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
+                    let dv_x = v.differential(&δ.x);
+                    let g = &dv_x * (ρ.α.signum() * θ * τ);
+                    ρ.x = δ.x - g;
+                    let n = g.norm2();
+                    if n >= F::EPSILON {
+                        // Estimate Lipschitz factor of ∇v
+                        let this_ℓ_F = (dv_x - v.differential(&ρ.x)).norm2() / n;
+                        *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F);
+                        θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
+                        changes = true
+                    }
+                }
+                if !changes {
+                    break;
+                }
+            }
+        }
+    }
+
+    // Set initial guess for μ=μ^{k+1}.
+    for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) {
+        if ρ.α.abs() > F::EPSILON {
+            δ.x = ρ.x;
+            //δ.α = ρ.α; // already set above
+        } else {
+            δ.α = β;
+        }
+    }
+    // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b)
+    μ_base_minus_γ0.set_masses(
+        μ_base_masses
+            .iter()
+            .zip(γ1.iter_masses())
+            .map(|(&a, b)| a - b),
+    );
+    (μ_base_masses, μ_base_minus_γ0)
+}
+
+/// A posteriori transport adaptation.
+#[replace_float_literals(F::cast_from(literal))]
+pub(crate) fn aposteriori_transport<F, const N: usize>(
+    γ1: &mut RNDM<F, N>,
+    μ: &mut RNDM<F, N>,
+    μ_base_minus_γ0: &mut RNDM<F, N>,
+    μ_base_masses: &Vec<F>,
+    extra: Option<F>,
+    ε: F,
+    tconfig: &TransportConfig<F>,
+) -> bool
+where
+    F: Float + ToNalgebraRealField,
+{
+    // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
+    // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
+    // at that point to zero, and retry.
+    let mut all_ok = true;
+    for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
+        if α_μ == 0.0 && *α_γ1 != 0.0 {
+            all_ok = false;
+            *α_γ1 = 0.0;
+        }
+    }
+
+    // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z).
+    //    through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1},
+    //    which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient.
+    let nγ = γ1.norm(Radon);
+    let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0);
+    let t = ε * tconfig.tolerance_mult_con;
+    if nγ * nΔ > t {
+        // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
+        // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
+        // will not enter here.
+        *γ1 *= tconfig.adaptation * t / (nγ * nΔ);
+        all_ok = false
+    }
+
+    if !all_ok {
+        // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
+        μ_base_minus_γ0.set_masses(
+            μ_base_masses
+                .iter()
+                .zip(γ1.iter_masses())
+                .map(|(&a, b)| a - b),
+        );
+    }
+
+    all_ok
+}
+
+/// Iteratively solve the pointsource localisation problem using sliding forward-backward
+/// splitting
+///
+/// The parametrisation is as for [`pointsource_fb_reg`].
+/// Inertia is currently not supported.
+#[replace_float_literals(F::cast_from(literal))]
+pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N: usize>(
+    opA: &A,
+    b: &A::Observable,
+    reg: Reg,
+    prox_penalty: &P,
+    config: &SlidingFBConfig<F>,
+    iterator: I,
+    mut plotter: SeqPlotter<F, N>,
+) -> RNDM<F, N>
+where
+    F: Float + ToNalgebraRealField,
+    I: AlgIteratorFactory<IterInfo<F, N>>,
+    A: ForwardModel<RNDM<F, N>, F>
+        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>
+        + BoundedCurvature<FloatType = F>,
+    for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>,
+    A::PreadjointCodomain: DifferentiableRealMapping<F, N>,
+    RNDM<F, N>: SpikeMerging<F>,
+    Reg: SlidingRegTerm<F, N>,
+    P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
+    PlotLookup: Plotting<N>,
+{
+    // Check parameters
+    assert!(config.τ0 > 0.0, "Invalid step length parameter");
+    config.transport.check();
+
+    // Initialise iterates
+    let mut μ = DiscreteMeasure::new();
+    let mut γ1 = DiscreteMeasure::new();
+    let mut residual = -b; // Has to equal $Aμ-b$.
+
+    // Set up parameters
+    // let opAnorm = opA.opnorm_bound(Radon, L2);
+    //let max_transport = config.max_transport.scale
+    //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0);
+    //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
+    let ℓ = 0.0;
+    let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
+    let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components();
+    let transport_lip = maybe_transport_lip.unwrap();
+    let calculate_θ = |ℓ_F, max_transport| {
+        let ℓ_r = transport_lip * max_transport;
+        config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r))
+    };
+    let mut θ_or_adaptive = match maybe_ℓ_F0 {
+        //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)),
+        Some(ℓ_F0) => TransportStepLength::AdaptiveMax {
+            l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual
+            max_transport: 0.0,
+            g: calculate_θ,
+        },
+        None => TransportStepLength::FullyAdaptive {
+            l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials
+            max_transport: 0.0,
+            g: calculate_θ,
+        },
+    };
+    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
+    // by τ compared to the conditional gradient approach.
+    let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling();
+    let mut ε = tolerance.initial();
+
+    // Statistics
+    let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo {
+        value: residual.norm2_squared_div2() + reg.apply(μ),
+        n_spikes: μ.len(),
+        ε,
+        // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
+        ..stats
+    };
+    let mut stats = IterInfo::new();
+
+    // Run the algorithm
+    for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
+        // Calculate initial transport
+        let v = opA.preadjoint().apply(residual);
+        let (μ_base_masses, mut μ_base_minus_γ0) =
+            initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v);
+
+        // Solve finite-dimensional subproblem several times until the dual variable for the
+        // regularisation term conforms to the assumptions made for the transport above.
+        let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop {
+            // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
+            let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
+            let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ);
+
+            // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
+            let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
+                &mut μ,
+                &mut τv̆,
+                &γ1,
+                Some(&μ_base_minus_γ0),
+                τ,
+                ε,
+                &config.insertion,
+                &reg,
+                &state,
+                &mut stats,
+            );
+
+            // A posteriori transport adaptation.
+            if aposteriori_transport(
+                &mut γ1,
+                &mut μ,
+                &mut μ_base_minus_γ0,
+                &μ_base_masses,
+                None,
+                ε,
+                &config.transport,
+            ) {
+                break 'adapt_transport (maybe_d, within_tolerances, τv̆);
+            }
+        };
+
+        stats.untransported_fraction = Some({
+            assert_eq!(μ_base_masses.len(), γ1.len());
+            let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
+            let source = μ_base_masses.iter().map(|v| v.abs()).sum();
+            (a + μ_base_minus_γ0.norm(Radon), b + source)
+        });
+        stats.transport_error = Some({
+            assert_eq!(μ_base_masses.len(), γ1.len());
+            let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
+            (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
+        });
+
+        // Merge spikes.
+        // This crucially expects the merge routine to be stable with respect to spike locations,
+        // and not to performing any pruning. That is be to done below simultaneously for γ.
+        let ins = &config.insertion;
+        if ins.merge_now(&state) {
+            stats.merged += prox_penalty.merge_spikes(
+                &mut μ,
+                &mut τv̆,
+                &γ1,
+                Some(&μ_base_minus_γ0),
+                τ,
+                ε,
+                ins,
+                &reg,
+                Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
+            );
+        }
+
+        // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
+        // latter needs to be pruned when μ is.
+        // TODO: This could do with a two-vector Vec::retain to avoid copies.
+        let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
+        if μ_new.len() != μ.len() {
+            let mut μ_iter = μ.iter_spikes();
+            γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
+            stats.pruned += μ.len() - μ_new.len();
+            μ = μ_new;
+        }
+
+        // Update residual
+        residual = calculate_residual(&μ, opA, b);
+
+        let iter = state.iteration();
+        stats.this_iters += 1;
+
+        // Give statistics if requested
+        state.if_verbose(|| {
+            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
+            full_stats(
+                &residual,
+                &μ,
+                ε,
+                std::mem::replace(&mut stats, IterInfo::new()),
+            )
+        });
+
+        // Update main tolerance for next iteration
+        ε = tolerance.update(ε, iter);
+    }
+
+    postprocess(μ, &config.insertion, L2Squared, opA, b)
+}

mercurial