src/fb.rs

Sun, 25 Sep 2022 21:45:56 +0300

author
Tuomo Valkonen <tuomov@iki.fi>
date
Sun, 25 Sep 2022 21:45:56 +0300
changeset 4
5aa5c279e341
parent 0
eb3c7813b67a
permissions
-rw-r--r--

Attempt at directly referring to measure masses as SliceStorageMut

/*!
Solver for the point source localisation problem using a forward-backward splitting method.

This corresponds to the manuscript

 * Valkonen T. - _Proximal methods for point source localisation_. ARXIV TO INSERT.

The main routine is [`pointsource_fb`]. It is based on [`generic_pointsource_fb`], which is also
used by our [primal-dual proximal splitting][crate::pdps] implementation.

FISTA-type inertia can also be enabled through [`FBConfig::meta`].

## Problem

<p>
Our objective is to solve
$$
    \min_{μ ∈ ℳ(Ω)}~ F_0(Aμ-b) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ),
$$
where $F_0(y)=\frac{1}{2}\|y\|_2^2$ and the forward operator $A \in 𝕃(ℳ(Ω); ℝ^n)$.
</p>

## Approach

<p>
As documented in more detail in the paper, on each step we approximately solve
$$
    \min_{μ ∈ ℳ(Ω)}~ F(x) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(x) + \frac{1}{2}\|μ-μ^k|_𝒟^2,
$$
where $𝒟: 𝕃(ℳ(Ω); C_c(Ω))$ is typically a convolution operator.
</p>

## Finite-dimensional subproblems.

With $C$ a projection from [`DiscreteMeasure`] to the weights, and $x^k$ such that $x^k=Cμ^k$, we
form the discretised linearised inner problem
<p>
$$
    \min_{x ∈ ℝ^n}~ τ\bigl(F(Cx^k) + [C^*∇F(Cx^k)]^⊤(x-x^k) + α {\vec 1}^⊤ x\bigr)
                    + δ_{≥ 0}(x) + \frac{1}{2}\|x-x^k\|_{C^*𝒟C}^2,
$$
equivalently
$$
    \begin{aligned}
    \min_x~ & τF(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k
            \\
            &
            - [C^*𝒟C x^k - τC^*∇F(Cx^k)]^⊤ x
            \\
            &
            + \frac{1}{2} x^⊤ C^*𝒟C x
            + τα {\vec 1}^⊤ x + δ_{≥ 0}(x),
    \end{aligned}
$$
In other words, we obtain the quadratic non-negativity constrained problem
$$
    \min_{x ∈ ℝ^n}~ \frac{1}{2} x^⊤ Ã x - b̃^⊤ x + c + τα {\vec 1}^⊤ x + δ_{≥ 0}(x).
$$
where
$$
   \begin{aligned}
    Ã & = C^*𝒟C,
    \\
    g̃ & = C^*𝒟C x^k - τ C^*∇F(Cx^k)
        = C^* 𝒟 μ^k - τ C^*A^*(Aμ^k - b)
    \\
    c & = τ F(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k
        \\
        &
        = \frac{τ}{2} \|Aμ^k-b\|^2 - τ[Aμ^k-b]^⊤Aμ^k + \frac{1}{2} \|μ_k\|_{𝒟}^2
        \\
        &
        = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2.
   \end{aligned}
$$
</p>

We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by
[`InnerSettings`] in [`FBGenericConfig::inner`].
*/

use numeric_literals::replace_float_literals;
use std::cmp::Ordering::*;
use serde::{Serialize, Deserialize};
use colored::Colorize;
use nalgebra::DVector;

use alg_tools::iterate::{
    AlgIteratorFactory,
    AlgIteratorState,
};
use alg_tools::euclidean::Euclidean;
use alg_tools::norms::Norm;
use alg_tools::linops::Apply;
use alg_tools::sets::Cube;
use alg_tools::loc::Loc;
use alg_tools::bisection_tree::{
    BTFN,
    PreBTFN,
    Bounds,
    BTNodeLookup,
    BTNode,
    BTSearch,
    P2Minimise,
    SupportGenerator,
    LocalAnalysis,
    Bounded,
};
use alg_tools::mapping::RealMapping;
use alg_tools::nalgebra_support::ToNalgebraRealField;

use crate::types::*;
use crate::measures::{
    DiscreteMeasure,
    DeltaMeasure,
    Radon
};
use crate::measures::merging::{
    SpikeMergingMethod,
    SpikeMerging,
};
use crate::forward_model::ForwardModel;
use crate::seminorms::{
    DiscreteMeasureOp, Lipschitz
};
use crate::subproblem::{
    quadratic_nonneg,
    InnerSettings,
    InnerMethod,
};
use crate::tolerance::Tolerance;
use crate::plot::{
    SeqPlotter,
    Plotting,
    PlotLookup
};

/// Method for constructing $μ$ on each iteration
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[allow(dead_code)]
pub enum InsertionStyle {
    /// Resuse previous $μ$ from previous iteration, optimising weights
    /// before inserting new spikes.
    Reuse,
    /// Start each iteration with $μ=0$.
    Zero,
}

/// Meta-algorithm type
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[allow(dead_code)]
pub enum FBMetaAlgorithm {
    /// No meta-algorithm
    None,
    /// FISTA-style inertia
    InertiaFISTA,
}

/// Ergodic tolerance application style
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[allow(dead_code)]
pub enum ErgodicTolerance<F> {
    /// Non-ergodic iteration-wise tolerance
    NonErgodic,
    /// Bound after `n`th iteration to `factor` times value on that iteration.
    AfterNth{ n : usize, factor : F },
}

/// Settings for [`pointsource_fb`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct FBConfig<F : Float> {
    /// Step length scaling
    pub τ0 : F,
    /// Meta-algorithm to apply
    pub meta : FBMetaAlgorithm,
    /// Generic parameters
    pub insertion : FBGenericConfig<F>,
}

/// Settings for the solution of the stepwise optimality condition in algorithms based on
/// [`generic_pointsource_fb`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct FBGenericConfig<F : Float> {
    /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`].
    pub insertion_style : InsertionStyle,
    /// Tolerance for point insertion.
    pub tolerance : Tolerance<F>,
    /// Stop looking for predual maximum (where to isert a new point) below
    /// `tolerance` multiplied by this factor.
    pub insertion_cutoff_factor : F,
    /// Apply tolerance ergodically
    pub ergodic_tolerance : ErgodicTolerance<F>,
    /// Settings for branch and bound refinement when looking for predual maxima
    pub refinement : RefinementSettings<F>,
    /// Maximum insertions within each outer iteration
    pub max_insertions : usize,
    /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations.
    pub bootstrap_insertions : Option<(usize, usize)>,
    /// Inner method settings
    pub inner : InnerSettings<F>,
    /// Spike merging method
    pub merging : SpikeMergingMethod<F>,
    /// Tolerance multiplier for merges
    pub merge_tolerance_mult : F,
    /// Spike merging method after the last step
    pub final_merging : SpikeMergingMethod<F>,
    /// Iterations between merging heuristic tries
    pub merge_every : usize,
    /// Save $μ$ for postprocessing optimisation
    pub postprocessing : bool
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for FBConfig<F> {
    fn default() -> Self {
        FBConfig {
            τ0 : 0.99,
            meta : FBMetaAlgorithm::None,
            insertion : Default::default()
        }
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for FBGenericConfig<F> {
    fn default() -> Self {
        FBGenericConfig {
            insertion_style : InsertionStyle::Reuse,
            tolerance : Default::default(),
            insertion_cutoff_factor : 1.0,
            ergodic_tolerance : ErgodicTolerance::NonErgodic,
            refinement : Default::default(),
            max_insertions : 100,
            //bootstrap_insertions : None,
            bootstrap_insertions : Some((10, 1)),
            inner : InnerSettings {
                method : InnerMethod::SSN,
                .. Default::default()
            },
            merging : SpikeMergingMethod::None,
            //merging : Default::default(),
            final_merging : Default::default(),
            merge_every : 10,
            merge_tolerance_mult : 2.0,
            postprocessing : false,
        }
    }
}

/// Trait for specialisation of [`generic_pointsource_fb`] to basic FB, FISTA.
///
/// The idea is that the residual $Aμ - b$ in the forward step can be replaced by an arbitrary
/// value. For example, to implement [primal-dual proximal splitting][crate::pdps] we replace it
/// with the dual variable $y$. We can then also implement alternative data terms, as the
/// (pre)differential of $F(μ)=F\_0(Aμ-b)$ is $F\'(μ) = A\_*F\_0\'(Aμ-b)$. In the case of the
/// quadratic fidelity $F_0(y)=\frac{1}{2}\\|y\\|_2^2$ in a Hilbert space, of course,
/// $F\_0\'(Aμ-b)=Aμ-b$ is the residual.
pub trait FBSpecialisation<F : Float, Observable : Euclidean<F>, const N : usize> : Sized {
    /// Updates the residual and does any necessary pruning of `μ`.
    ///
    /// Returns the new residual and possibly a new step length.
    ///
    /// The measure `μ` may also be modified to apply, e.g., inertia to it.
    /// The updated residual should correspond to the residual at `μ`.
    /// See the [trait documentation][FBSpecialisation] for the use and meaning of the residual.
    ///
    /// The parameter `μ_base` is the base point of the iteration, typically the previous iterate,
    /// but for, e.g., FISTA has inertia applied to it.
    fn update(
        &mut self,
        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
        μ_base : &DiscreteMeasure<Loc<F, N>, F>,
    ) -> (Observable, Option<F>);

    /// Calculates the data term value corresponding to iterate `μ` and available residual.
    ///
    /// Inertia and other modifications, as deemed, necessary, should be applied to `μ`.
    ///
    /// The blanket implementation correspondsn to the 2-norm-squared data fidelity
    /// $\\|\text{residual}\\|\_2^2/2$.
    fn calculate_fit(
        &self,
        _μ : &DiscreteMeasure<Loc<F, N>, F>,
        residual : &Observable
    ) -> F {
        residual.norm2_squared_div2()
    }

    /// Calculates the data term value at $μ$.
    ///
    /// Unlike [`Self::calculate_fit`], no inertia, etc., should be applied to `μ`.
    fn calculate_fit_simple(
        &self,
        μ : &DiscreteMeasure<Loc<F, N>, F>,
    ) -> F;

    /// Returns the final iterate after any necessary postprocess pruning, merging, etc.
    fn postprocess(self, mut μ : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>)
    -> DiscreteMeasure<Loc<F, N>, F>
    where  DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
        μ.merge_spikes_fitness(merging,
                               |μ̃| self.calculate_fit_simple(μ̃),
                               |&v| v);
        μ.prune();
        μ
    }

    /// Returns measure to be used for value calculations, which may differ from μ.
    fn value_μ<'c, 'b : 'c>(&'b self, μ : &'c DiscreteMeasure<Loc<F, N>, F>)
    -> &'c DiscreteMeasure<Loc<F, N>, F> {
        μ
    }
}

/// Specialisation of [`generic_pointsource_fb`] to basic μFB.
struct BasicFB<
    'a,
    F : Float + ToNalgebraRealField,
    A : ForwardModel<Loc<F, N>, F>,
    const N : usize
> {
    /// The data
    b : &'a A::Observable,
    /// The forward operator
    opA : &'a A,
}

/// Implementation of [`FBSpecialisation`] for basic μFB forward-backward splitting.
#[replace_float_literals(F::cast_from(literal))]
impl<'a, F : Float + ToNalgebraRealField , A : ForwardModel<Loc<F, N>, F>, const N : usize>
FBSpecialisation<F, A::Observable, N> for BasicFB<'a, F, A, N> {
    fn update(
        &mut self,
        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
        _μ_base : &DiscreteMeasure<Loc<F, N>, F>
    ) -> (A::Observable, Option<F>) {
        μ.prune();
        //*residual = self.opA.apply(μ) - self.b;
        let mut residual = self.b.clone();
        self.opA.gemv(&mut residual, 1.0, μ, -1.0);
        (residual, None)
    }

    fn calculate_fit_simple(
        &self,
        μ : &DiscreteMeasure<Loc<F, N>, F>,
    ) -> F {
        let mut residual = self.b.clone();
        self.opA.gemv(&mut residual, 1.0, μ, -1.0);
        residual.norm2_squared_div2()
    }
}

/// Specialisation of [`generic_pointsource_fb`] to FISTA.
struct FISTA<
    'a,
    F : Float + ToNalgebraRealField,
    A : ForwardModel<Loc<F, N>, F>,
    const N : usize
> {
    /// The data
    b : &'a A::Observable,
    /// The forward operator
    opA : &'a A,
    /// Current inertial parameter
    λ : F,
    /// Previous iterate without inertia applied.
    /// We need to store this here because `μ_base` passed to [`FBSpecialisation::update`] will
    /// have inertia applied to it, so is not useful to use.
    μ_prev : DiscreteMeasure<Loc<F, N>, F>,
}

/// Implementation of [`FBSpecialisation`] for μFISTA inertial forward-backward splitting.
#[replace_float_literals(F::cast_from(literal))]
impl<'a, F : Float + ToNalgebraRealField, A : ForwardModel<Loc<F, N>, F>, const N : usize>
FBSpecialisation<F, A::Observable, N> for FISTA<'a, F, A, N> {
    fn update(
        &mut self,
        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
        _μ_base : &DiscreteMeasure<Loc<F, N>, F>
    ) -> (A::Observable, Option<F>) {
        // Update inertial parameters
        let λ_prev = self.λ;
        self.λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() );
        let θ = self.λ / λ_prev - self.λ;
        // Perform inertial update on μ.
        // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ
        // and μ_prev have zero weight. Since both have weights from the finite-dimensional
        // subproblem with a proximal projection step, this is likely to happen when the
        // spike is not needed. A copy of the pruned μ without artithmetic performed is
        // stored in μ_prev.
        μ.pruning_sub(1.0 + θ, θ, &mut self.μ_prev);

        //*residual = self.opA.apply(μ) - self.b;
        let mut residual = self.b.clone();
        self.opA.gemv(&mut residual, 1.0, μ, -1.0);
        (residual, None)
    }

    fn calculate_fit_simple(
        &self,
        μ : &DiscreteMeasure<Loc<F, N>, F>,
    ) -> F {
        let mut residual = self.b.clone();
        self.opA.gemv(&mut residual, 1.0, μ, -1.0);
        residual.norm2_squared_div2()
    }

    fn calculate_fit(
        &self,
        _μ : &DiscreteMeasure<Loc<F, N>, F>,
        _residual : &A::Observable
    ) -> F {
        self.calculate_fit_simple(&self.μ_prev)
    }

    // For FISTA we need to do a final pruning as well, due to the limited
    // pruning that can be done on each step.
    fn postprocess(mut self, μ_base : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>)
    -> DiscreteMeasure<Loc<F, N>, F>
    where  DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
        let mut μ = self.μ_prev;
        self.μ_prev = μ_base;
        μ.merge_spikes_fitness(merging,
                               |μ̃| self.calculate_fit_simple(μ̃),
                               |&v| v);
        μ.prune();
        μ
    }

    fn value_μ<'c, 'b : 'c>(&'c self, _μ : &'c DiscreteMeasure<Loc<F, N>, F>)
    -> &'c DiscreteMeasure<Loc<F, N>, F> {
        &self.μ_prev
    }
}

/// Iteratively solve the pointsource localisation problem using forward-backward splitting
///
/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
/// as documented in [`alg_tools::iterate`].
///
/// For details on the mathematical formulation, see the [module level](self) documentation.
///
/// Returns the final iterate.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, const N : usize>(
    opA : &'a A,
    b : &A::Observable,
    α : F,
    op𝒟 : &'a 𝒟,
    config : &FBConfig<F>,
    iterator : I,
    plotter : SeqPlotter<F, N>
) -> DiscreteMeasure<Loc<F, N>, F>
where F : Float + ToNalgebraRealField<MixedType = F>,
      I : AlgIteratorFactory<IterInfo<F, N>>,
      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
      A::Observable : std::ops::MulAssign<F>,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
          + Lipschitz<𝒟, FloatType=F>,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
      𝒟::Codomain : RealMapping<F, N>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
      PlotLookup : Plotting<N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {

    let initial_residual = -b;
    let τ = config.τ0/opA.lipschitz_factor(&op𝒟).unwrap();

    match config.meta {
        FBMetaAlgorithm::None => generic_pointsource_fb(
            opA, α, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual,
            BasicFB{ b, opA }
        ),
        FBMetaAlgorithm::InertiaFISTA => generic_pointsource_fb(
            opA, α, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual,
            FISTA{ b, opA, λ : 1.0, μ_prev : DiscreteMeasure::new() }
        ),
    }
}

/// Generic implementation of [`pointsource_fb`].
///
/// The method can be specialised to even primal-dual proximal splitting through the
/// [`FBSpecialisation`] parameter `specialisation`.
/// The settings in `config` have their [respective documentation](FBGenericConfig). `opA` is the
/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
/// as documented in [`alg_tools::iterate`].
///
/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
/// sums of simple functions usign bisection trees, and the related
/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
/// active at a specific points, and to maximise their sums. Through the implementation of the
/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
///
/// Returns the final iterate.
#[replace_float_literals(F::cast_from(literal))]
pub fn generic_pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, const N : usize>(
    opA : &'a A,
    α : F,
    op𝒟 : &'a 𝒟,
    mut τ : F,
    config : &FBGenericConfig<F>,
    iterator : I,
    mut plotter : SeqPlotter<F, N>,
    mut residual : A::Observable,
    mut specialisation : Spec,
) -> DiscreteMeasure<Loc<F, N>, F>
where F : Float + ToNalgebraRealField<MixedType=F>,
      I : AlgIteratorFactory<IterInfo<F, N>>,
      Spec : FBSpecialisation<F, A::Observable, N>,
      A::Observable : std::ops::MulAssign<F>,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
          + Lipschitz<𝒟, FloatType=F>,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
      𝒟::Codomain : RealMapping<F, N>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
      PlotLookup : Plotting<N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {

    // Set up parameters
    let quiet = iterator.is_quiet();
    let op𝒟norm = op𝒟.opnorm_bound();
    // We multiply tolerance by τ for FB since
    // our subproblems depending on tolerances are scaled by τ compared to the conditional
    // gradient approach.
    let mut tolerance = config.tolerance * τ * α;
    let mut ε = tolerance.initial();

    // Initialise operators
    let preadjA = opA.preadjoint();

    // Initialise iterates
    let mut μ = DiscreteMeasure::new();

    let mut after_nth_bound = F::INFINITY;
    // FIXME: Don't allocate if not needed.
    let mut after_nth_accum = opA.zero_observable();

    let mut inner_iters = 0;
    let mut this_iters = 0;
    let mut pruned = 0;
    let mut merged = 0;

    let μ_diff = |μ_new : &DiscreteMeasure<Loc<F, N>, F>,
                  μ_base : &DiscreteMeasure<Loc<F, N>, F>| {
        let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style {
            InsertionStyle::Reuse => {
                μ_new.iter_spikes()
                        .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0)))
                        .map(|(δ, α_base)| (δ.x, α_base - δ.α))
                        .collect()
            },
            InsertionStyle::Zero => {
                μ_new.iter_spikes()
                        .map(|δ| -δ)
                        .chain(μ_base.iter_spikes().copied())
                        .collect()
            }
        };
        ν.prune(); // Potential small performance improvement
        ν
    };

    // Run the algorithm
    iterator.iterate(|state| {
        // Calculate subproblem tolerances, and update main tolerance for next iteration
        let τα = τ * α;
        // if μ.len() == 0 /*state.iteration() == 1*/ {
        //     let t = minus_τv.bounds().upper() * 0.001;
        //     if t > 0.0 {
        //         let (ξ, v_ξ) = minus_τv.maximise(t, config.refinement.max_steps);
        //         if τα + ε > v_ξ && v_ξ > τα {
        //             // The zero measure is already within bounds, so improve them
        //             tolerance = config.tolerance * (v_ξ - τα);
        //             ε = tolerance.initial();
        //         }
        //         μ += DeltaMeasure { x : ξ, α : 0.0 };
        //     } else {
        //         // Zero is the solution.
        //         return Step::Terminated
        //     }
        // }
        let target_bounds = Bounds(τα - ε,  τα + ε);
        let merge_tolerance = config.merge_tolerance_mult * ε;
        let merge_target_bounds = Bounds(τα - merge_tolerance,  τα + merge_tolerance);
        let inner_tolerance = ε * config.inner.tolerance_mult;
        let refinement_tolerance = ε * config.refinement.tolerance_mult;
        let maximise_above = τα + ε * config.insertion_cutoff_factor;
        let mut ε1 = ε;
        let ε_prev = ε;
        ε = tolerance.update(ε, state.iteration());

        // Maximum insertion count and measure difference calculation depend on insertion style.
        let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
            (i, Some((l, k))) if i <= l => (k, false),
            _ => (config.max_insertions, !quiet),
        };
        let max_insertions = match config.insertion_style {
            InsertionStyle::Zero => {
                todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled.");
                // let n = μ.len();
                // μ = DiscreteMeasure::new();
                // n + m
            },
            InsertionStyle::Reuse => m,
        };

        // Calculate smooth part of surrogate model.
        residual *= -τ;
        if let ErgodicTolerance::AfterNth{ .. } = config.ergodic_tolerance {
            // Negative residual times τ expected here, as set above.
            // TODO: is this the correct location?
            after_nth_accum += &residual;
        }
        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
        // the residual and replacing it below before the end of this closure.
        let r = std::mem::replace(&mut residual, opA.empty_observable());
        let minus_τv = preadjA.apply(r);     // minus_τv = -τA^*(Aμ^k-b)
        // TODO: should avoid a second copy of μ here; μ_base already stores a copy.
        let ω0 = op𝒟.apply(μ.clone());       // 𝒟μ^k
        //let g = &minus_τv + ω0;            // Linear term of surrogate model

        // Save current base point
        let μ_base = μ.clone();
            
        // Add points to support until within error tolerance or maximum insertion count reached.
        let mut count = 0;
        let (within_tolerances, d) = 'insertion: loop {
            if μ.len() > 0 {
                // Form finite-dimensional subproblem. The subproblem references to the original μ^k
                // from the beginning of the iteration are all contained in the immutable c and g.
                let à = op𝒟.findim_matrix(μ.iter_locations());
                let g̃ = DVector::from_iterator(μ.len(),
                                               μ.iter_locations()
                                                .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ))
                                                .map(F::to_nalgebra_mixed));
                let mut x = μ.masses_mut();

                // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
                // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
                // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤  sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
                // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2
                // = n |𝒟| |x|_2, where n is the number of points. Therefore
                let inner_τ = config.inner.τ0 / (op𝒟norm * F::cast_from(μ.len()));

                // Solve finite-dimensional subproblem.
                let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
                inner_iters += quadratic_nonneg(config.inner.method, &Ã, &g̃, τ*α, &mut x,
                                                inner_τ, inner_it);

                // Update masses of μ based on solution of finite-dimensional subproblem.
                //μ.set_masses_dvector(&x);
            }

            // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality
            // conditions in the predual space, and finding new points for insertion, if necessary.
            let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ, &μ_base));

            // If no merging heuristic is used, let's be more conservative about spike insertion,
            // and skip it after first round. If merging is done, being more greedy about spike
            // insertion also seems to improve performance.
            let may_break = if let SpikeMergingMethod::None = config.merging {
                false
            } else {
                count > 0
            };

            // First do a rough check whether we are within bounds and can stop.
            let in_bounds = match config.ergodic_tolerance {
                ErgodicTolerance::NonErgodic => {
                    target_bounds.superset(&d.bounds())
                },
                ErgodicTolerance::AfterNth{ n, factor } => {
                    // Bound -τ∑_{k=0}^{N-1}[A_*(Aμ^k-b)+α] from above.
                    match state.iteration().cmp(&n) {
                        Less => true,
                        Equal => {
                            let iter = F::cast_from(state.iteration());
                            let mut tmp = preadjA.apply(&after_nth_accum);
                            let (_, v0) = tmp.maximise(refinement_tolerance,
                                                    config.refinement.max_steps);
                            let v = v0 - iter * τ * α;
                            after_nth_bound = factor * v;
                            println!("{}", format!("Set ergodic tolerance to {}", after_nth_bound));
                            true
                        },
                        Greater => {
                            // TODO: can divide after_nth_accum by N, so use basic tolerance on that.
                            let iter = F::cast_from(state.iteration());
                            let mut tmp = preadjA.apply(&after_nth_accum);
                            tmp.has_upper_bound(after_nth_bound + iter * τ * α,
                                                refinement_tolerance,
                                                config.refinement.max_steps)
                        }
                    }
                }
            };

            // If preliminary check indicates that we are in bonds, and if it otherwise matches
            // the insertion strategy, skip insertion.
            if may_break && in_bounds {
                break 'insertion (true, d)
            }

            // If the rough check didn't indicate stopping, find maximising point, maintaining for
            // the calculations in the beginning of the loop that v_ξ = (ω0-τv-𝒟μ)(ξ) = d(ξ),
            // where 𝒟μ is now distinct from μ0 after the insertions already performed.
            // We do not need to check lower bounds, as a solution of the finite-dimensional
            // subproblem should always satisfy them.

            // // Find the mimimum over the support of μ.
            // let d_min_supp = d_max;μ.iter_spikes().filter_map(|&DeltaMeasure{ α, ref x }| {
            //    (α != F::ZERO).then(|| d.value(x))
            // }).reduce(F::min).unwrap_or(0.0);

            let (ξ, v_ξ) = if false /* μ.len() == 0*/ /*count == 0 &&*/ {
                // If μ has no spikes, just find the maximum of d. Then adjust the tolerance, if
                // necessary, to adapt it to the problem.
                let (ξ, v_ξ) = d.maximise(refinement_tolerance, config.refinement.max_steps);
                //dbg!((τα, v_ξ, target_bounds.upper(), maximise_above));
                if τα < v_ξ  && v_ξ < target_bounds.upper() {
                    ε1 = v_ξ - τα;
                    ε *= ε1 / ε_prev;
                    tolerance *= ε1 / ε_prev;
                }
                (ξ, v_ξ)
            } else {
                // If μ has some spikes, only find a maximum of d if it is above a threshold
                // defined by the refinment tolerance.
                match d.maximise_above(maximise_above, refinement_tolerance,
                                    config.refinement.max_steps) {
                    None => break 'insertion (true, d),
                    Some(res) => res,
                }
            };

            // // Do a one final check whether we can stop already without inserting more points
            // // because `d` actually in bounds based on a more refined estimate.
            // if may_break && target_bounds.upper() >= v_ξ {
            //     break (true, d)
            // }

            // Break if maximum insertion count reached
            if count >= max_insertions {
                let in_bounds2 = target_bounds.upper() >= v_ξ;
                break 'insertion (in_bounds2, d)
            }

            // No point in optimising the weight here; the finite-dimensional algorithm is fast.
            μ += DeltaMeasure { x : ξ, α : 0.0 };
            count += 1;
        };

        if !within_tolerances && warn_insertions {
            // Complain (but continue) if we failed to get within tolerances
            // by inserting more points.
            let err = format!("Maximum insertions reached without achieving \
                                subproblem solution tolerance");
            println!("{}", err.red());
        }

        // Merge spikes
        if state.iteration() % config.merge_every == 0 {
            let n_before_merge = μ.len();
            μ.merge_spikes(config.merging, |μ_candidate| {
                //println!("Merge attempt!");
                let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ_candidate, &μ_base));

                if merge_target_bounds.superset(&d.bounds()) {
                    //println!("…Early Ok");
                    return Some(())
                }

                let d_min_supp = μ_candidate.iter_spikes().filter_map(|&DeltaMeasure{ α, ref x }| {
                    (α != 0.0).then(|| d.apply(x))
                }).reduce(F::min);

                if d_min_supp.map_or(true, |b| b >= merge_target_bounds.lower()) &&
                d.has_upper_bound(merge_target_bounds.upper(), refinement_tolerance,
                                    config.refinement.max_steps) {
                    //println!("…Ok");
                    Some(())
                } else {
                    //println!("…Fail");
                    None
                }
            });
            debug_assert!(μ.len() >= n_before_merge);
            merged += μ.len() - n_before_merge;
        }

        let n_before_prune = μ.len();
        (residual, τ) = match specialisation.update(&mut μ, &μ_base) {
            (r, None) => (r, τ),
            (r, Some(new_τ)) => (r, new_τ)
        };
        debug_assert!(μ.len() <= n_before_prune);
        pruned += n_before_prune - μ.len();

        this_iters += 1;

        // Give function value if needed
        state.if_verbose(|| {
            let value_μ = specialisation.value_μ(&μ);
            // Plot if so requested
            plotter.plot_spikes(
                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
                "start".to_string(), Some(&minus_τv),
                Some(target_bounds), value_μ,
            );
            // Calculate mean inner iterations and reset relevant counters
            // Return the statistics
            let res = IterInfo {
                value : specialisation.calculate_fit(&μ, &residual) + α * value_μ.norm(Radon),
                n_spikes : value_μ.len(),
                inner_iters,
                this_iters,
                merged,
                pruned,
                ε : ε_prev,
                maybe_ε1 : Some(ε1),
                postprocessing: config.postprocessing.then(|| value_μ.clone()),
            };
            inner_iters = 0;
            this_iters = 0;
            merged = 0;
            pruned = 0;
            res
        })
    });

    specialisation.postprocess(μ, config.final_merging)
}



mercurial