src/fb.rs

branch
dev
changeset 32
56c8adc32b09
parent 29
87649ccfa6a8
child 34
efa60bc4f743
--- a/src/fb.rs	Fri Apr 28 13:15:19 2023 +0300
+++ b/src/fb.rs	Tue Dec 31 09:34:24 2024 -0500
@@ -83,17 +83,16 @@
 use numeric_literals::replace_float_literals;
 use serde::{Serialize, Deserialize};
 use colored::Colorize;
-use nalgebra::{DVector, DMatrix};
+use nalgebra::DVector;
 
 use alg_tools::iterate::{
     AlgIteratorFactory,
     AlgIteratorState,
 };
 use alg_tools::euclidean::Euclidean;
-use alg_tools::linops::Apply;
+use alg_tools::linops::{Apply, GEMV};
 use alg_tools::sets::Cube;
 use alg_tools::loc::Loc;
-use alg_tools::mapping::Mapping;
 use alg_tools::bisection_tree::{
     BTFN,
     PreBTFN,
@@ -104,7 +103,7 @@
     P2Minimise,
     SupportGenerator,
     LocalAnalysis,
-    Bounded,
+    BothGenerators,
 };
 use alg_tools::mapping::RealMapping;
 use alg_tools::nalgebra_support::ToNalgebraRealField;
@@ -119,12 +118,8 @@
     SpikeMerging,
 };
 use crate::forward_model::ForwardModel;
-use crate::seminorms::{
-    DiscreteMeasureOp, Lipschitz
-};
+use crate::seminorms::DiscreteMeasureOp;
 use crate::subproblem::{
-    nonneg::quadratic_nonneg,
-    unconstrained::quadratic_unconstrained,
     InnerSettings,
     InnerMethod,
 };
@@ -134,9 +129,11 @@
     Plotting,
     PlotLookup
 };
-use crate::regularisation::{
-    NonnegRadonRegTerm,
-    RadonRegTerm,
+use crate::regularisation::RegTerm;
+use crate::dataterm::{
+    calculate_residual,
+    L2Squared,
+    DataTerm,
 };
 
 /// Method for constructing $μ$ on each iteration
@@ -150,24 +147,12 @@
     Zero,
 }
 
-/// Meta-algorithm type
-#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
-#[allow(dead_code)]
-pub enum FBMetaAlgorithm {
-    /// No meta-algorithm
-    None,
-    /// FISTA-style inertia
-    InertiaFISTA,
-}
-
 /// Settings for [`pointsource_fb_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
 pub struct FBConfig<F : Float> {
     /// Step length scaling
     pub τ0 : F,
-    /// Meta-algorithm to apply
-    pub meta : FBMetaAlgorithm,
     /// Generic parameters
     pub insertion : FBGenericConfig<F>,
 }
@@ -209,7 +194,6 @@
     fn default() -> Self {
         FBConfig {
             τ0 : 0.99,
-            meta : FBMetaAlgorithm::None,
             insertion : Default::default()
         }
     }
@@ -240,486 +224,236 @@
     }
 }
 
-/// Trait for specialisation of [`generic_pointsource_fb_reg`] to basic FB, FISTA.
-///
-/// The idea is that the residual $Aμ - b$ in the forward step can be replaced by an arbitrary
-/// value. For example, to implement [primal-dual proximal splitting][crate::pdps] we replace it
-/// with the dual variable $y$. We can then also implement alternative data terms, as the
-/// (pre)differential of $F(μ)=F\_0(Aμ-b)$ is $F\'(μ) = A\_*F\_0\'(Aμ-b)$. In the case of the
-/// quadratic fidelity $F_0(y)=\frac{1}{2}\\|y\\|_2^2$ in a Hilbert space, of course,
-/// $F\_0\'(Aμ-b)=Aμ-b$ is the residual.
-pub trait FBSpecialisation<F : Float, Observable : Euclidean<F>, const N : usize> : Sized {
-    /// Updates the residual and does any necessary pruning of `μ`.
-    ///
-    /// Returns the new residual and possibly a new step length.
-    ///
-    /// The measure `μ` may also be modified to apply, e.g., inertia to it.
-    /// The updated residual should correspond to the residual at `μ`.
-    /// See the [trait documentation][FBSpecialisation] for the use and meaning of the residual.
-    ///
-    /// The parameter `μ_base` is the base point of the iteration, typically the previous iterate,
-    /// but for, e.g., FISTA has inertia applied to it.
-    fn update(
-        &mut self,
-        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
-        μ_base : &DiscreteMeasure<Loc<F, N>, F>,
-    ) -> (Observable, Option<F>);
-
-    /// Calculates the data term value corresponding to iterate `μ` and available residual.
-    ///
-    /// Inertia and other modifications, as deemed, necessary, should be applied to `μ`.
-    ///
-    /// The blanket implementation correspondsn to the 2-norm-squared data fidelity
-    /// $\\|\text{residual}\\|\_2^2/2$.
-    fn calculate_fit(
-        &self,
-        _μ : &DiscreteMeasure<Loc<F, N>, F>,
-        residual : &Observable
-    ) -> F {
-        residual.norm2_squared_div2()
-    }
-
-    /// Calculates the data term value at $μ$.
-    ///
-    /// Unlike [`Self::calculate_fit`], no inertia, etc., should be applied to `μ`.
-    fn calculate_fit_simple(
-        &self,
-        μ : &DiscreteMeasure<Loc<F, N>, F>,
-    ) -> F;
-
-    /// Returns the final iterate after any necessary postprocess pruning, merging, etc.
-    fn postprocess(self, mut μ : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>)
-    -> DiscreteMeasure<Loc<F, N>, F>
-    where  DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
-        μ.merge_spikes_fitness(merging,
-                               |μ̃| self.calculate_fit_simple(μ̃),
-                               |&v| v);
-        μ.prune();
-        μ
-    }
-
-    /// Returns measure to be used for value calculations, which may differ from μ.
-    fn value_μ<'c, 'b : 'c>(&'b self, μ : &'c DiscreteMeasure<Loc<F, N>, F>)
-    -> &'c DiscreteMeasure<Loc<F, N>, F> {
-        μ
-    }
-}
-
-/// Specialisation of [`generic_pointsource_fb_reg`] to basic μFB.
-struct BasicFB<
-    'a,
-    F : Float + ToNalgebraRealField,
-    A : ForwardModel<Loc<F, N>, F>,
-    const N : usize
-> {
-    /// The data
-    b : &'a A::Observable,
-    /// The forward operator
-    opA : &'a A,
-}
-
-/// Implementation of [`FBSpecialisation`] for basic μFB forward-backward splitting.
 #[replace_float_literals(F::cast_from(literal))]
-impl<'a, F : Float + ToNalgebraRealField , A : ForwardModel<Loc<F, N>, F>, const N : usize>
-FBSpecialisation<F, A::Observable, N> for BasicFB<'a, F, A, N> {
-    fn update(
-        &mut self,
-        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
-        _μ_base : &DiscreteMeasure<Loc<F, N>, F>
-    ) -> (A::Observable, Option<F>) {
-        μ.prune();
-        //*residual = self.opA.apply(μ) - self.b;
-        let mut residual = self.b.clone();
-        self.opA.gemv(&mut residual, 1.0, μ, -1.0);
-        (residual, None)
-    }
-
-    fn calculate_fit_simple(
-        &self,
-        μ : &DiscreteMeasure<Loc<F, N>, F>,
-    ) -> F {
-        let mut residual = self.b.clone();
-        self.opA.gemv(&mut residual, 1.0, μ, -1.0);
-        residual.norm2_squared_div2()
-    }
-}
-
-/// Specialisation of [`generic_pointsource_fb_reg`] to FISTA.
-struct FISTA<
-    'a,
-    F : Float + ToNalgebraRealField,
-    A : ForwardModel<Loc<F, N>, F>,
-    const N : usize
-> {
-    /// The data
-    b : &'a A::Observable,
-    /// The forward operator
-    opA : &'a A,
-    /// Current inertial parameter
-    λ : F,
-    /// Previous iterate without inertia applied.
-    /// We need to store this here because `μ_base` passed to [`FBSpecialisation::update`] will
-    /// have inertia applied to it, so is not useful to use.
-    μ_prev : DiscreteMeasure<Loc<F, N>, F>,
-}
-
-/// Implementation of [`FBSpecialisation`] for μFISTA inertial forward-backward splitting.
-#[replace_float_literals(F::cast_from(literal))]
-impl<'a, F : Float + ToNalgebraRealField, A : ForwardModel<Loc<F, N>, F>, const N : usize>
-FBSpecialisation<F, A::Observable, N> for FISTA<'a, F, A, N> {
-    fn update(
-        &mut self,
-        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
-        _μ_base : &DiscreteMeasure<Loc<F, N>, F>
-    ) -> (A::Observable, Option<F>) {
-        // Update inertial parameters
-        let λ_prev = self.λ;
-        self.λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() );
-        let θ = self.λ / λ_prev - self.λ;
-        // Perform inertial update on μ.
-        // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ
-        // and μ_prev have zero weight. Since both have weights from the finite-dimensional
-        // subproblem with a proximal projection step, this is likely to happen when the
-        // spike is not needed. A copy of the pruned μ without artithmetic performed is
-        // stored in μ_prev.
-        μ.pruning_sub(1.0 + θ, θ, &mut self.μ_prev);
-
-        //*residual = self.opA.apply(μ) - self.b;
-        let mut residual = self.b.clone();
-        self.opA.gemv(&mut residual, 1.0, μ, -1.0);
-        (residual, None)
-    }
-
-    fn calculate_fit_simple(
-        &self,
-        μ : &DiscreteMeasure<Loc<F, N>, F>,
-    ) -> F {
-        let mut residual = self.b.clone();
-        self.opA.gemv(&mut residual, 1.0, μ, -1.0);
-        residual.norm2_squared_div2()
-    }
-
-    fn calculate_fit(
-        &self,
-        _μ : &DiscreteMeasure<Loc<F, N>, F>,
-        _residual : &A::Observable
-    ) -> F {
-        self.calculate_fit_simple(&self.μ_prev)
-    }
-
-    // For FISTA we need to do a final pruning as well, due to the limited
-    // pruning that can be done on each step.
-    fn postprocess(mut self, μ_base : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>)
-    -> DiscreteMeasure<Loc<F, N>, F>
-    where  DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
-        let mut μ = self.μ_prev;
-        self.μ_prev = μ_base;
-        μ.merge_spikes_fitness(merging,
-                               |μ̃| self.calculate_fit_simple(μ̃),
-                               |&v| v);
-        μ.prune();
-        μ
-    }
-
-    fn value_μ<'c, 'b : 'c>(&'c self, _μ : &'c DiscreteMeasure<Loc<F, N>, F>)
-    -> &'c DiscreteMeasure<Loc<F, N>, F> {
-        &self.μ_prev
-    }
-}
-
-
-/// Abstraction of regularisation terms for [`generic_pointsource_fb_reg`].
-pub trait RegTerm<F : Float + ToNalgebraRealField, const N : usize>
-: for<'a> Apply<&'a DiscreteMeasure<Loc<F, N>, F>, Output = F> {
-    /// Approximately solve the problem
-    /// <div>$$
-    ///     \min_{x ∈ ℝ^n} \frac{1}{2} x^⊤Ax - g^⊤ x + τ G(x)
-    /// $$</div>
-    /// for $G$ depending on the trait implementation.
-    ///
-    /// The parameter `mA` is $A$. An estimate for its opeator norm should be provided in
-    /// `mA_normest`. The initial iterate and output is `x`. The current main tolerance is `ε`.
-    ///
-    /// Returns the number of iterations taken.
-    fn solve_findim(
-        &self,
-        mA : &DMatrix<F::MixedType>,
-        g : &DVector<F::MixedType>,
-        τ : F,
-        x : &mut DVector<F::MixedType>,
-        mA_normest : F,
-        ε : F,
-        config : &FBGenericConfig<F>
-    ) -> usize;
-
-    /// Find a point where `d` may violate the tolerance `ε`.
-    ///
-    /// If `skip_by_rough_check` is set, do not find the point if a rough check indicates that we
-    /// are in bounds. `ε` is the current main tolerance and `τ` a scaling factor for the
-    /// regulariser.
-    ///
-    /// Returns `None` if `d` is in bounds either based on the rough check, or a more precise check
-    /// terminating early. Otherwise returns a possibly violating point, the value of `d` there,
-    /// and a boolean indicating whether the found point is in bounds.
-    fn find_tolerance_violation<G, BT>(
-        &self,
-        d : &mut BTFN<F, G, BT, N>,
-        τ : F,
-        ε : F,
-        skip_by_rough_check : bool,
-        config : &FBGenericConfig<F>,
-    ) -> Option<(Loc<F, N>, F, bool)>
-    where BT : BTSearch<F, N, Agg=Bounds<F>>,
-          G : SupportGenerator<F, N, Id=BT::Data>,
-          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
-                           + LocalAnalysis<F, Bounds<F>, N>;
-
-    /// Verify that `d` is in bounds `ε` for a merge candidate `μ`
-    ///
-    /// `ε` is the current main tolerance and `τ` a scaling factor for the regulariser.
-    fn verify_merge_candidate<G, BT>(
-        &self,
-        d : &mut BTFN<F, G, BT, N>,
-        μ : &DiscreteMeasure<Loc<F, N>, F>,
-        τ : F,
-        ε : F,
-        config : &FBGenericConfig<F>,
-    ) -> bool
-    where BT : BTSearch<F, N, Agg=Bounds<F>>,
-          G : SupportGenerator<F, N, Id=BT::Data>,
-          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
-                           + LocalAnalysis<F, Bounds<F>, N>;
-
-    fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>>;
-
-    /// Returns a scaling factor for the tolerance sequence.
-    ///
-    /// Typically this is the regularisation parameter.
-    fn tolerance_scaling(&self) -> F;
-}
-
-#[replace_float_literals(F::cast_from(literal))]
-impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for NonnegRadonRegTerm<F>
-where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
-    fn solve_findim(
-        &self,
-        mA : &DMatrix<F::MixedType>,
-        g : &DVector<F::MixedType>,
-        τ : F,
-        x : &mut DVector<F::MixedType>,
-        mA_normest : F,
-        ε : F,
-        config : &FBGenericConfig<F>
-    ) -> usize {
-        let inner_tolerance = ε * config.inner.tolerance_mult;
-        let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
-        let inner_τ = config.inner.τ0 / mA_normest;
-        quadratic_nonneg(config.inner.method, mA, g, τ * self.α(), x,
-                         inner_τ, inner_it)
-    }
-
-    #[inline]
-    fn find_tolerance_violation<G, BT>(
-        &self,
-        d : &mut BTFN<F, G, BT, N>,
-        τ : F,
-        ε : F,
-        skip_by_rough_check : bool,
-        config : &FBGenericConfig<F>,
-    ) -> Option<(Loc<F, N>, F, bool)>
-    where BT : BTSearch<F, N, Agg=Bounds<F>>,
-          G : SupportGenerator<F, N, Id=BT::Data>,
-          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
-                           + LocalAnalysis<F, Bounds<F>, N> {
-        let τα = τ * self.α();
-        let keep_below = τα + ε;
-        let maximise_above = τα + ε * config.insertion_cutoff_factor;
-        let refinement_tolerance = ε * config.refinement.tolerance_mult;
-
-        // If preliminary check indicates that we are in bonds, and if it otherwise matches
-        // the insertion strategy, skip insertion.
-        if skip_by_rough_check && d.bounds().upper() <= keep_below {
-            None
-        } else {
-            // If the rough check didn't indicate no insertion needed, find maximising point.
-            d.maximise_above(maximise_above, refinement_tolerance, config.refinement.max_steps)
-             .map(|(ξ, v_ξ)| (ξ, v_ξ, v_ξ <= keep_below))
+pub(crate) fn μ_diff<F : Float, const N : usize>(
+    μ_new : &DiscreteMeasure<Loc<F, N>, F>,
+    μ_base : &DiscreteMeasure<Loc<F, N>, F>,
+    ν_delta : Option<&DiscreteMeasure<Loc<F, N>, F>>,
+    config : &FBGenericConfig<F>
+) -> DiscreteMeasure<Loc<F, N>, F> {
+    let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style {
+        InsertionStyle::Reuse => {
+            μ_new.iter_spikes()
+                 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0)))
+                 .map(|(δ, α_base)| (δ.x, α_base - δ.α))
+                 .collect()
+        },
+        InsertionStyle::Zero => {
+            μ_new.iter_spikes()
+                 .map(|δ| -δ)
+                 .chain(μ_base.iter_spikes().copied())
+                 .collect()
         }
-    }
-
-    fn verify_merge_candidate<G, BT>(
-        &self,
-        d : &mut BTFN<F, G, BT, N>,
-        μ : &DiscreteMeasure<Loc<F, N>, F>,
-        τ : F,
-        ε : F,
-        config : &FBGenericConfig<F>,
-    ) -> bool
-    where BT : BTSearch<F, N, Agg=Bounds<F>>,
-          G : SupportGenerator<F, N, Id=BT::Data>,
-          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
-                           + LocalAnalysis<F, Bounds<F>, N> {
-        let τα = τ * self.α();
-        let refinement_tolerance = ε * config.refinement.tolerance_mult;
-        let merge_tolerance = config.merge_tolerance_mult * ε;
-        let keep_below = τα + merge_tolerance;
-        let keep_supp_above = τα - merge_tolerance;
-        let bnd = d.bounds();
-
-        return (
-            bnd.lower() >= keep_supp_above
-            ||
-            μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
-                (β == 0.0) || d.apply(x) >= keep_supp_above
-            }).all(std::convert::identity)
-         ) && (
-            bnd.upper() <= keep_below
-            ||
-            d.has_upper_bound(keep_below, refinement_tolerance, config.refinement.max_steps)
-        )
-    }
-
-    fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
-        let τα = τ * self.α();
-        Some(Bounds(τα - ε,  τα + ε))
-    }
-
-    fn tolerance_scaling(&self) -> F {
-        self.α()
+    };
+    ν.prune(); // Potential small performance improvement
+    // Add ν_delta if given
+    match ν_delta {
+        None => ν,
+        Some(ν_d) => ν + ν_d,
     }
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for RadonRegTerm<F>
-where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
-    fn solve_findim(
-        &self,
-        mA : &DMatrix<F::MixedType>,
-        g : &DVector<F::MixedType>,
-        τ : F,
-        x : &mut DVector<F::MixedType>,
-        mA_normest: F,
-        ε : F,
-        config : &FBGenericConfig<F>
-    ) -> usize {
-        let inner_tolerance = ε * config.inner.tolerance_mult;
-        let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
-        let inner_τ = config.inner.τ0 / mA_normest;
-        quadratic_unconstrained(config.inner.method, mA, g, τ * self.α(), x,
-                                inner_τ, inner_it)
+pub(crate) fn insert_and_reweigh<
+    'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
+>(
+    μ : &mut DiscreteMeasure<Loc<F, N>, F>,
+    minus_τv : &BTFN<F, GA, BTA, N>,
+    μ_base : &DiscreteMeasure<Loc<F, N>, F>,
+    ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>,
+    op𝒟 : &'a 𝒟,
+    op𝒟norm : F,
+    τ : F,
+    ε : F,
+    config : &FBGenericConfig<F>,
+    reg : &Reg,
+    state : &State,
+    stats : &mut IterInfo<F, N>,
+) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool)
+where F : Float + ToNalgebraRealField,
+      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
+      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
+      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
+      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
+      𝒟::Codomain : RealMapping<F, N>,
+      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
+      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
+      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
+      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      Reg : RegTerm<F, N>,
+      State : AlgIteratorState {
+
+    // Maximum insertion count and measure difference calculation depend on insertion style.
+    let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
+        (i, Some((l, k))) if i <= l => (k, false),
+        _ => (config.max_insertions, !state.is_quiet()),
+    };
+    let max_insertions = match config.insertion_style {
+        InsertionStyle::Zero => {
+            todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled.");
+            // let n = μ.len();
+            // μ = DiscreteMeasure::new();
+            // n + m
+        },
+        InsertionStyle::Reuse => m,
+    };
+
+    // TODO: should avoid a second copy of μ here; μ_base already stores a copy.
+    let ω0 = op𝒟.apply(match ν_delta {
+        None => μ.clone(),
+        Some(ν_d) => &*μ + ν_d,
+    });
+
+    // Add points to support until within error tolerance or maximum insertion count reached.
+    let mut count = 0;
+    let (within_tolerances, d) = 'insertion: loop {
+        if μ.len() > 0 {
+            // Form finite-dimensional subproblem. The subproblem references to the original μ^k
+            // from the beginning of the iteration are all contained in the immutable c and g.
+            let à = op𝒟.findim_matrix(μ.iter_locations());
+            let g̃ = DVector::from_iterator(μ.len(),
+                                           μ.iter_locations()
+                                            .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ))
+                                            .map(F::to_nalgebra_mixed));
+            let mut x = μ.masses_dvector();
+
+            // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
+            // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
+            // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤  sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
+            // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2
+            // = n |𝒟| |x|_2, where n is the number of points. Therefore
+            let Ã_normest = op𝒟norm * F::cast_from(μ.len());
+
+            // Solve finite-dimensional subproblem.
+            stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config);
+
+            // Update masses of μ based on solution of finite-dimensional subproblem.
+            μ.set_masses_dvector(&x);
+        }
+
+        // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality
+        // conditions in the predual space, and finding new points for insertion, if necessary.
+        let mut d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_base, ν_delta, config));
+
+        // If no merging heuristic is used, let's be more conservative about spike insertion,
+        // and skip it after first round. If merging is done, being more greedy about spike
+        // insertion also seems to improve performance.
+        let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging {
+            false
+        } else {
+            count > 0
+        };
+
+        // Find a spike to insert, if needed
+        let (ξ, _v_ξ, in_bounds) =  match reg.find_tolerance_violation(
+            &mut d, τ, ε, skip_by_rough_check, config
+        ) {
+            None => break 'insertion (true, d),
+            Some(res) => res,
+        };
+
+        // Break if maximum insertion count reached
+        if count >= max_insertions {
+            break 'insertion (in_bounds, d)
+        }
+
+        // No point in optimising the weight here; the finite-dimensional algorithm is fast.
+        *μ += DeltaMeasure { x : ξ, α : 0.0 };
+        count += 1;
+    };
+
+    // TODO: should redo everything if some transports cause a problem.
+    // Maybe implementation should call above loop as a closure.
+
+    if !within_tolerances && warn_insertions {
+        // Complain (but continue) if we failed to get within tolerances
+        // by inserting more points.
+        let err = format!("Maximum insertions reached without achieving \
+                            subproblem solution tolerance");
+        println!("{}", err.red());
     }
 
-   fn find_tolerance_violation<G, BT>(
-        &self,
-        d : &mut BTFN<F, G, BT, N>,
-        τ : F,
-        ε : F,
-        skip_by_rough_check : bool,
-        config : &FBGenericConfig<F>,
-    ) -> Option<(Loc<F, N>, F, bool)>
-    where BT : BTSearch<F, N, Agg=Bounds<F>>,
-          G : SupportGenerator<F, N, Id=BT::Data>,
-          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
-                           + LocalAnalysis<F, Bounds<F>, N> {
-        let τα = τ * self.α();
-        let keep_below = τα + ε;
-        let keep_above = -τα - ε;
-        let maximise_above = τα + ε * config.insertion_cutoff_factor;
-        let minimise_below = -τα - ε * config.insertion_cutoff_factor;
-        let refinement_tolerance = ε * config.refinement.tolerance_mult;
+    (d, within_tolerances)
+}
 
-        // If preliminary check indicates that we are in bonds, and if it otherwise matches
-        // the insertion strategy, skip insertion.
-        if skip_by_rough_check && Bounds(keep_above, keep_below).superset(&d.bounds()) {
-            None
-        } else {
-            // If the rough check didn't indicate no insertion needed, find maximising point.
-            let mx = d.maximise_above(maximise_above, refinement_tolerance,
-                                      config.refinement.max_steps);
-            let mi = d.minimise_below(minimise_below, refinement_tolerance,
-                                      config.refinement.max_steps);
+#[replace_float_literals(F::cast_from(literal))]
+pub(crate) fn prune_and_maybe_simple_merge<
+    'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
+>(
+    μ : &mut DiscreteMeasure<Loc<F, N>, F>,
+    minus_τv : &BTFN<F, GA, BTA, N>,
+    μ_base : &DiscreteMeasure<Loc<F, N>, F>,
+    op𝒟 : &'a 𝒟,
+    τ : F,
+    ε : F,
+    config : &FBGenericConfig<F>,
+    reg : &Reg,
+    state : &State,
+    stats : &mut IterInfo<F, N>,
+)
+where F : Float + ToNalgebraRealField,
+      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
+      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
+      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
+      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
+      𝒟::Codomain : RealMapping<F, N>,
+      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
+      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
+      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
+      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      Reg : RegTerm<F, N>,
+      State : AlgIteratorState {
+    if state.iteration() % config.merge_every == 0 {
+        let n_before_merge = μ.len();
+        μ.merge_spikes(config.merging, |μ_candidate| {
+            let μd = μ_diff(&μ_candidate, &μ_base, None, config);
+            let mut d = minus_τv + op𝒟.preapply(μd);
 
-            match (mx, mi) {
-                (None, None) => None,
-                (Some((ξ, v_ξ)), None) => Some((ξ, v_ξ, keep_below >= v_ξ)),
-                (None, Some((ζ, v_ζ))) => Some((ζ, v_ζ, keep_above <= v_ζ)),
-                (Some((ξ, v_ξ)), Some((ζ, v_ζ))) => {
-                    if v_ξ - τα > τα - v_ζ {
-                        Some((ξ, v_ξ, keep_below >= v_ξ))
-                    } else {
-                        Some((ζ, v_ζ, keep_above <= v_ζ))
-                    }
-                }
-            }
-        }
+            reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
+                .then_some(())
+        });
+        debug_assert!(μ.len() >= n_before_merge);
+        stats.merged += μ.len() - n_before_merge;
     }
 
-    fn verify_merge_candidate<G, BT>(
-        &self,
-        d : &mut BTFN<F, G, BT, N>,
-        μ : &DiscreteMeasure<Loc<F, N>, F>,
-        τ : F,
-        ε : F,
-        config : &FBGenericConfig<F>,
-    ) -> bool
-    where BT : BTSearch<F, N, Agg=Bounds<F>>,
-          G : SupportGenerator<F, N, Id=BT::Data>,
-          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
-                           + LocalAnalysis<F, Bounds<F>, N> {
-        let τα = τ * self.α();
-        let refinement_tolerance = ε * config.refinement.tolerance_mult;
-        let merge_tolerance = config.merge_tolerance_mult * ε;
-        let keep_below = τα + merge_tolerance;
-        let keep_above = -τα - merge_tolerance;
-        let keep_supp_pos_above = τα - merge_tolerance;
-        let keep_supp_neg_below = -τα + merge_tolerance;
-        let bnd = d.bounds();
-
-        return (
-            (bnd.lower() >= keep_supp_pos_above && bnd.upper() <= keep_supp_neg_below)
-            ||
-            μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
-                use std::cmp::Ordering::*;
-                match β.partial_cmp(&0.0) {
-                    Some(Greater) => d.apply(x) >= keep_supp_pos_above,
-                    Some(Less) => d.apply(x) <= keep_supp_neg_below,
-                    _ => true,
-                }
-            }).all(std::convert::identity)
-        ) && (
-            bnd.upper() <= keep_below
-            ||
-            d.has_upper_bound(keep_below, refinement_tolerance,
-                              config.refinement.max_steps)
-        ) && (
-            bnd.lower() >= keep_above
-            ||
-            d.has_lower_bound(keep_above, refinement_tolerance,
-                              config.refinement.max_steps)
-        )
-    }
-
-    fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
-        let τα = τ * self.α();
-        Some(Bounds(-τα - ε,  τα + ε))
-    }
-
-    fn tolerance_scaling(&self) -> F {
-        self.α()
-    }
+    let n_before_prune = μ.len();
+    μ.prune();
+    debug_assert!(μ.len() <= n_before_prune);
+    stats.pruned += n_before_prune - μ.len();
 }
 
+#[replace_float_literals(F::cast_from(literal))]
+pub(crate) fn postprocess<
+    F : Float,
+    V : Euclidean<F> + Clone,
+    A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>,
+    D : DataTerm<F, V, N>,
+    const N : usize
+> (
+    mut μ : DiscreteMeasure<Loc<F, N>, F>,
+    config : &FBGenericConfig<F>,
+    dataterm : D,
+    opA : &A,
+    b : &V,
+) -> DiscreteMeasure<Loc<F, N>, F>
+where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
+    μ.merge_spikes_fitness(config.merging,
+                           |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
+                           |&v| v);
+    μ.prune();
+    μ
+}
 
-/// Generic implementation of [`pointsource_fb_reg`].
+/// Iteratively solve the pointsource localisation problem using forward-backward splitting.
 ///
-/// The method can be specialised to even primal-dual proximal splitting through the
-/// [`FBSpecialisation`] parameter `specialisation`.
-/// The settings in `config` have their [respective documentation](FBGenericConfig). `opA` is the
+/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
 /// as documented in [`alg_tools::iterate`].
 ///
+/// For details on the mathematical formulation, see the [module level](self) documentation.
+///
 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
 /// sums of simple functions usign bisection trees, and the related
 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
@@ -729,252 +463,16 @@
 ///
 /// Returns the final iterate.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn generic_pointsource_fb_reg<
-    'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, Reg, const N : usize
+pub fn pointsource_fb_reg<
+    'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize
 >(
     opA : &'a A,
-    reg : Reg,
-    op𝒟 : &'a 𝒟,
-    mut τ : F,
-    config : &FBGenericConfig<F>,
-    iterator : I,
-    mut plotter : SeqPlotter<F, N>,
-    mut residual : A::Observable,
-    mut specialisation : Spec
-) -> DiscreteMeasure<Loc<F, N>, F>
-where F : Float + ToNalgebraRealField,
-      I : AlgIteratorFactory<IterInfo<F, N>>,
-      Spec : FBSpecialisation<F, A::Observable, N>,
-      A::Observable : std::ops::MulAssign<F>,
-      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
-          + Lipschitz<𝒟, FloatType=F>,
-      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
-      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
-      𝒟::Codomain : RealMapping<F, N>,
-      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
-      PlotLookup : Plotting<N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
-      Reg : RegTerm<F, N> {
-
-    // Set up parameters
-    let quiet = iterator.is_quiet();
-    let op𝒟norm = op𝒟.opnorm_bound();
-    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
-    // by τ compared to the conditional gradient approach.
-    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
-    let mut ε = tolerance.initial();
-
-    // Initialise operators
-    let preadjA = opA.preadjoint();
-
-    // Initialise iterates
-    let mut μ = DiscreteMeasure::new();
-
-    let mut inner_iters = 0;
-    let mut this_iters = 0;
-    let mut pruned = 0;
-    let mut merged = 0;
-
-    let μ_diff = |μ_new : &DiscreteMeasure<Loc<F, N>, F>,
-                  μ_base : &DiscreteMeasure<Loc<F, N>, F>| {
-        let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style {
-            InsertionStyle::Reuse => {
-                μ_new.iter_spikes()
-                        .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0)))
-                        .map(|(δ, α_base)| (δ.x, α_base - δ.α))
-                        .collect()
-            },
-            InsertionStyle::Zero => {
-                μ_new.iter_spikes()
-                        .map(|δ| -δ)
-                        .chain(μ_base.iter_spikes().copied())
-                        .collect()
-            }
-        };
-        ν.prune(); // Potential small performance improvement
-        ν
-    };
-
-    // Run the algorithm
-    iterator.iterate(|state| {
-        // Maximum insertion count and measure difference calculation depend on insertion style.
-        let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
-            (i, Some((l, k))) if i <= l => (k, false),
-            _ => (config.max_insertions, !quiet),
-        };
-        let max_insertions = match config.insertion_style {
-            InsertionStyle::Zero => {
-                todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled.");
-                // let n = μ.len();
-                // μ = DiscreteMeasure::new();
-                // n + m
-            },
-            InsertionStyle::Reuse => m,
-        };
-
-        // Calculate smooth part of surrogate model.
-        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
-        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
-        // the residual and replacing it below before the end of this closure.
-        residual *= -τ;
-        let r = std::mem::replace(&mut residual, opA.empty_observable());
-        let minus_τv = preadjA.apply(r);     // minus_τv = -τA^*(Aμ^k-b)
-        // TODO: should avoid a second copy of μ here; μ_base already stores a copy.
-        let ω0 = op𝒟.apply(μ.clone());       // 𝒟μ^k
-        //let g = &minus_τv + ω0;            // Linear term of surrogate model
-
-        // Save current base point
-        let μ_base = μ.clone();
-            
-        // Add points to support until within error tolerance or maximum insertion count reached.
-        let mut count = 0;
-        let (within_tolerances, d) = 'insertion: loop {
-            if μ.len() > 0 {
-                // Form finite-dimensional subproblem. The subproblem references to the original μ^k
-                // from the beginning of the iteration are all contained in the immutable c and g.
-                let à = op𝒟.findim_matrix(μ.iter_locations());
-                let g̃ = DVector::from_iterator(μ.len(),
-                                               μ.iter_locations()
-                                                .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ))
-                                                .map(F::to_nalgebra_mixed));
-                let mut x = μ.masses_dvector();
-
-                // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
-                // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
-                // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤  sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
-                // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2
-                // = n |𝒟| |x|_2, where n is the number of points. Therefore
-                let Ã_normest = op𝒟norm * F::cast_from(μ.len());
-
-                // Solve finite-dimensional subproblem.
-                inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config);
-
-                // Update masses of μ based on solution of finite-dimensional subproblem.
-                μ.set_masses_dvector(&x);
-            }
-
-            // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality
-            // conditions in the predual space, and finding new points for insertion, if necessary.
-            let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ, &μ_base));
-
-            // If no merging heuristic is used, let's be more conservative about spike insertion,
-            // and skip it after first round. If merging is done, being more greedy about spike
-            // insertion also seems to improve performance.
-            let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging {
-                false
-            } else {
-                count > 0
-            };
-
-            // Find a spike to insert, if needed
-            let (ξ, _v_ξ, in_bounds) =  match reg.find_tolerance_violation(
-                &mut d, τ, ε, skip_by_rough_check, config
-            ) {
-                None => break 'insertion (true, d),
-                Some(res) => res,
-            };
-
-            // Break if maximum insertion count reached
-            if count >= max_insertions {
-                break 'insertion (in_bounds, d)
-            }
-
-            // No point in optimising the weight here; the finite-dimensional algorithm is fast.
-            μ += DeltaMeasure { x : ξ, α : 0.0 };
-            count += 1;
-        };
-
-        if !within_tolerances && warn_insertions {
-            // Complain (but continue) if we failed to get within tolerances
-            // by inserting more points.
-            let err = format!("Maximum insertions reached without achieving \
-                                subproblem solution tolerance");
-            println!("{}", err.red());
-        }
-
-        // Merge spikes
-        if state.iteration() % config.merge_every == 0 {
-            let n_before_merge = μ.len();
-            μ.merge_spikes(config.merging, |μ_candidate| {
-                let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ_candidate, &μ_base));
-
-                reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
-                   .then_some(())
-            });
-            debug_assert!(μ.len() >= n_before_merge);
-            merged += μ.len() - n_before_merge;
-        }
-
-        let n_before_prune = μ.len();
-        (residual, τ) = match specialisation.update(&mut μ, &μ_base) {
-            (r, None) => (r, τ),
-            (r, Some(new_τ)) => (r, new_τ)
-        };
-        debug_assert!(μ.len() <= n_before_prune);
-        pruned += n_before_prune - μ.len();
-
-        this_iters += 1;
-
-        // Update main tolerance for next iteration
-        let ε_prev = ε;
-        ε = tolerance.update(ε, state.iteration());
-
-        // Give function value if needed
-        state.if_verbose(|| {
-            let value_μ = specialisation.value_μ(&μ);
-            // Plot if so requested
-            plotter.plot_spikes(
-                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
-                "start".to_string(), Some(&minus_τv),
-                reg.target_bounds(τ, ε_prev), value_μ,
-            );
-            // Calculate mean inner iterations and reset relevant counters.
-            // Return the statistics
-            let res = IterInfo {
-                value : specialisation.calculate_fit(&μ, &residual) + reg.apply(value_μ),
-                n_spikes : value_μ.len(),
-                inner_iters,
-                this_iters,
-                merged,
-                pruned,
-                ε : ε_prev,
-                postprocessing: config.postprocessing.then(|| value_μ.clone()),
-            };
-            inner_iters = 0;
-            this_iters = 0;
-            merged = 0;
-            pruned = 0;
-            res
-        })
-    });
-
-    specialisation.postprocess(μ, config.final_merging)
-}
-
-/// Iteratively solve the pointsource localisation problem using forward-backward splitting
-///
-/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
-/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
-/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
-/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
-/// as documented in [`alg_tools::iterate`].
-///
-/// For details on the mathematical formulation, see the [module level](self) documentation.
-///
-/// Returns the final iterate.
-#[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>(
-    opA : &'a A,
     b : &A::Observable,
     reg : Reg,
     op𝒟 : &'a 𝒟,
-    config : &FBConfig<F>,
+    fbconfig : &FBConfig<F>,
     iterator : I,
-    plotter : SeqPlotter<F, N>,
+    mut plotter : SeqPlotter<F, N>,
 ) -> DiscreteMeasure<Loc<F, N>, F>
 where F : Float + ToNalgebraRealField,
       I : AlgIteratorFactory<IterInfo<F, N>>,
@@ -983,7 +481,7 @@
       A::Observable : std::ops::MulAssign<F>,
       GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
       A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
-          + Lipschitz<𝒟, FloatType=F>,
+          + Lipschitz<&'a 𝒟, FloatType=F>,
       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
       𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
@@ -996,17 +494,227 @@
       DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
       Reg : RegTerm<F, N> {
 
-    let initial_residual = -b;
-    let τ = config.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
+    // Set up parameters
+    let config = &fbconfig.insertion;
+    let op𝒟norm = op𝒟.opnorm_bound();
+    let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
+    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
+    // by τ compared to the conditional gradient approach.
+    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
+    let mut ε = tolerance.initial();
+
+    // Initialise iterates
+    let mut μ = DiscreteMeasure::new();
+    let mut residual = -b;
+    let mut stats = IterInfo::new();
+
+    // Run the algorithm
+    iterator.iterate(|state| {
+        // Calculate smooth part of surrogate model.
+        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
+        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
+        // the residual and replacing it below before the end of this closure.
+        residual *= -τ;
+        let r = std::mem::replace(&mut residual, opA.empty_observable());
+        let minus_τv = opA.preadjoint().apply(r);
+
+        // Save current base point
+        let μ_base = μ.clone();
+            
+        // Insert and reweigh
+        let (d, within_tolerances) = insert_and_reweigh(
+            &mut μ, &minus_τv, &μ_base, None,
+            op𝒟, op𝒟norm,
+            τ, ε,
+            config, &reg, state, &mut stats
+        );
+
+        // Prune and possibly merge spikes
+        prune_and_maybe_simple_merge(
+            &mut μ, &minus_τv, &μ_base,
+            op𝒟,
+            τ, ε,
+            config, &reg, state, &mut stats
+        );
+
+        // Update residual
+        residual = calculate_residual(&μ, opA, b);
+
+        // Update main tolerance for next iteration
+        let ε_prev = ε;
+        ε = tolerance.update(ε, state.iteration());
+        stats.this_iters += 1;
+
+        // Give function value if needed
+        state.if_verbose(|| {
+            // Plot if so requested
+            plotter.plot_spikes(
+                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
+                "start".to_string(), Some(&minus_τv),
+                reg.target_bounds(τ, ε_prev), &μ,
+            );
+            // Calculate mean inner iterations and reset relevant counters.
+            // Return the statistics
+            let res = IterInfo {
+                value : residual.norm2_squared_div2() + reg.apply(&μ),
+                n_spikes : μ.len(),
+                ε : ε_prev,
+                postprocessing: config.postprocessing.then(|| μ.clone()),
+                .. stats
+            };
+            stats = IterInfo::new();
+            res
+        })
+    });
+
+    postprocess(μ, config, L2Squared, opA, b)
+}
 
-    match config.meta {
-        FBMetaAlgorithm::None => generic_pointsource_fb_reg(
-            opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual,
-            BasicFB{ b, opA },
-        ),
-        FBMetaAlgorithm::InertiaFISTA => generic_pointsource_fb_reg(
-            opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual,
-            FISTA{ b, opA, λ : 1.0, μ_prev : DiscreteMeasure::new() },
-        ),
-    }
+/// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting.
+///
+/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
+/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
+/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
+/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
+/// as documented in [`alg_tools::iterate`].
+///
+/// For details on the mathematical formulation, see the [module level](self) documentation.
+///
+/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
+/// sums of simple functions usign bisection trees, and the related
+/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
+/// active at a specific points, and to maximise their sums. Through the implementation of the
+/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
+/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
+///
+/// Returns the final iterate.
+#[replace_float_literals(F::cast_from(literal))]
+pub fn pointsource_fista_reg<
+    'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize
+>(
+    opA : &'a A,
+    b : &A::Observable,
+    reg : Reg,
+    op𝒟 : &'a 𝒟,
+    fbconfig : &FBConfig<F>,
+    iterator : I,
+    mut plotter : SeqPlotter<F, N>,
+) -> DiscreteMeasure<Loc<F, N>, F>
+where F : Float + ToNalgebraRealField,
+      I : AlgIteratorFactory<IterInfo<F, N>>,
+      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
+                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
+      A::Observable : std::ops::MulAssign<F>,
+      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
+      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
+          + Lipschitz<&'a 𝒟, FloatType=F>,
+      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
+      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
+      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
+      𝒟::Codomain : RealMapping<F, N>,
+      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
+      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
+      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
+      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
+      PlotLookup : Plotting<N>,
+      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      Reg : RegTerm<F, N> {
+
+    // Set up parameters
+    let config = &fbconfig.insertion;
+    let op𝒟norm = op𝒟.opnorm_bound();
+    let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
+    let mut λ = 1.0;
+    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
+    // by τ compared to the conditional gradient approach.
+    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
+    let mut ε = tolerance.initial();
+
+    // Initialise iterates
+    let mut μ = DiscreteMeasure::new();
+    let mut μ_prev = DiscreteMeasure::new();
+    let mut residual = -b;
+    let mut stats = IterInfo::new();
+    let mut warned_merging = false;
+
+    // Run the algorithm
+    iterator.iterate(|state| {
+        // Calculate smooth part of surrogate model.
+        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
+        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
+        // the residual and replacing it below before the end of this closure.
+        residual *= -τ;
+        let r = std::mem::replace(&mut residual, opA.empty_observable());
+        let minus_τv = opA.preadjoint().apply(r);
+
+        // Save current base point
+        let μ_base = μ.clone();
+            
+        // Insert new spikes and reweigh
+        let (d, within_tolerances) = insert_and_reweigh(
+            &mut μ, &minus_τv, &μ_base, None,
+            op𝒟, op𝒟norm,
+            τ, ε,
+            config, &reg, state, &mut stats
+        );
+
+        // (Do not) merge spikes.
+        if state.iteration() % config.merge_every == 0 {
+            match config.merging {
+                SpikeMergingMethod::None => { },
+                _ => if !warned_merging {
+                    let err = format!("Merging not supported for μFISTA");
+                    println!("{}", err.red());
+                    warned_merging = true;
+                }
+            }
+        }
+
+        // Update inertial prameters
+        let λ_prev = λ;
+        λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() );
+        let θ = λ / λ_prev - λ;
+
+        // Perform inertial update on μ.
+        // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ
+        // and μ_prev have zero weight. Since both have weights from the finite-dimensional
+        // subproblem with a proximal projection step, this is likely to happen when the
+        // spike is not needed. A copy of the pruned μ without artithmetic performed is
+        // stored in μ_prev.
+        let n_before_prune = μ.len();
+        μ.pruning_sub(1.0 + θ, θ, &mut μ_prev);
+        debug_assert!(μ.len() <= n_before_prune);
+        stats.pruned += n_before_prune - μ.len();
+
+        // Update residual
+        residual = calculate_residual(&μ, opA, b);
+
+        // Update main tolerance for next iteration
+        let ε_prev = ε;
+        ε = tolerance.update(ε, state.iteration());
+        stats.this_iters += 1;
+
+        // Give function value if needed
+        state.if_verbose(|| {
+            // Plot if so requested
+            plotter.plot_spikes(
+                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
+                "start".to_string(), Some(&minus_τv),
+                reg.target_bounds(τ, ε_prev), &μ_prev,
+            );
+            // Calculate mean inner iterations and reset relevant counters.
+            // Return the statistics
+            let res = IterInfo {
+                value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev),
+                n_spikes : μ_prev.len(),
+                ε : ε_prev,
+                postprocessing: config.postprocessing.then(|| μ_prev.clone()),
+                .. stats
+            };
+            stats = IterInfo::new();
+            res
+        })
+    });
+
+    postprocess(μ_prev, config, L2Squared, opA, b)
 }

mercurial