src/sliding_fb.rs

changeset 70
ed16d0f10d08
parent 68
00d0881f89a6
--- a/src/sliding_fb.rs	Tue Apr 08 13:31:39 2025 -0500
+++ b/src/sliding_fb.rs	Fri May 08 16:47:58 2026 -0500
@@ -9,23 +9,24 @@
 //use nalgebra::{DVector, DMatrix};
 use itertools::izip;
 use std::iter::Iterator;
+use std::ops::MulAssign;
 
+use crate::fb::*;
+use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
+use crate::measures::merging::SpikeMerging;
+use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon, RNDM};
+use crate::plot::Plotter;
+use crate::prox_penalty::{ProxPenalty, StepLengthBound};
+use crate::regularisation::SlidingRegTerm;
+use crate::types::*;
+use alg_tools::error::DynResult;
 use alg_tools::euclidean::Euclidean;
 use alg_tools::iterate::AlgIteratorFactory;
-use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping};
+use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping};
 use alg_tools::nalgebra_support::ToNalgebraRealField;
 use alg_tools::norms::Norm;
-
-use crate::forward_model::{AdjointProductBoundedBy, BoundedCurvature, ForwardModel};
-use crate::measures::merging::SpikeMerging;
-use crate::measures::{DiscreteMeasure, Radon, RNDM};
-use crate::types::*;
-//use crate::tolerance::Tolerance;
-use crate::dataterm::{calculate_residual, calculate_residual2, DataTerm, L2Squared};
-use crate::fb::*;
-use crate::plot::{PlotLookup, Plotting, SeqPlotter};
-use crate::regularisation::SlidingRegTerm;
-//use crate::transport::TransportLipschitz;
+use anyhow::ensure;
+use std::ops::ControlFlow;
 
 /// Transport settings for [`pointsource_sliding_fb_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
@@ -37,15 +38,20 @@
     pub adaptation: F,
     /// A posteriori transport tolerance multiplier (C_pos)
     pub tolerance_mult_con: F,
+    /// maximum number of adaptation iterations, until cancelling transport.
+    pub max_attempts: usize,
+    /// Maximum number of failed transportations for a single source point
+    pub max_fail: usize,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
 impl<F: Float> TransportConfig<F> {
     /// Check that the parameters are ok. Panics if not.
-    pub fn check(&self) {
-        assert!(self.θ0 > 0.0);
-        assert!(0.0 < self.adaptation && self.adaptation < 1.0);
-        assert!(self.tolerance_mult_con > 0.0);
+    pub fn check(&self) -> DynResult<()> {
+        ensure!(self.θ0 > 0.0);
+        ensure!(0.0 < self.adaptation && self.adaptation < 1.0);
+        ensure!(self.tolerance_mult_con > 0.0);
+        Ok(())
     }
 }
 
@@ -56,6 +62,8 @@
             θ0: 0.9,
             adaptation: 0.9,
             tolerance_mult_con: 100.0,
+            max_attempts: 2,
+            max_fail: usize::MAX,
         }
     }
 }
@@ -66,10 +74,14 @@
 pub struct SlidingFBConfig<F: Float> {
     /// Step length scaling
     pub τ0: F,
+    // Auxiliary variable step length scaling for [`crate::sliding_pdps::pointsource_sliding_fb_pair`]
+    pub σp0: F,
     /// Transport parameters
     pub transport: TransportConfig<F>,
     /// Generic parameters
-    pub insertion: FBGenericConfig<F>,
+    pub insertion: InsertionConfig<F>,
+    /// Guess for curvature bound calculations.
+    pub guess: BoundedCurvatureGuess,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
@@ -77,8 +89,10 @@
     fn default() -> Self {
         SlidingFBConfig {
             τ0: 0.99,
+            σp0: 0.99,
             transport: Default::default(),
             insertion: Default::default(),
+            guess: BoundedCurvatureGuess::BetterThanZero,
         }
     }
 }
@@ -96,166 +110,439 @@
     FullyAdaptive { l: F, max_transport: F, g: G },
 }
 
-/// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)`
-/// with step lengh τ and transport step length `θ_or_adaptive`.
-#[replace_float_literals(F::cast_from(literal))]
-pub(crate) fn initial_transport<F, G, D, const N: usize>(
-    γ1: &mut RNDM<F, N>,
-    μ: &mut RNDM<F, N>,
-    τ: F,
-    θ_or_adaptive: &mut TransportStepLength<F, G>,
-    v: D,
-) -> (Vec<F>, RNDM<F, N>)
-where
-    F: Float + ToNalgebraRealField,
-    G: Fn(F, F) -> F,
-    D: DifferentiableRealMapping<F, N>,
-{
-    use TransportStepLength::*;
+#[derive(Clone, Debug, Serialize)]
+pub struct SingleTransport<const N: usize, F: Float> {
+    /// Source point
+    x: Loc<N, F>,
+    /// Target point
+    y: Loc<N, F>,
+    /// Original mass
+    α_μ_orig: F,
+    /// Transported mass
+    α_γ: F,
+    /// Helper for pruning
+    prune: bool,
+    /// Fail count
+    fail_count: usize,
+}
+
+#[derive(Clone, Debug, Serialize)]
+pub struct Transport<const N: usize, F: Float> {
+    vec: Vec<SingleTransport<N, F>>,
+}
 
-    // Save current base point and shift μ to new positions. Idea is that
-    //  μ_base(_masses) = μ^k (vector of masses)
-    //  μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
-    //  γ1 = π_♯^1γ^{k+1}
-    //  μ = μ^{k+1}
-    let μ_base_masses: Vec<F> = μ.iter_masses().collect();
-    let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
-                                         // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
-                                         //let mut sum_norm_dv = 0.0;
-    let γ_prev_len = γ1.len();
-    assert!(μ.len() >= γ_prev_len);
-    γ1.extend(μ[γ_prev_len..].iter().cloned());
+/// Whether partiall transported points are allowed.
+///
+/// Partial transport can cause spike count explosion, so full or zero
+/// transport is generally preferred. If this is set to `true`, different
+/// transport adaptation heuristics will be used.
+const ALLOW_PARTIAL_TRANSPORT: bool = true;
+const MINIMAL_PARTIAL_TRANSPORT: bool = true;
+
+impl<const N: usize, F: Float> Transport<N, F> {
+    pub(crate) fn new() -> Self {
+        Transport { vec: Vec::new() }
+    }
 
-    // Calculate initial transport and step length.
-    // First calculate initial transported weights
-    for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
-        // If old transport has opposing sign, the new transport will be none.
-        ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
-            0.0
-        } else {
-            δ.α
-        };
+    pub(crate) fn iter(&self) -> impl Iterator<Item = &'_ SingleTransport<N, F>> {
+        self.vec.iter()
+    }
+
+    pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut SingleTransport<N, F>> {
+        self.vec.iter_mut()
+    }
+
+    pub(crate) fn extend<I>(&mut self, it: I)
+    where
+        I: IntoIterator<Item = SingleTransport<N, F>>,
+    {
+        self.vec.extend(it)
+    }
+
+    pub(crate) fn len(&self) -> usize {
+        self.vec.len()
     }
 
-    // Calculate transport rays.
-    match *θ_or_adaptive {
-        Fixed(θ) => {
-            let θτ = τ * θ;
-            for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
-                ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
+    // pub(crate) fn dist_matching(&self, μ: &RNDM<N, F>) -> F {
+    //     self.iter()
+    //         .zip(μ.iter_spikes())
+    //         .map(|(ρ, δ)| (ρ.α_γ - δ.α).abs())
+    //         .sum()
+    // }
+
+    /// Construct `μ̆`, replacing the contents of `μ`.
+    #[replace_float_literals(F::cast_from(literal))]
+    pub(crate) fn μ̆_into(&self, μ: &mut RNDM<N, F>) {
+        assert!(self.len() <= μ.len());
+
+        // First transported points
+        for (δ, ρ) in izip!(μ.iter_spikes_mut(), self.iter()) {
+            if ρ.α_γ.abs() > 0.0 {
+                // Transport – transported point
+                δ.α = ρ.α_γ;
+                δ.x = ρ.y;
+            } else {
+                // No transport – original point
+                δ.α = ρ.α_μ_orig;
+                δ.x = ρ.x;
             }
         }
-        AdaptiveMax {
-            l: ℓ_F,
-            ref mut max_transport,
-            g: ref calculate_θ,
-        } => {
-            *max_transport = max_transport.max(γ1.norm(Radon));
-            let θτ = τ * calculate_θ(ℓ_F, *max_transport);
-            for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
-                ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
+
+        // Then source points with partial transport
+        let mut i = self.len();
+        if ALLOW_PARTIAL_TRANSPORT {
+            // This can cause the number of points to explode, so cannot have partial transport.
+            for ρ in self.iter() {
+                let α = ρ.α_μ_orig - ρ.α_γ;
+                if ρ.α_γ.abs() > F::EPSILON && α != 0.0 {
+                    let δ = DeltaMeasure { α, x: ρ.x };
+                    if i < μ.len() {
+                        μ[i] = δ;
+                    } else {
+                        μ.push(δ)
+                    }
+                    i += 1;
+                }
             }
         }
-        FullyAdaptive {
-            l: ref mut adaptive_ℓ_F,
-            ref mut max_transport,
-            g: ref calculate_θ,
-        } => {
-            *max_transport = max_transport.max(γ1.norm(Radon));
-            let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
-            // Do two runs through the spikes to update θ, breaking if first run did not cause
-            // a change.
-            for _i in 0..=1 {
-                let mut changes = false;
-                for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
-                    let dv_x = v.differential(&δ.x);
-                    let g = &dv_x * (ρ.α.signum() * θ * τ);
-                    ρ.x = δ.x - g;
-                    let n = g.norm2();
-                    if n >= F::EPSILON {
-                        // Estimate Lipschitz factor of ∇v
-                        let this_ℓ_F = (dv_x - v.differential(&ρ.x)).norm2() / n;
-                        *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F);
-                        θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
-                        changes = true
+        μ.truncate(i);
+    }
+
+    /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)`
+    /// with step lengh τ and transport step length `θ_or_adaptive`.
+    #[replace_float_literals(F::cast_from(literal))]
+    pub(crate) fn initial_transport<G, D>(
+        &mut self,
+        μ: &RNDM<N, F>,
+        _τ: F,
+        τθ_or_adaptive: &mut TransportStepLength<F, G>,
+        v: D,
+        tconfig: &TransportConfig<F>,
+    ) where
+        G: Fn(F, F) -> F,
+        D: DifferentiableRealMapping<N, F>,
+    {
+        use TransportStepLength::*;
+
+        // Initialise transport structure weights
+        for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) {
+            ρ.α_μ_orig = δ.α;
+            ρ.x = δ.x;
+            if ρ.fail_count > tconfig.max_fail {
+                ρ.α_γ = 0.0
+            } else {
+                // If old transport has opposing sign, the new transport will be none.
+                ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) {
+                    0.0
+                } else {
+                    δ.α
+                }
+            }
+        }
+
+        let γ_prev_len = self.len();
+        assert!(μ.len() >= γ_prev_len);
+        self.extend(μ[γ_prev_len..].iter().map(|δ| SingleTransport {
+            x: δ.x,
+            y: δ.x, // Just something, will be filled properly in the next phase
+            α_μ_orig: δ.α,
+            α_γ: δ.α,
+            prune: false,
+            fail_count: 0,
+        }));
+
+        // Calculate transport rays.
+        match *τθ_or_adaptive {
+            Fixed(θ) => {
+                for ρ in self.iter_mut() {
+                    if ρ.fail_count <= tconfig.max_fail {
+                        ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ);
                     }
                 }
-                if !changes {
-                    break;
+            }
+            AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θτ } => {
+                *max_transport = max_transport.max(self.norm(Radon));
+                let θτ = calculate_θτ(ℓ_F, *max_transport);
+                for ρ in self.iter_mut() {
+                    if ρ.fail_count <= tconfig.max_fail {
+                        ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ);
+                    }
+                }
+            }
+            FullyAdaptive {
+                l: ref mut adaptive_ℓ_F,
+                ref mut max_transport,
+                g: ref calculate_θτ,
+            } => {
+                *max_transport = max_transport.max(self.norm(Radon));
+                let mut θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport);
+                // Do two runs through the spikes to update θ, breaking if first run did not cause
+                // a change.
+                for _i in 0..=1 {
+                    let mut changes = false;
+                    for ρ in self.iter_mut() {
+                        if ρ.fail_count < tconfig.max_fail {
+                            let dv_x = v.differential(&ρ.x);
+                            let g = &dv_x * (ρ.α_γ.signum() * θτ);
+                            ρ.y = ρ.x - g;
+                            let n = g.norm2();
+                            if n >= F::EPSILON {
+                                // Estimate Lipschitz factor of ∇v
+                                let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n;
+                                *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F);
+                                θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport);
+                                changes = true
+                            }
+                        }
+                    }
+                    if !changes {
+                        break;
+                    }
                 }
             }
         }
     }
 
-    // Set initial guess for μ=μ^{k+1}.
-    for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) {
-        if ρ.α.abs() > F::EPSILON {
-            δ.x = ρ.x;
-            //δ.α = ρ.α; // already set above
-        } else {
-            δ.α = β;
+    /// A posteriori transport adaptation.
+    #[replace_float_literals(F::cast_from(literal))]
+    pub(crate) fn aposteriori_transport<D>(
+        &mut self,
+        μ: &RNDM<N, F>,
+        μ̆: &RNDM<N, F>,
+        _v: &mut D,
+        extra: Option<F>,
+        ε: F,
+        tconfig: &TransportConfig<F>,
+        attempts: &mut usize,
+    ) -> bool
+    where
+        D: DifferentiableRealMapping<N, F>,
+    {
+        *attempts += 1;
+
+        // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
+        // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
+        // at that point to zero, and retry.
+        let mut all_ok = true;
+        for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) {
+            if δ.α == 0.0 && ρ.α_γ != 0.0 {
+                all_ok = false;
+                ρ.α_γ = 0.0;
+            }
         }
+
+        // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z).
+        //    through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ̆^k
+        //    which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient.
+        let nγ = self.norm(Radon);
+        let nΔ = μ.dist_matching(&μ̆) + extra.unwrap_or(0.0);
+        let t = ε * tconfig.tolerance_mult_con;
+        if nγ * nΔ > t && *attempts >= tconfig.max_attempts {
+            all_ok = false;
+        } else if nγ * nΔ > t {
+            // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
+            // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
+            // will not enter here.
+            //*self *= tconfig.adaptation * t / (nγ * nΔ);
+
+            // We want a consistent behaviour that has the potential to set many weights to zero.
+            // Therefore, we find the smallest uniform reduction `chg_one`, subtracted
+            // from all weights, that achieves total `adapt` adaptation.
+            let adapt_to = tconfig.adaptation * t / nΔ;
+            let reduction_target = nγ - adapt_to;
+            assert!(reduction_target > 0.0);
+            if ALLOW_PARTIAL_TRANSPORT {
+                if MINIMAL_PARTIAL_TRANSPORT {
+                    // This reduces weights of transport, starting from … until `adapt` is
+                    // exhausted. It will, therefore, only ever cause one extrap point insertion
+                    // at the sources, unlike “full” partial transport.
+                    //let refs = self.vec.iter_mut().collect::<Vec<_>>();
+                    //refs.sort_by(|ρ1, ρ2| ρ1.α_γ.abs().partial_cmp(&ρ2.α_γ.abs()).unwrap());
+                    // let mut it = refs.into_iter();
+                    //
+                    // Maybe sort by differential norm
+                    // let mut refs = self
+                    //     .vec
+                    //     .iter_mut()
+                    //     .map(|ρ| {
+                    //         let val = v.differential(&ρ.x).norm2_squared();
+                    //         (ρ, val)
+                    //     })
+                    //     .collect::<Vec<_>>();
+                    // refs.sort_by(|(_, v1), (_, v2)| v2.partial_cmp(&v1).unwrap());
+                    // let mut it = refs.into_iter().map(|(ρ, _)| ρ);
+                    let mut it = self.vec.iter_mut().rev();
+                    let _unused = it.try_fold(reduction_target, |left, ρ| {
+                        let w = ρ.α_γ.abs();
+                        if left <= w {
+                            ρ.α_γ = ρ.α_γ.signum() * (w - left);
+                            ControlFlow::Break(())
+                        } else {
+                            ρ.α_γ = 0.0;
+                            ControlFlow::Continue(left - w)
+                        }
+                    });
+                } else {
+                    // This version equally reduces all weights. It causes partial transport, which
+                    // has the problem that that we need to then adapt weights in both start and
+                    // end points, in insert_and_reweigh, somtimes causing the number of spikes μ
+                    // to explode.
+                    let mut abs_weights = self
+                        .vec
+                        .iter()
+                        .map(|ρ| ρ.α_γ.abs())
+                        .filter(|t| *t > F::EPSILON)
+                        .collect::<Vec<F>>();
+                    abs_weights.sort_by(|a, b| a.partial_cmp(b).unwrap());
+                    let n = abs_weights.len();
+                    // Cannot have partial transport; can cause spike count explosion
+                    let chg = abs_weights.into_iter().zip((1..=n).rev()).try_fold(
+                        0.0,
+                        |smaller_total, (w, m)| {
+                            let mf = F::cast_from(m);
+                            let reduction = w * mf + smaller_total;
+                            if reduction >= reduction_target {
+                                ControlFlow::Break((reduction_target - smaller_total) / mf)
+                            } else {
+                                ControlFlow::Continue(smaller_total + w)
+                            }
+                        },
+                    );
+                    match chg {
+                        ControlFlow::Continue(_) => self.vec.iter_mut().for_each(|δ| δ.α_γ = 0.0),
+                        ControlFlow::Break(chg_one) => self.vec.iter_mut().for_each(|ρ| {
+                            let t = ρ.α_γ.abs();
+                            if t > 0.0 {
+                                if ALLOW_PARTIAL_TRANSPORT {
+                                    let new = (t - chg_one).max(0.0);
+                                    ρ.α_γ = ρ.α_γ.signum() * new;
+                                }
+                            }
+                        }),
+                    }
+                }
+            } else {
+                // This version zeroes smallest weights, avoiding partial transport.
+                let mut abs_weights_idx = self
+                    .vec
+                    .iter()
+                    .map(|ρ| ρ.α_γ.abs())
+                    .zip(0..)
+                    .filter(|(w, _)| *w >= 0.0)
+                    .collect::<Vec<(F, usize)>>();
+                abs_weights_idx.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap());
+
+                let mut left = reduction_target;
+
+                for (w, i) in abs_weights_idx {
+                    left -= w;
+                    let ρ = &mut self.vec[i];
+                    ρ.α_γ = 0.0;
+                    if left < 0.0 {
+                        break;
+                    }
+                }
+            }
+
+            all_ok = false
+        }
+
+        if !all_ok && *attempts >= tconfig.max_attempts {
+            for ρ in self.iter_mut() {
+                ρ.α_γ = 0.0;
+            }
+        }
+
+        for ρ in self.iter_mut() {
+            if ρ.α_γ == 0.0 {
+                ρ.fail_count += 1;
+            } else if all_ok {
+                ρ.fail_count = 0;
+            }
+        }
+
+        all_ok
     }
-    // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b)
-    μ_base_minus_γ0.set_masses(
-        μ_base_masses
+
+    /// Returns $‖μ\^k - π\_♯\^0γ\^{k+1}‖$
+    pub(crate) fn μ0_minus_γ0_radon(&self) -> F {
+        self.vec.iter().map(|ρ| (ρ.α_μ_orig - ρ.α_γ).abs()).sum()
+    }
+
+    /// Returns $∫ c_2 d|γ|$
+    #[replace_float_literals(F::cast_from(literal))]
+    pub(crate) fn c2integral(&self) -> F {
+        self.vec
             .iter()
-            .zip(γ1.iter_masses())
-            .map(|(&a, b)| a - b),
-    );
-    (μ_base_masses, μ_base_minus_γ0)
+            .map(|ρ| ρ.y.dist2_squared(&ρ.x) / 2.0 * ρ.α_γ.abs())
+            .sum()
+    }
+
+    #[replace_float_literals(F::cast_from(literal))]
+    pub(crate) fn get_transport_stats(&self, stats: &mut IterInfo<F>, μ: &RNDM<N, F>) {
+        // TODO: This doesn't take into account μ[i].α becoming zero in the latest tranport
+        // attempt, for i < self.len(), when a corresponding source term also exists with index
+        // j ≥ self.len(). For now, we let that be reflected in the prune count.
+        stats.inserted += μ.len() - self.len();
+
+        let transp = stats.get_transport_mut();
+
+        transp.dist = {
+            let (a, b) = transp.dist;
+            (a + self.c2integral(), b + self.norm(Radon))
+        };
+        transp.untransported_fraction = {
+            let (a, b) = transp.untransported_fraction;
+            let source = self.iter().map(|ρ| ρ.α_μ_orig.abs()).sum();
+            (a + self.μ0_minus_γ0_radon(), b + source)
+        };
+        transp.transport_error = {
+            let (a, b) = transp.transport_error;
+            //(a + self.dist_matching(&μ), b + self.norm(Radon))
+
+            // This ignores points that have been not transported at all, to only calculate
+            // destnation error; untransported_fraction accounts for not being able to transport
+            // at all.
+            self.iter()
+                .zip(μ.iter_spikes())
+                .fold((a, b), |(a, b), (ρ, δ)| {
+                    let transported = ρ.α_γ.abs();
+                    if transported > F::EPSILON {
+                        (a + (ρ.α_γ - δ.α).abs(), b + transported)
+                    } else {
+                        (a, b)
+                    }
+                })
+        };
+    }
+
+    /// Prune spikes with zero weight. To maintain correct ordering between μ and γ, also the
+    /// latter needs to be pruned when μ is.
+    pub(crate) fn prune_compat(&mut self, μ: &mut RNDM<N, F>, stats: &mut IterInfo<F>) {
+        assert!(self.vec.len() <= μ.len());
+        let old_len = μ.len();
+        for (ρ, δ) in self.vec.iter_mut().zip(μ.iter_spikes()) {
+            ρ.prune = !(δ.α.abs() > F::EPSILON);
+        }
+        μ.prune_by(|δ| δ.α.abs() > F::EPSILON);
+        stats.pruned += old_len - μ.len();
+        self.vec.retain(|ρ| !ρ.prune);
+        assert!(self.vec.len() <= μ.len());
+    }
 }
 
-/// A posteriori transport adaptation.
-#[replace_float_literals(F::cast_from(literal))]
-pub(crate) fn aposteriori_transport<F, const N: usize>(
-    γ1: &mut RNDM<F, N>,
-    μ: &mut RNDM<F, N>,
-    μ_base_minus_γ0: &mut RNDM<F, N>,
-    μ_base_masses: &Vec<F>,
-    extra: Option<F>,
-    ε: F,
-    tconfig: &TransportConfig<F>,
-) -> bool
-where
-    F: Float + ToNalgebraRealField,
-{
-    // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
-    // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
-    // at that point to zero, and retry.
-    let mut all_ok = true;
-    for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
-        if α_μ == 0.0 && *α_γ1 != 0.0 {
-            all_ok = false;
-            *α_γ1 = 0.0;
+impl<const N: usize, F: Float> Norm<Radon, F> for Transport<N, F> {
+    fn norm(&self, _: Radon) -> F {
+        self.iter().map(|ρ| ρ.α_γ.abs()).sum()
+    }
+}
+
+impl<const N: usize, F: Float> MulAssign<F> for Transport<N, F> {
+    fn mul_assign(&mut self, factor: F) {
+        for ρ in self.iter_mut() {
+            ρ.α_γ *= factor;
         }
     }
-
-    // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z).
-    //    through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1},
-    //    which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient.
-    let nγ = γ1.norm(Radon);
-    let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0);
-    let t = ε * tconfig.tolerance_mult_con;
-    if nγ * nΔ > t {
-        // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
-        // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
-        // will not enter here.
-        *γ1 *= tconfig.adaptation * t / (nγ * nΔ);
-        all_ok = false
-    }
-
-    if !all_ok {
-        // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
-        μ_base_minus_γ0.set_masses(
-            μ_base_masses
-                .iter()
-                .zip(γ1.iter_masses())
-                .map(|(&a, b)| a - b),
-        );
-    }
-
-    all_ok
 }
 
 /// Iteratively solve the pointsource localisation problem using sliding forward-backward
@@ -264,36 +551,33 @@
 /// The parametrisation is as for [`pointsource_fb_reg`].
 /// Inertia is currently not supported.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N: usize>(
-    opA: &A,
-    b: &A::Observable,
-    reg: Reg,
+pub fn pointsource_sliding_fb_reg<F, I, Dat, Reg, Plot, P, const N: usize>(
+    f: &Dat,
+    reg: &Reg,
     prox_penalty: &P,
     config: &SlidingFBConfig<F>,
     iterator: I,
-    mut plotter: SeqPlotter<F, N>,
-) -> RNDM<F, N>
+    mut plotter: Plot,
+    μ0: Option<RNDM<N, F>>,
+) -> DynResult<RNDM<N, F>>
 where
     F: Float + ToNalgebraRealField,
-    I: AlgIteratorFactory<IterInfo<F, N>>,
-    A: ForwardModel<RNDM<F, N>, F>
-        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>
-        + BoundedCurvature<FloatType = F>,
-    for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>,
-    A::PreadjointCodomain: DifferentiableRealMapping<F, N>,
-    RNDM<F, N>: SpikeMerging<F>,
-    Reg: SlidingRegTerm<F, N>,
-    P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
-    PlotLookup: Plotting<N>,
+    I: AlgIteratorFactory<IterInfo<F>>,
+    Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>,
+    Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>,
+    //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>,
+    RNDM<N, F>: SpikeMerging<F>,
+    Reg: SlidingRegTerm<Loc<N, F>, F>,
+    P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
+    Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
 {
     // Check parameters
-    assert!(config.τ0 > 0.0, "Invalid step length parameter");
-    config.transport.check();
+    ensure!(config.τ0 > 0.0, "Invalid step length parameter");
+    config.transport.check()?;
 
     // Initialise iterates
-    let mut μ = DiscreteMeasure::new();
-    let mut γ1 = DiscreteMeasure::new();
-    let mut residual = -b; // Has to equal $Aμ-b$.
+    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
+    let mut γ = Transport::new();
 
     // Set up parameters
     // let opAnorm = opA.opnorm_bound(Radon, L2);
@@ -301,25 +585,28 @@
     //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0);
     //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
     let ℓ = 0.0;
-    let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
-    let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components();
-    let transport_lip = maybe_transport_lip.unwrap();
-    let calculate_θ = |ℓ_F, max_transport| {
-        let ℓ_r = transport_lip * max_transport;
-        config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r))
-    };
-    let mut θ_or_adaptive = match maybe_ℓ_F0 {
-        //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)),
-        Some(ℓ_F0) => TransportStepLength::AdaptiveMax {
-            l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual
-            max_transport: 0.0,
-            g: calculate_θ,
-        },
-        None => TransportStepLength::FullyAdaptive {
-            l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials
-            max_transport: 0.0,
-            g: calculate_θ,
-        },
+    let τ = config.τ0 / prox_penalty.step_length_bound(&f)?;
+
+    let mut θ_or_adaptive = match f.curvature_bound_components(config.guess) {
+        (_, Err(_)) => TransportStepLength::Fixed(config.transport.θ0),
+        (maybe_ℓ_F, Ok(transport_lip)) => {
+            let calculate_θτ = move |ℓ_F, max_transport| {
+                let ℓ_r = transport_lip * max_transport;
+                config.transport.θ0 / (ℓ + ℓ_F + ℓ_r)
+            };
+            match maybe_ℓ_F {
+                Ok(ℓ_F) => TransportStepLength::AdaptiveMax {
+                    l: ℓ_F, // TODO: could estimate computing the real reesidual
+                    max_transport: 0.0,
+                    g: calculate_θτ,
+                },
+                Err(_) => TransportStepLength::FullyAdaptive {
+                    l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials
+                    max_transport: 0.0,
+                    g: calculate_θτ,
+                },
+            }
+        }
     };
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
@@ -327,8 +614,8 @@
     let mut ε = tolerance.initial();
 
     // Statistics
-    let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo {
-        value: residual.norm2_squared_div2() + reg.apply(μ),
+    let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
+        value: f.apply(μ) + reg.apply(μ),
         n_spikes: μ.len(),
         ε,
         // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
@@ -337,90 +624,67 @@
     let mut stats = IterInfo::new();
 
     // Run the algorithm
-    for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
+    for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate initial transport
-        let v = opA.preadjoint().apply(residual);
-        let (μ_base_masses, mut μ_base_minus_γ0) =
-            initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v);
+        let v = f.differential(&μ);
+        γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v, &config.transport);
+
+        let mut attempts = 0;
 
         // Solve finite-dimensional subproblem several times until the dual variable for the
         // regularisation term conforms to the assumptions made for the transport above.
-        let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop {
+        let (maybe_d, _within_tolerances, mut τv̆, μ̆) = 'adapt_transport: loop {
+            // Set initial guess for μ=μ^{k+1}.
+            γ.μ̆_into(&mut μ);
+            let μ̆ = μ.clone();
+
             // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
-            let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
-            let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ);
+            //let residual_μ̆ = calculate_residual2(&γ1, &μ0_minus_γ0, opA, b);
+            // TODO: this could be optimised by doing the differential like the
+            // old residual2.
+            // NOTE: This assumes that μ = γ1
+            let mut τv̆ = f.differential(&μ̆) * τ;
 
             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
             let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
                 &mut μ,
                 &mut τv̆,
-                &γ1,
-                Some(&μ_base_minus_γ0),
                 τ,
                 ε,
                 &config.insertion,
                 &reg,
                 &state,
                 &mut stats,
-            );
+            )?;
 
             // A posteriori transport adaptation.
-            if aposteriori_transport(
-                &mut γ1,
-                &mut μ,
-                &mut μ_base_minus_γ0,
-                &μ_base_masses,
-                None,
-                ε,
-                &config.transport,
-            ) {
-                break 'adapt_transport (maybe_d, within_tolerances, τv̆);
+            if γ.aposteriori_transport(&μ, &μ̆, &mut τv̆, None, ε, &config.transport, &mut attempts)
+            {
+                break 'adapt_transport (maybe_d, within_tolerances, τv̆, μ̆);
             }
+
+            stats.get_transport_mut().readjustment_iters += 1;
         };
 
-        stats.untransported_fraction = Some({
-            assert_eq!(μ_base_masses.len(), γ1.len());
-            let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
-            let source = μ_base_masses.iter().map(|v| v.abs()).sum();
-            (a + μ_base_minus_γ0.norm(Radon), b + source)
-        });
-        stats.transport_error = Some({
-            assert_eq!(μ_base_masses.len(), γ1.len());
-            let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
-            (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
-        });
+        γ.get_transport_stats(&mut stats, &μ);
 
         // Merge spikes.
         // This crucially expects the merge routine to be stable with respect to spike locations,
         // and not to performing any pruning. That is be to done below simultaneously for γ.
-        let ins = &config.insertion;
-        if ins.merge_now(&state) {
+        if config.insertion.merge_now(&state) {
             stats.merged += prox_penalty.merge_spikes(
                 &mut μ,
                 &mut τv̆,
-                &γ1,
-                Some(&μ_base_minus_γ0),
+                &μ̆,
                 τ,
                 ε,
-                ins,
+                &config.insertion,
                 &reg,
-                Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
+                Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)),
             );
         }
 
-        // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
-        // latter needs to be pruned when μ is.
-        // TODO: This could do with a two-vector Vec::retain to avoid copies.
-        let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
-        if μ_new.len() != μ.len() {
-            let mut μ_iter = μ.iter_spikes();
-            γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
-            stats.pruned += μ.len() - μ_new.len();
-            μ = μ_new;
-        }
-
-        // Update residual
-        residual = calculate_residual(&μ, opA, b);
+        γ.prune_compat(&mut μ, &mut stats);
 
         let iter = state.iteration();
         stats.this_iters += 1;
@@ -428,17 +692,13 @@
         // Give statistics if requested
         state.if_verbose(|| {
             plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
-            full_stats(
-                &residual,
-                &μ,
-                ε,
-                std::mem::replace(&mut stats, IterInfo::new()),
-            )
+            full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
         });
 
         // Update main tolerance for next iteration
         ε = tolerance.update(ε, iter);
     }
 
-    postprocess(μ, &config.insertion, L2Squared, opA, b)
+    //postprocess(μ, &config.insertion, f)
+    postprocess(μ, &config.insertion, |μ̃| f.apply(μ̃))
 }

mercurial