src/sliding_fb.rs

branch
dev
changeset 35
b087e3eab191
parent 34
efa60bc4f743
child 37
c5d8bd1a7728
--- a/src/sliding_fb.rs	Thu Aug 29 00:00:00 2024 -0500
+++ b/src/sliding_fb.rs	Tue Dec 31 09:25:45 2024 -0500
@@ -10,15 +10,12 @@
 use itertools::izip;
 use std::iter::Iterator;
 
-use alg_tools::iterate::{
-    AlgIteratorFactory,
-    AlgIteratorState
-};
+use alg_tools::iterate::AlgIteratorFactory;
 use alg_tools::euclidean::Euclidean;
 use alg_tools::sets::Cube;
 use alg_tools::loc::Loc;
-use alg_tools::mapping::{Apply, Differentiable};
-use alg_tools::norms::{Norm, L2};
+use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance};
+use alg_tools::norms::Norm;
 use alg_tools::bisection_tree::{
     BTFN,
     PreBTFN,
@@ -33,14 +30,19 @@
 };
 use alg_tools::mapping::RealMapping;
 use alg_tools::nalgebra_support::ToNalgebraRealField;
+use alg_tools::norms::{L2, Linfinity};
 
 use crate::types::*;
-use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon};
+use crate::measures::{DiscreteMeasure, Radon, RNDM};
 use crate::measures::merging::{
-    //SpikeMergingMethod,
+    SpikeMergingMethod,
     SpikeMerging,
 };
-use crate::forward_model::ForwardModel;
+use crate::forward_model::{
+    ForwardModel,
+    AdjointProductBoundedBy,
+    LipschitzValues,
+};
 use crate::seminorms::DiscreteMeasureOp;
 //use crate::tolerance::Tolerance;
 use crate::plot::{
@@ -56,7 +58,44 @@
     calculate_residual,
     calculate_residual2,
 };
-use crate::transport::TransportLipschitz;
+//use crate::transport::TransportLipschitz;
+
+/// Transport settings for [`pointsource_sliding_fb_reg`].
+#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
+#[serde(default)]
+pub struct TransportConfig<F : Float> {
+    /// Transport step length $θ$ normalised to $(0, 1)$.
+    pub θ0 : F,
+    /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance.
+    pub adaptation : F,
+    /// Transport tolerance wrt. ω
+    pub tolerance_ω : F,
+    /// Transport tolerance wrt. ∇v
+    pub tolerance_dv : F,
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl <F : Float> TransportConfig<F> {
+    /// Check that the parameters are ok. Panics if not.
+    pub fn check(&self) {
+        assert!(self.θ0 > 0.0);
+        assert!(0.0 < self.adaptation && self.adaptation < 1.0);
+        assert!(self.tolerance_dv > 0.0);
+        assert!(self.tolerance_ω > 0.0);
+    }
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F : Float> Default for TransportConfig<F> {
+    fn default() -> Self {
+        TransportConfig {
+            θ0 : 0.01,
+            adaptation : 0.9,
+            tolerance_ω : 1000.0, // TODO: no idea what this should be
+            tolerance_dv : 1000.0, // TODO: no idea what this should be
+        }
+    }
+}
 
 /// Settings for [`pointsource_sliding_fb_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
@@ -64,15 +103,8 @@
 pub struct SlidingFBConfig<F : Float> {
     /// Step length scaling
     pub τ0 : F,
-    /// Transport step length $θ$ normalised to $(0, 1)$.
-    pub θ0 : F,
-    /// Maximum transport mass scaling.
-    // /// The maximum transported mass is this factor times $\norm{b}^2/(2α)$.
-    // pub max_transport_scale : F,
-    /// Transport tolerance wrt. ω
-    pub transport_tolerance_ω : F,
-    /// Transport tolerance wrt. ∇v
-    pub transport_tolerance_dv : F,
+    /// Transport parameters
+    pub transport : TransportConfig<F>,
     /// Generic parameters
     pub insertion : FBGenericConfig<F>,
 }
@@ -82,38 +114,243 @@
     fn default() -> Self {
         SlidingFBConfig {
             τ0 : 0.99,
-            θ0 : 0.99,
-            //max_transport_scale : 10.0,
-            transport_tolerance_ω : 1.0, // TODO: no idea what this should be
-            transport_tolerance_dv : 1.0, // TODO: no idea what this should be
+            transport : Default::default(),
             insertion : Default::default()
         }
     }
 }
 
-/// Scale each |γ|_i ≠ 0 by q_i=q̄/g(γ_i)
+/// Internal type of adaptive transport step length calculation
+pub(crate) enum TransportStepLength<F : Float, G : Fn(F, F) -> F> {
+    /// Fixed, known step length
+    Fixed(F),
+    /// Adaptive step length, only wrt. maximum transport.
+    /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
+    AdaptiveMax{ l : F, max_transport : F, g : G },
+    /// Adaptive step length.
+    /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
+    FullyAdaptive{ l : F, max_transport : F, g : G },
+}
+
+/// Constrution and a priori transport adaptation.
 #[replace_float_literals(F::cast_from(literal))]
-fn scale_down<'a, I, F, G, const N : usize>(
-    iter : I,
-    q̄ : F,
-    mut g : G
-) where F : Float,
-        I : Iterator<Item = &'a mut DeltaMeasure<Loc<F,N>, F>>,
-        G : FnMut(&DeltaMeasure<Loc<F,N>, F>) -> F {
-    iter.for_each(|δ| {
-        if δ.α != 0.0 {
-            let b = g(δ);
-            if b * δ.α > 0.0 {
-                δ.α *= q̄/b;
+pub(crate) fn initial_transport<F, G, D, Observable, const N : usize>(
+    γ1 : &mut RNDM<F, N>,
+    μ : &mut RNDM<F, N>,
+    opAapply : impl Fn(&RNDM<F, N>) -> Observable,
+    ε : F,
+    τ : F,
+    θ_or_adaptive : &mut TransportStepLength<F, G>,
+    opAnorm : F,
+    v : D,
+    tconfig : &TransportConfig<F>
+) -> (Vec<F>, RNDM<F, N>)
+where
+    F : Float + ToNalgebraRealField,
+    G : Fn(F, F) -> F,
+    Observable : Euclidean<F, Output=Observable>,
+    for<'a> &'a Observable : Instance<Observable>,
+    //for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
+    D : DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F, N>>,
+{
+
+    use TransportStepLength::*;
+
+    // Save current base point and shift μ to new positions. Idea is that
+    //  μ_base(_masses) = μ^k (vector of masses)
+    //  μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
+    //  γ1 = π_♯^1γ^{k+1}
+    //  μ = μ^{k+1}
+    let μ_base_masses : Vec<F> = μ.iter_masses().collect();
+    let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
+    // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
+    //let mut sum_norm_dv = 0.0;
+    let γ_prev_len = γ1.len();
+    assert!(μ.len() >= γ_prev_len);
+    γ1.extend(μ[γ_prev_len..].iter().cloned());
+
+    // Calculate initial transport and step length.
+    // First calculate initial transported weights
+    for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
+        // If old transport has opposing sign, the new transport will be none.
+        ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
+            0.0
+        } else {
+            δ.α
+        };
+    };
+
+    // A priori transport adaptation based on bounding 2 ‖A‖ ‖A(γ₁-γ₀)‖‖γ‖ by scaling γ.
+    // 1. Calculate transport rays.
+    //    If the Lipschitz factor of the values v=∇F(μ) are not known, estimate it.
+    match *θ_or_adaptive {
+        Fixed(θ) => {
+            let θτ = τ * θ;
+            for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
+                ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
+            }
+        },
+        AdaptiveMax{ l : ℓ_v, ref mut max_transport, g : ref calculate_θ } => {
+            *max_transport = max_transport.max(γ1.norm(Radon));
+            let θτ = τ * calculate_θ(ℓ_v, *max_transport);
+            for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
+                ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
+            }
+        },
+        FullyAdaptive{ l : ref mut adaptive_ℓ_v, ref mut max_transport, g : ref calculate_θ } => {
+            *max_transport = max_transport.max(γ1.norm(Radon));
+            let mut θ = calculate_θ(*adaptive_ℓ_v, *max_transport);
+            loop {
+                let θτ = τ * θ;
+                for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
+                    let dv_x = v.differential(&δ.x);
+                    ρ.x = δ.x - &dv_x * (ρ.α.signum() * θτ);
+                    // Estimate Lipschitz factor of ∇v
+                    let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2();
+                    *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v);
+                }
+                let new_θ = calculate_θ(*adaptive_ℓ_v / tconfig.adaptation, *max_transport);
+                if new_θ <= θ {
+                    break
+                }
+                θ = new_θ;
             }
         }
-    });
+    }
+
+    // 2. Adjust transport mass, if needed.
+    // This tries to remove the smallest transport masses first.
+    if true {
+        // Alternative 1 : subtract same amount from all transport rays until reaching zero
+        loop {
+            let nr =γ1.norm(Radon);
+            let n = τ * 2.0 * opAnorm * (opAapply(&*γ1)-opAapply(&*μ)).norm2();
+            if n <= 0.0 || nr <= 0.0 {
+                break
+            }
+            let reduction_needed = nr - (ε * tconfig.tolerance_dv / n);
+            if reduction_needed <= 0.0 {
+                break
+            }
+            let (min_nonzero, n_nonzero) = γ1.iter_masses()
+                                            .map(|α| α.abs())
+                                            .filter(|α| *α > F::EPSILON)
+                                            .fold((F::INFINITY, 0), |(a, n), b| (a.min(b), n+1));
+            assert!(n_nonzero > 0);
+            // Reduction that can be done in all nonzero spikes simultaneously
+            let h = (reduction_needed / F::cast_from(n_nonzero)).min(min_nonzero);
+            for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) {
+                ρ.α = ρ.α.signum() * (ρ.α.abs() - h).max(0.0);
+                δ.α = ρ.α;
+            }
+            if min_nonzero * F::cast_from(n_nonzero) >= reduction_needed {
+                break
+            }
+        }
+    } else {
+        // Alternative 2: first reduce transport rays with greater effect based on differential.
+        // This is a an inefficient quick-and-dirty implementation.
+        loop {
+            let nr = γ1.norm(Radon);
+            let a = opAapply(&*γ1)-opAapply(&*μ);
+            let na = a.norm2();
+            let n = τ * 2.0 * opAnorm * na;
+            if n <= 0.0 || nr <= 0.0 {
+                break
+            }
+            let reduction_needed = nr - (ε * tconfig.tolerance_dv / n);
+            if reduction_needed <= 0.0 {
+                break
+            }
+            let mut max_d = 0.0;
+            let mut max_d_ind = 0;
+            for (δ, ρ, i) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), 0..) {
+                // Calculate differential of  ‖A(γ₁-γ₀)‖‖γ‖  wrt. each spike
+                let s = δ.α.signum();
+                // TODO: this is very inefficient implementation due to the limitations
+                // of the closure parameters.
+                let δ1 = DiscreteMeasure::from([(ρ.x, s)]);
+                let δ2 = DiscreteMeasure::from([(δ.x, s)]);
+                let a_part = opAapply(&δ1)-opAapply(&δ2);
+                let d = a.dot(&a_part)/na * nr + 2.0 * na;
+                if d > max_d {
+                    max_d = d;
+                    max_d_ind = i;
+                }
+            }
+            // Just set mass to zero for transport ray with greater differential
+            assert!(max_d > 0.0);
+            γ1[max_d_ind].α = 0.0;
+            μ[max_d_ind].α = 0.0;
+        }
+    }
+
+    // Set initial guess for μ=μ^{k+1}.
+    for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) {
+        if ρ.α.abs() > F::EPSILON {
+            δ.x = ρ.x;
+            //δ.α = ρ.α; // already set above
+        } else {
+            δ.α = β;
+        }
+    }
+    // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b)
+    μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses())
+                                                   .map(|(&a,b)| a - b));
+    (μ_base_masses, μ_base_minus_γ0)
+}
+
+/// A posteriori transport adaptation.
+#[replace_float_literals(F::cast_from(literal))]
+pub(crate) fn aposteriori_transport<F, const N : usize>(
+    γ1 : &mut RNDM<F, N>,
+    μ : &mut RNDM<F, N>,
+    μ_base_minus_γ0 : &mut RNDM<F, N>,
+    μ_base_masses : &Vec<F>,
+    ε : F,
+    tconfig : &TransportConfig<F>
+) -> bool
+where F : Float + ToNalgebraRealField {
+
+    // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
+    // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
+    // at that point to zero, and retry.
+    let mut all_ok = true;
+    for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
+        if α_μ == 0.0 && *α_γ1 != 0.0 {
+            all_ok = false;
+            *α_γ1 = 0.0;
+        }
+    }
+
+    // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z).
+    //    through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1},
+    //    which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient.
+    let nγ = γ1.norm(Radon);
+    let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1);
+    let t = ε * tconfig.tolerance_ω;
+    if nγ*nΔ > t {
+        // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
+        // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
+        // will not enter here.
+        *γ1 *= tconfig.adaptation * t / ( nγ * nΔ );
+        all_ok = false
+    }
+
+    if !all_ok {
+        // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
+        μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses())
+                                                        .map(|(&a,b)| a - b));
+
+    }
+
+    all_ok
 }
 
 /// Iteratively solve the pointsource localisation problem using sliding forward-backward
 /// splitting
 ///
-/// The parametrisatio is as for [`pointsource_fb_reg`].
+/// The parametrisation is as for [`pointsource_fb_reg`].
 /// Inertia is currently not supported.
 #[replace_float_literals(F::cast_from(literal))]
 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>(
@@ -121,203 +358,113 @@
     b : &A::Observable,
     reg : Reg,
     op𝒟 : &'a 𝒟,
-    sfbconfig : &SlidingFBConfig<F>,
+    config : &SlidingFBConfig<F>,
     iterator : I,
     mut plotter : SeqPlotter<F, N>,
-) -> DiscreteMeasure<Loc<F, N>, F>
+) -> RNDM<F, N>
 where F : Float + ToNalgebraRealField,
       I : AlgIteratorFactory<IterInfo<F, N>>,
-      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
-                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
-      A::Observable : std::ops::MulAssign<F>,
-      A::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output=Loc<F, N>>,
+      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
+      for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
+      A::PreadjointCodomain : DifferentiableMapping<
+        Loc<F, N>, DerivativeDomain=Loc<F, N>, Codomain=F
+      >,
       GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
-          + Lipschitz<&'a 𝒟, FloatType=F> + TransportLipschitz<L2Squared, FloatType=F>,
+      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
+          + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
+          //+ TransportLipschitz<L2Squared, FloatType=F>,
       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
       𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>,
                                           Codomain = BTFN<F, G𝒟, BT𝒟, N>>,
       BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>
-         + Differentiable<Loc<F, N>, Output=Loc<F,N>>,
+         + DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F,N>>,
       K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-         //+ Differentiable<Loc<F, N>, Output=Loc<F,N>>,
+         //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>,
       BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
       Cube<F, N>: P2Minimise<Loc<F, N>, F>,
       PlotLookup : Plotting<N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      RNDM<F, N> : SpikeMerging<F>,
       Reg : SlidingRegTerm<F, N> {
 
-    assert!(sfbconfig.τ0 > 0.0 &&
-            sfbconfig.θ0 > 0.0);
-
-    // Set up parameters
-    let config = &sfbconfig.insertion;
-    let op𝒟norm = op𝒟.opnorm_bound();
-    //let max_transport = sfbconfig.max_transport_scale
-    //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0);
-    //let tlip = opA.transport_lipschitz_factor(L2Squared) * max_transport;
-    //let ℓ = 0.0;
-    let θ = sfbconfig.θ0; // (ℓ + tlip);
-    let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
-    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
-    // by τ compared to the conditional gradient approach.
-    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
-    let mut ε = tolerance.initial();
+    // Check parameters
+    assert!(config.τ0 > 0.0, "Invalid step length parameter");
+    config.transport.check();
 
     // Initialise iterates
     let mut μ = DiscreteMeasure::new();
     let mut γ1 = DiscreteMeasure::new();
-    let mut residual = -b;
+    let mut residual = -b; // Has to equal $Aμ-b$.
+
+    // Set up parameters
+    let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
+    let opAnorm = opA.opnorm_bound(Radon, L2);
+    //let max_transport = config.max_transport.scale
+    //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0);
+    //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
+    let ℓ = 0.0;
+    let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap();
+    let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v));
+    let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() {
+        // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v
+        // (the uniform Lipschitz factor of ∇v).
+        // We assume that the residual is decreasing.
+        Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * residual.norm2(), 0.0)),
+        None => TransportStepLength::FullyAdaptive {
+            l : 0.0,
+            max_transport : 0.0,
+            g : calculate_θ
+        },
+    };
+    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
+    // by τ compared to the conditional gradient approach.
+    let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling();
+    let mut ε = tolerance.initial();
+
+    // Statistics
+    let full_stats = |residual : &A::Observable,
+                      μ : &RNDM<F, N>,
+                      ε, stats| IterInfo {
+        value : residual.norm2_squared_div2() + reg.apply(μ),
+        n_spikes : μ.len(),
+        ε,
+        // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
+        .. stats
+    };
     let mut stats = IterInfo::new();
 
     // Run the algorithm
-    iterator.iterate(|state| {
-        // Calculate smooth part of surrogate model.
-        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
-        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
-        // the residual and replacing it below before the end of this closure.
-        let r = std::mem::replace(&mut residual, opA.empty_observable());
-        let v = opA.preadjoint().apply(r);
-
-        // Save current base point and shift μ to new positions. Idea is that
-        //  μ_base(_masses) = μ^k (vector of masses)
-        //  μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
-        //  γ1 = π_♯^1γ^{k+1}
-        //  μ = μ^{k+1}
-        let μ_base_masses : Vec<F> = μ.iter_masses().collect();
-        let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
-        // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
-        let mut sum_norm_dv_times_γinit = 0.0;
-        let mut sum_abs_γinit = 0.0;
-        //let mut sum_norm_dv = 0.0;
-        let γ_prev_len = γ1.len();
-        assert!(μ.len() >= γ_prev_len);
-        γ1.extend(μ[γ_prev_len..].iter().cloned());
-        for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) {
-            let d_v_x = v.differential(&δ.x);
-            // If old transport has opposing sign, the new transport will be none.
-            ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
-                0.0
-            } else {
-                δ.α
-            };
-            δ.x -= d_v_x * (θ * δ.α.signum()); // This is δ.α.signum() when δ.α ≠ 0.
-            ρ.x = δ.x;
-            let nrm = d_v_x.norm(L2);
-            let a = ρ.α.abs();
-            let v = nrm * a;
-            if v > 0.0 {
-                sum_norm_dv_times_γinit += v;
-                sum_abs_γinit += a;
-            }
-        }
-
-        // A priori transport adaptation based on bounding ∫ ⟨∇v(x), z-y⟩ dλ(x, y, z).
-        // This is just one option, there are many.
-        let t = ε * sfbconfig.transport_tolerance_dv;
-        if sum_norm_dv_times_γinit > t {
-            // Scale each |γ|_i by q_i=q̄/‖vx‖_i such that ∑_i |γ|_i q_i ‖vx‖_i = t
-            // TODO: store the closure values above?
-            scale_down(γ1.iter_spikes_mut(),
-                       t / sum_abs_γinit,
-                       |δ| v.differential(&δ.x).norm(L2));
-        }
-        //println!("|γ| = {}, |μ| = {}", γ1.norm(crate::measures::Radon), μ.norm(crate::measures::Radon));
+    for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
+        // Calculate initial transport
+        let v = opA.preadjoint().apply(residual);
+        let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport(
+            &mut γ1, &mut μ, |ν| opA.apply(ν),
+            ε, τ, &mut θ_or_adaptive, opAnorm,
+            v, &config.transport,
+        );
 
         // Solve finite-dimensional subproblem several times until the dual variable for the
         // regularisation term conforms to the assumptions made for the transport above.
-        let (d, within_tolerances) = 'adapt_transport: loop {
-            // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
-            for (δ_γ1, δ_μ_base_minus_γ0, &α_μ_base) in izip!(γ1.iter_spikes(),
-                                                              μ_base_minus_γ0.iter_spikes_mut(),
-                                                              μ_base_masses.iter()) {
-                δ_μ_base_minus_γ0.set_mass(α_μ_base - δ_γ1.get_mass());
-            }
-
-            // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b)
+        let (d, _within_tolerances, τv̆) = 'adapt_transport: loop {
+            // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
             let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
-            let transported_minus_τv̆ = opA.preadjoint().apply(residual_μ̆ * (-τ));
+            let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ);
 
             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
             let (d, within_tolerances) = insert_and_reweigh(
-                &mut μ, &transported_minus_τv̆, &γ1, Some(&μ_base_minus_γ0),
+                &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0),
                 op𝒟, op𝒟norm,
-                τ, ε,
-                config,
-                &reg, state, &mut stats,
+                τ, ε, &config.insertion,
+                &reg, &state, &mut stats,
             );
 
-            // A posteriori transport adaptation based on bounding (1/τ)∫ ω(z) - ω(y) dλ(x, y, z).
-            let all_ok = if false { // Basic check
-                // If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
-                // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
-                // at that point to zero, and retry.
-                let mut all_ok = true;
-                for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
-                    if α_μ == 0.0 && *α_γ1 != 0.0 {
-                        all_ok = false;
-                        *α_γ1 = 0.0;
-                    }
-                }
-                all_ok
-            } else {
-                // TODO: Could maybe optimise, as this is also formed in insert_and_reweigh above.
-                let mut minus_ω = op𝒟.apply(γ1.sub_matching(&μ) + &μ_base_minus_γ0);
-
-                // let vpos = γ1.iter_spikes()
-                //              .filter(|δ| δ.α > 0.0)
-                //              .map(|δ| minus_ω.apply(&δ.x))
-                //              .reduce(F::max)
-                //              .and_then(|threshold| {
-                //                 minus_ω.minimise_below(threshold,
-                //                                         ε * config.refinement.tolerance_mult,
-                //                                         config.refinement.max_steps)
-                //                        .map(|(_z, minus_ω_z)| minus_ω_z)
-                //              });
-
-                // let vneg = γ1.iter_spikes()
-                //              .filter(|δ| δ.α < 0.0)
-                //              .map(|δ| minus_ω.apply(&δ.x))
-                //              .reduce(F::min)
-                //              .and_then(|threshold| {
-                //                 minus_ω.maximise_above(threshold,
-                //                                         ε * config.refinement.tolerance_mult,
-                //                                         config.refinement.max_steps)
-                //                        .map(|(_z, minus_ω_z)| minus_ω_z)
-                //              });
-                let (_, vpos) = minus_ω.minimise(ε * config.refinement.tolerance_mult,
-                                                 config.refinement.max_steps);
-                let (_, vneg) = minus_ω.maximise(ε * config.refinement.tolerance_mult,
-                                                 config.refinement.max_steps);
-            
-                let t = τ * ε * sfbconfig.transport_tolerance_ω;
-                let val = |δ : &DeltaMeasure<Loc<F, N>, F>| {
-                    δ.α * (minus_ω.apply(&δ.x) - if δ.α >= 0.0 { vpos } else { vneg })
-                    // match if δ.α >= 0.0 { vpos } else { vneg } {
-                    //     None => 0.0,
-                    //     Some(v) => δ.α * (minus_ω.apply(&δ.x) - v)
-                    // }
-                };
-                // Calculate positive/bad (rp) values under the integral.
-                // Also store sum of masses for the positive entries.
-                let (rp, w) = γ1.iter_spikes().fold((0.0, 0.0), |(p, w), δ| {
-                    let v = val(δ);
-                    if v <= 0.0 { (p, w) } else { (p + v, w + δ.α.abs()) }
-                });
-
-                if rp > t {
-                    // TODO: store v above?
-                    scale_down(γ1.iter_spikes_mut(), t / w, val);
-                    false
-                } else {
-                    true
-                }
-            };
-
-            if all_ok {
-                break 'adapt_transport (d, within_tolerances)
+            // A posteriori transport adaptation.
+            if aposteriori_transport(
+                &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses,
+                ε, &config.transport
+            ) {
+                break 'adapt_transport (d, within_tolerances, τv̆)
             }
         };
 
@@ -330,10 +477,24 @@
         stats.transport_error = Some({
             assert_eq!(μ_base_masses.len(), γ1.len());
             let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
-            let err = izip!(μ.iter_masses(), γ1.iter_masses()).map(|(v,w)| (v-w).abs()).sum();
-            (a + err, b + γ1.norm(Radon))
+            (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
         });
 
+        // Merge spikes.
+        // This expects the prune below to prune γ.
+        // TODO: This may not work correctly in all cases.
+        let ins = &config.insertion;
+        if ins.merge_now(&state) {
+            if let SpikeMergingMethod::None = ins.merging {
+            } else {
+                stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| {
+                    let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0;
+                    let mut d = &τv̆ + op𝒟.preapply(ν);
+                    reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins)
+                });
+            }
+        }
+
         // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
         // latter needs to be pruned when μ is.
         // TODO: This could do with a two-vector Vec::retain to avoid copies.
@@ -341,40 +502,25 @@
         if μ_new.len() != μ.len() {
             let mut μ_iter = μ.iter_spikes();
             γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
+            stats.pruned += μ.len() - μ_new.len();
             μ = μ_new;
         }
 
-        // TODO: how to merge?
-
         // Update residual
         residual = calculate_residual(&μ, opA, b);
 
-        // Update main tolerance for next iteration
-        let ε_prev = ε;
-        ε = tolerance.update(ε, state.iteration());
+        let iter = state.iteration();
         stats.this_iters += 1;
 
-        // Give function value if needed
+        // Give statistics if requested
         state.if_verbose(|| {
-            // Plot if so requested
-            plotter.plot_spikes(
-                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
-                "start".to_string(), None::<&A::PreadjointCodomain>, // TODO: Should be Some(&((-τ) * v)), but not implemented
-                reg.target_bounds(τ, ε_prev), &μ,
-            );
-            // Calculate mean inner iterations and reset relevant counters.
-            // Return the statistics
-            let res = IterInfo {
-                value : residual.norm2_squared_div2() + reg.apply(&μ),
-                n_spikes : μ.len(),
-                ε : ε_prev,
-                postprocessing: config.postprocessing.then(|| μ.clone()),
-                .. stats
-            };
-            stats = IterInfo::new();
-            res
-        })
-    });
+            plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ);
+            full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
+        });
 
-    postprocess(μ, config, L2Squared, opA, b)
+        // Update main tolerance for next iteration
+        ε = tolerance.update(ε, iter);
+    }
+
+    postprocess(μ, &config.insertion, L2Squared, opA, b)
 }

mercurial