src/fb.rs

Tue, 31 Dec 2024 09:34:24 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Tue, 31 Dec 2024 09:34:24 -0500
branch
dev
changeset 32
56c8adc32b09
parent 29
87649ccfa6a8
child 34
efa60bc4f743
permissions
-rw-r--r--

Early transport sketches

/*!
Solver for the point source localisation problem using a forward-backward splitting method.

This corresponds to the manuscript

 * Valkonen T. - _Proximal methods for point source localisation_,
   [arXiv:2212.02991](https://arxiv.org/abs/2212.02991).

The main routine is [`pointsource_fb_reg`]. It is based on [`generic_pointsource_fb_reg`], which is
also used by our [primal-dual proximal splitting][crate::pdps] implementation.

FISTA-type inertia can also be enabled through [`FBConfig::meta`].

## Problem

<p>
Our objective is to solve
$$
    \min_{μ ∈ ℳ(Ω)}~ F_0(Aμ-b) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ),
$$
where $F_0(y)=\frac{1}{2}\|y\|_2^2$ and the forward operator $A \in 𝕃(ℳ(Ω); ℝ^n)$.
</p>

## Approach

<p>
As documented in more detail in the paper, on each step we approximately solve
$$
    \min_{μ ∈ ℳ(Ω)}~ F(x) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(x) + \frac{1}{2}\|μ-μ^k|_𝒟^2,
$$
where $𝒟: 𝕃(ℳ(Ω); C_c(Ω))$ is typically a convolution operator.
</p>

## Finite-dimensional subproblems.

With $C$ a projection from [`DiscreteMeasure`] to the weights, and $x^k$ such that $x^k=Cμ^k$, we
form the discretised linearised inner problem
<p>
$$
    \min_{x ∈ ℝ^n}~ τ\bigl(F(Cx^k) + [C^*∇F(Cx^k)]^⊤(x-x^k) + α {\vec 1}^⊤ x\bigr)
                    + δ_{≥ 0}(x) + \frac{1}{2}\|x-x^k\|_{C^*𝒟C}^2,
$$
equivalently
$$
    \begin{aligned}
    \min_x~ & τF(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k
            \\
            &
            - [C^*𝒟C x^k - τC^*∇F(Cx^k)]^⊤ x
            \\
            &
            + \frac{1}{2} x^⊤ C^*𝒟C x
            + τα {\vec 1}^⊤ x + δ_{≥ 0}(x),
    \end{aligned}
$$
In other words, we obtain the quadratic non-negativity constrained problem
$$
    \min_{x ∈ ℝ^n}~ \frac{1}{2} x^⊤ Ã x - b̃^⊤ x + c + τα {\vec 1}^⊤ x + δ_{≥ 0}(x).
$$
where
$$
   \begin{aligned}
    Ã & = C^*𝒟C,
    \\
    g̃ & = C^*𝒟C x^k - τ C^*∇F(Cx^k)
        = C^* 𝒟 μ^k - τ C^*A^*(Aμ^k - b)
    \\
    c & = τ F(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k
        \\
        &
        = \frac{τ}{2} \|Aμ^k-b\|^2 - τ[Aμ^k-b]^⊤Aμ^k + \frac{1}{2} \|μ_k\|_{𝒟}^2
        \\
        &
        = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2.
   \end{aligned}
$$
</p>

We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by
[`InnerSettings`] in [`FBGenericConfig::inner`].
*/

use numeric_literals::replace_float_literals;
use serde::{Serialize, Deserialize};
use colored::Colorize;
use nalgebra::DVector;

use alg_tools::iterate::{
    AlgIteratorFactory,
    AlgIteratorState,
};
use alg_tools::euclidean::Euclidean;
use alg_tools::linops::{Apply, GEMV};
use alg_tools::sets::Cube;
use alg_tools::loc::Loc;
use alg_tools::bisection_tree::{
    BTFN,
    PreBTFN,
    Bounds,
    BTNodeLookup,
    BTNode,
    BTSearch,
    P2Minimise,
    SupportGenerator,
    LocalAnalysis,
    BothGenerators,
};
use alg_tools::mapping::RealMapping;
use alg_tools::nalgebra_support::ToNalgebraRealField;

use crate::types::*;
use crate::measures::{
    DiscreteMeasure,
    DeltaMeasure,
};
use crate::measures::merging::{
    SpikeMergingMethod,
    SpikeMerging,
};
use crate::forward_model::ForwardModel;
use crate::seminorms::DiscreteMeasureOp;
use crate::subproblem::{
    InnerSettings,
    InnerMethod,
};
use crate::tolerance::Tolerance;
use crate::plot::{
    SeqPlotter,
    Plotting,
    PlotLookup
};
use crate::regularisation::RegTerm;
use crate::dataterm::{
    calculate_residual,
    L2Squared,
    DataTerm,
};

/// Method for constructing $μ$ on each iteration
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[allow(dead_code)]
pub enum InsertionStyle {
    /// Resuse previous $μ$ from previous iteration, optimising weights
    /// before inserting new spikes.
    Reuse,
    /// Start each iteration with $μ=0$.
    Zero,
}

/// Settings for [`pointsource_fb_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct FBConfig<F : Float> {
    /// Step length scaling
    pub τ0 : F,
    /// Generic parameters
    pub insertion : FBGenericConfig<F>,
}

/// Settings for the solution of the stepwise optimality condition in algorithms based on
/// [`generic_pointsource_fb_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct FBGenericConfig<F : Float> {
    /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`].
    pub insertion_style : InsertionStyle,
    /// Tolerance for point insertion.
    pub tolerance : Tolerance<F>,
    /// Stop looking for predual maximum (where to isert a new point) below
    /// `tolerance` multiplied by this factor.
    pub insertion_cutoff_factor : F,
    /// Settings for branch and bound refinement when looking for predual maxima
    pub refinement : RefinementSettings<F>,
    /// Maximum insertions within each outer iteration
    pub max_insertions : usize,
    /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations.
    pub bootstrap_insertions : Option<(usize, usize)>,
    /// Inner method settings
    pub inner : InnerSettings<F>,
    /// Spike merging method
    pub merging : SpikeMergingMethod<F>,
    /// Tolerance multiplier for merges
    pub merge_tolerance_mult : F,
    /// Spike merging method after the last step
    pub final_merging : SpikeMergingMethod<F>,
    /// Iterations between merging heuristic tries
    pub merge_every : usize,
    /// Save $μ$ for postprocessing optimisation
    pub postprocessing : bool
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for FBConfig<F> {
    fn default() -> Self {
        FBConfig {
            τ0 : 0.99,
            insertion : Default::default()
        }
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for FBGenericConfig<F> {
    fn default() -> Self {
        FBGenericConfig {
            insertion_style : InsertionStyle::Reuse,
            tolerance : Default::default(),
            insertion_cutoff_factor : 1.0,
            refinement : Default::default(),
            max_insertions : 100,
            //bootstrap_insertions : None,
            bootstrap_insertions : Some((10, 1)),
            inner : InnerSettings {
                method : InnerMethod::SSN,
                .. Default::default()
            },
            merging : SpikeMergingMethod::None,
            //merging : Default::default(),
            final_merging : Default::default(),
            merge_every : 10,
            merge_tolerance_mult : 2.0,
            postprocessing : false,
        }
    }
}

#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn μ_diff<F : Float, const N : usize>(
    μ_new : &DiscreteMeasure<Loc<F, N>, F>,
    μ_base : &DiscreteMeasure<Loc<F, N>, F>,
    ν_delta : Option<&DiscreteMeasure<Loc<F, N>, F>>,
    config : &FBGenericConfig<F>
) -> DiscreteMeasure<Loc<F, N>, F> {
    let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style {
        InsertionStyle::Reuse => {
            μ_new.iter_spikes()
                 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0)))
                 .map(|(δ, α_base)| (δ.x, α_base - δ.α))
                 .collect()
        },
        InsertionStyle::Zero => {
            μ_new.iter_spikes()
                 .map(|δ| -δ)
                 .chain(μ_base.iter_spikes().copied())
                 .collect()
        }
    };
    ν.prune(); // Potential small performance improvement
    // Add ν_delta if given
    match ν_delta {
        None => ν,
        Some(ν_d) => ν + ν_d,
    }
}

#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn insert_and_reweigh<
    'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
>(
    μ : &mut DiscreteMeasure<Loc<F, N>, F>,
    minus_τv : &BTFN<F, GA, BTA, N>,
    μ_base : &DiscreteMeasure<Loc<F, N>, F>,
    ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>,
    op𝒟 : &'a 𝒟,
    op𝒟norm : F,
    τ : F,
    ε : F,
    config : &FBGenericConfig<F>,
    reg : &Reg,
    state : &State,
    stats : &mut IterInfo<F, N>,
) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool)
where F : Float + ToNalgebraRealField,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
      𝒟::Codomain : RealMapping<F, N>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
      Reg : RegTerm<F, N>,
      State : AlgIteratorState {

    // Maximum insertion count and measure difference calculation depend on insertion style.
    let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
        (i, Some((l, k))) if i <= l => (k, false),
        _ => (config.max_insertions, !state.is_quiet()),
    };
    let max_insertions = match config.insertion_style {
        InsertionStyle::Zero => {
            todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled.");
            // let n = μ.len();
            // μ = DiscreteMeasure::new();
            // n + m
        },
        InsertionStyle::Reuse => m,
    };

    // TODO: should avoid a second copy of μ here; μ_base already stores a copy.
    let ω0 = op𝒟.apply(match ν_delta {
        None => μ.clone(),
        Some(ν_d) => &*μ + ν_d,
    });

    // Add points to support until within error tolerance or maximum insertion count reached.
    let mut count = 0;
    let (within_tolerances, d) = 'insertion: loop {
        if μ.len() > 0 {
            // Form finite-dimensional subproblem. The subproblem references to the original μ^k
            // from the beginning of the iteration are all contained in the immutable c and g.
            let à = op𝒟.findim_matrix(μ.iter_locations());
            let g̃ = DVector::from_iterator(μ.len(),
                                           μ.iter_locations()
                                            .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ))
                                            .map(F::to_nalgebra_mixed));
            let mut x = μ.masses_dvector();

            // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
            // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
            // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤  sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
            // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2
            // = n |𝒟| |x|_2, where n is the number of points. Therefore
            let Ã_normest = op𝒟norm * F::cast_from(μ.len());

            // Solve finite-dimensional subproblem.
            stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config);

            // Update masses of μ based on solution of finite-dimensional subproblem.
            μ.set_masses_dvector(&x);
        }

        // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality
        // conditions in the predual space, and finding new points for insertion, if necessary.
        let mut d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_base, ν_delta, config));

        // If no merging heuristic is used, let's be more conservative about spike insertion,
        // and skip it after first round. If merging is done, being more greedy about spike
        // insertion also seems to improve performance.
        let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging {
            false
        } else {
            count > 0
        };

        // Find a spike to insert, if needed
        let (ξ, _v_ξ, in_bounds) =  match reg.find_tolerance_violation(
            &mut d, τ, ε, skip_by_rough_check, config
        ) {
            None => break 'insertion (true, d),
            Some(res) => res,
        };

        // Break if maximum insertion count reached
        if count >= max_insertions {
            break 'insertion (in_bounds, d)
        }

        // No point in optimising the weight here; the finite-dimensional algorithm is fast.
        *μ += DeltaMeasure { x : ξ, α : 0.0 };
        count += 1;
    };

    // TODO: should redo everything if some transports cause a problem.
    // Maybe implementation should call above loop as a closure.

    if !within_tolerances && warn_insertions {
        // Complain (but continue) if we failed to get within tolerances
        // by inserting more points.
        let err = format!("Maximum insertions reached without achieving \
                            subproblem solution tolerance");
        println!("{}", err.red());
    }

    (d, within_tolerances)
}

#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn prune_and_maybe_simple_merge<
    'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
>(
    μ : &mut DiscreteMeasure<Loc<F, N>, F>,
    minus_τv : &BTFN<F, GA, BTA, N>,
    μ_base : &DiscreteMeasure<Loc<F, N>, F>,
    op𝒟 : &'a 𝒟,
    τ : F,
    ε : F,
    config : &FBGenericConfig<F>,
    reg : &Reg,
    state : &State,
    stats : &mut IterInfo<F, N>,
)
where F : Float + ToNalgebraRealField,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
      𝒟::Codomain : RealMapping<F, N>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
      Reg : RegTerm<F, N>,
      State : AlgIteratorState {
    if state.iteration() % config.merge_every == 0 {
        let n_before_merge = μ.len();
        μ.merge_spikes(config.merging, |μ_candidate| {
            let μd = μ_diff(&μ_candidate, &μ_base, None, config);
            let mut d = minus_τv + op𝒟.preapply(μd);

            reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
                .then_some(())
        });
        debug_assert!(μ.len() >= n_before_merge);
        stats.merged += μ.len() - n_before_merge;
    }

    let n_before_prune = μ.len();
    μ.prune();
    debug_assert!(μ.len() <= n_before_prune);
    stats.pruned += n_before_prune - μ.len();
}

#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn postprocess<
    F : Float,
    V : Euclidean<F> + Clone,
    A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>,
    D : DataTerm<F, V, N>,
    const N : usize
> (
    mut μ : DiscreteMeasure<Loc<F, N>, F>,
    config : &FBGenericConfig<F>,
    dataterm : D,
    opA : &A,
    b : &V,
) -> DiscreteMeasure<Loc<F, N>, F>
where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
    μ.merge_spikes_fitness(config.merging,
                           |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
                           |&v| v);
    μ.prune();
    μ
}

/// Iteratively solve the pointsource localisation problem using forward-backward splitting.
///
/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
/// as documented in [`alg_tools::iterate`].
///
/// For details on the mathematical formulation, see the [module level](self) documentation.
///
/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
/// sums of simple functions usign bisection trees, and the related
/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
/// active at a specific points, and to maximise their sums. Through the implementation of the
/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
///
/// Returns the final iterate.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_fb_reg<
    'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize
>(
    opA : &'a A,
    b : &A::Observable,
    reg : Reg,
    op𝒟 : &'a 𝒟,
    fbconfig : &FBConfig<F>,
    iterator : I,
    mut plotter : SeqPlotter<F, N>,
) -> DiscreteMeasure<Loc<F, N>, F>
where F : Float + ToNalgebraRealField,
      I : AlgIteratorFactory<IterInfo<F, N>>,
      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
      A::Observable : std::ops::MulAssign<F>,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
          + Lipschitz<&'a 𝒟, FloatType=F>,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
      𝒟::Codomain : RealMapping<F, N>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
      PlotLookup : Plotting<N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
      Reg : RegTerm<F, N> {

    // Set up parameters
    let config = &fbconfig.insertion;
    let op𝒟norm = op𝒟.opnorm_bound();
    let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    // Initialise iterates
    let mut μ = DiscreteMeasure::new();
    let mut residual = -b;
    let mut stats = IterInfo::new();

    // Run the algorithm
    iterator.iterate(|state| {
        // Calculate smooth part of surrogate model.
        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
        // the residual and replacing it below before the end of this closure.
        residual *= -τ;
        let r = std::mem::replace(&mut residual, opA.empty_observable());
        let minus_τv = opA.preadjoint().apply(r);

        // Save current base point
        let μ_base = μ.clone();
            
        // Insert and reweigh
        let (d, within_tolerances) = insert_and_reweigh(
            &mut μ, &minus_τv, &μ_base, None,
            op𝒟, op𝒟norm,
            τ, ε,
            config, &reg, state, &mut stats
        );

        // Prune and possibly merge spikes
        prune_and_maybe_simple_merge(
            &mut μ, &minus_τv, &μ_base,
            op𝒟,
            τ, ε,
            config, &reg, state, &mut stats
        );

        // Update residual
        residual = calculate_residual(&μ, opA, b);

        // Update main tolerance for next iteration
        let ε_prev = ε;
        ε = tolerance.update(ε, state.iteration());
        stats.this_iters += 1;

        // Give function value if needed
        state.if_verbose(|| {
            // Plot if so requested
            plotter.plot_spikes(
                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
                "start".to_string(), Some(&minus_τv),
                reg.target_bounds(τ, ε_prev), &μ,
            );
            // Calculate mean inner iterations and reset relevant counters.
            // Return the statistics
            let res = IterInfo {
                value : residual.norm2_squared_div2() + reg.apply(&μ),
                n_spikes : μ.len(),
                ε : ε_prev,
                postprocessing: config.postprocessing.then(|| μ.clone()),
                .. stats
            };
            stats = IterInfo::new();
            res
        })
    });

    postprocess(μ, config, L2Squared, opA, b)
}

/// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting.
///
/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
/// as documented in [`alg_tools::iterate`].
///
/// For details on the mathematical formulation, see the [module level](self) documentation.
///
/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
/// sums of simple functions usign bisection trees, and the related
/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
/// active at a specific points, and to maximise their sums. Through the implementation of the
/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
///
/// Returns the final iterate.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_fista_reg<
    'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize
>(
    opA : &'a A,
    b : &A::Observable,
    reg : Reg,
    op𝒟 : &'a 𝒟,
    fbconfig : &FBConfig<F>,
    iterator : I,
    mut plotter : SeqPlotter<F, N>,
) -> DiscreteMeasure<Loc<F, N>, F>
where F : Float + ToNalgebraRealField,
      I : AlgIteratorFactory<IterInfo<F, N>>,
      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
      A::Observable : std::ops::MulAssign<F>,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
          + Lipschitz<&'a 𝒟, FloatType=F>,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
      𝒟::Codomain : RealMapping<F, N>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
      PlotLookup : Plotting<N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
      Reg : RegTerm<F, N> {

    // Set up parameters
    let config = &fbconfig.insertion;
    let op𝒟norm = op𝒟.opnorm_bound();
    let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
    let mut λ = 1.0;
    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    // Initialise iterates
    let mut μ = DiscreteMeasure::new();
    let mut μ_prev = DiscreteMeasure::new();
    let mut residual = -b;
    let mut stats = IterInfo::new();
    let mut warned_merging = false;

    // Run the algorithm
    iterator.iterate(|state| {
        // Calculate smooth part of surrogate model.
        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
        // the residual and replacing it below before the end of this closure.
        residual *= -τ;
        let r = std::mem::replace(&mut residual, opA.empty_observable());
        let minus_τv = opA.preadjoint().apply(r);

        // Save current base point
        let μ_base = μ.clone();
            
        // Insert new spikes and reweigh
        let (d, within_tolerances) = insert_and_reweigh(
            &mut μ, &minus_τv, &μ_base, None,
            op𝒟, op𝒟norm,
            τ, ε,
            config, &reg, state, &mut stats
        );

        // (Do not) merge spikes.
        if state.iteration() % config.merge_every == 0 {
            match config.merging {
                SpikeMergingMethod::None => { },
                _ => if !warned_merging {
                    let err = format!("Merging not supported for μFISTA");
                    println!("{}", err.red());
                    warned_merging = true;
                }
            }
        }

        // Update inertial prameters
        let λ_prev = λ;
        λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() );
        let θ = λ / λ_prev - λ;

        // Perform inertial update on μ.
        // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ
        // and μ_prev have zero weight. Since both have weights from the finite-dimensional
        // subproblem with a proximal projection step, this is likely to happen when the
        // spike is not needed. A copy of the pruned μ without artithmetic performed is
        // stored in μ_prev.
        let n_before_prune = μ.len();
        μ.pruning_sub(1.0 + θ, θ, &mut μ_prev);
        debug_assert!(μ.len() <= n_before_prune);
        stats.pruned += n_before_prune - μ.len();

        // Update residual
        residual = calculate_residual(&μ, opA, b);

        // Update main tolerance for next iteration
        let ε_prev = ε;
        ε = tolerance.update(ε, state.iteration());
        stats.this_iters += 1;

        // Give function value if needed
        state.if_verbose(|| {
            // Plot if so requested
            plotter.plot_spikes(
                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
                "start".to_string(), Some(&minus_τv),
                reg.target_bounds(τ, ε_prev), &μ_prev,
            );
            // Calculate mean inner iterations and reset relevant counters.
            // Return the statistics
            let res = IterInfo {
                value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev),
                n_spikes : μ_prev.len(),
                ε : ε_prev,
                postprocessing: config.postprocessing.then(|| μ_prev.clone()),
                .. stats
            };
            stats = IterInfo::new();
            res
        })
    });

    postprocess(μ_prev, config, L2Squared, opA, b)
}

mercurial