src/frank_wolfe.rs

Mon, 17 Feb 2025 14:10:52 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Mon, 17 Feb 2025 14:10:52 -0500
changeset 54
b3312eee105c
parent 39
6316d68b58af
permissions
-rw-r--r--

Make some math in documentation render

/*!
Solver for the point source localisation problem using a conditional gradient method.

We implement two variants, the “fully corrective” method from

  * Pieper K., Walter D. _Linear convergence of accelerated conditional gradient algorithms
    in spaces of measures_, DOI: [10.1051/cocv/2021042](https://doi.org/10.1051/cocv/2021042),
    arXiv: [1904.09218](https://doi.org/10.48550/arXiv.1904.09218).

and what we call the “relaxed” method from

  * Bredies K., Pikkarainen H. - _Inverse problems in spaces of measures_,
    DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205).
*/

use numeric_literals::replace_float_literals;
use nalgebra::{DMatrix, DVector};
use serde::{Serialize, Deserialize};
//use colored::Colorize;

use alg_tools::iterate::{
    AlgIteratorFactory,
    AlgIteratorOptions,
    ValueIteratorFactory,
};
use alg_tools::euclidean::Euclidean;
use alg_tools::norms::Norm;
use alg_tools::linops::Mapping;
use alg_tools::sets::Cube;
use alg_tools::loc::Loc;
use alg_tools::bisection_tree::{
    BTFN,
    Bounds,
    BTNodeLookup,
    BTNode,
    BTSearch,
    P2Minimise,
    SupportGenerator,
    LocalAnalysis,
};
use alg_tools::mapping::RealMapping;
use alg_tools::nalgebra_support::ToNalgebraRealField;
use alg_tools::norms::L2;

use crate::types::*;
use crate::measures::{
    RNDM,
    DiscreteMeasure,
    DeltaMeasure,
    Radon,
};
use crate::measures::merging::{
    SpikeMergingMethod,
    SpikeMerging,
};
use crate::forward_model::ForwardModel;
#[allow(unused_imports)] // Used in documentation
use crate::subproblem::{
    unconstrained::quadratic_unconstrained,
    nonneg::quadratic_nonneg,
    InnerSettings,
    InnerMethod,
};
use crate::tolerance::Tolerance;
use crate::plot::{
    SeqPlotter,
    Plotting,
    PlotLookup
};
use crate::regularisation::{
    NonnegRadonRegTerm,
    RadonRegTerm,
    RegTerm
};

/// Settings for [`pointsource_fw_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct FWConfig<F : Float> {
    /// Tolerance for branch-and-bound new spike location discovery
    pub tolerance : Tolerance<F>,
    /// Inner problem solution configuration. Has to have `method` set to [`InnerMethod::FB`]
    /// as the conditional gradient subproblems' optimality conditions do not in general have an
    /// invertible Newton derivative for SSN.
    pub inner : InnerSettings<F>,
    /// Variant of the conditional gradient method
    pub variant : FWVariant,
    /// Settings for branch and bound refinement when looking for predual maxima
    pub refinement : RefinementSettings<F>,
    /// Spike merging heuristic
    pub merging : SpikeMergingMethod<F>,
}

/// Conditional gradient method variant; see also [`FWConfig`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[allow(dead_code)]
pub enum FWVariant {
    /// Algorithm 2 of Walter-Pieper
    FullyCorrective,
    /// Bredies–Pikkarainen. Forces `FWConfig.inner.max_iter = 1`.
    Relaxed,
}

impl<F : Float> Default for FWConfig<F> {
    fn default() -> Self {
        FWConfig {
            tolerance : Default::default(),
            refinement : Default::default(),
            inner : Default::default(),
            variant : FWVariant::FullyCorrective,
            merging : SpikeMergingMethod { enabled : true, ..Default::default() },
        }
    }
}

pub trait FindimQuadraticModel<Domain, F> : ForwardModel<DiscreteMeasure<Domain, F>, F>
where
    F : Float + ToNalgebraRealField,
    Domain : Clone + PartialEq,
{
    /// Return A_*A and A_* b
    fn findim_quadratic_model(
        &self,
        μ : &DiscreteMeasure<Domain, F>,
        b : &Self::Observable
    ) -> (DMatrix<F::MixedType>, DVector<F::MixedType>);
}

/// Helper struct for pre-initialising the finite-dimensional subproblem solver.
pub struct FindimData<F : Float> {
    /// ‖A‖^2
    opAnorm_squared : F,
    /// Bound $M_0$ from the Bredies–Pikkarainen article.
    m0 : F
}

/// Trait for finite dimensional weight optimisation.
pub trait WeightOptim<
    F : Float + ToNalgebraRealField,
    A : ForwardModel<RNDM<F, N>, F>,
    I : AlgIteratorFactory<F>,
    const N : usize
> {

    /// Return a pre-initialisation struct for [`Self::optimise_weights`].
    ///
    /// The parameter `opA` is the forward operator $A$.
    fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F>;

    /// Solve the finite-dimensional weight optimisation problem for the 2-norm-squared data fidelity
    /// point source localisation problem.
    ///
    /// That is, we minimise
    /// <div>$$
    ///     μ ↦ \frac{1}{2}\|Aμ-b\|_w^2 + G(μ)
    /// $$</div>
    /// only with respect to the weights of $μ$. Here $G$ is a regulariser modelled by `Self`.
    ///
    /// The parameter `μ` is the discrete measure whose weights are to be optimised.
    /// The `opA` parameter is the forward operator $A$, while `b`$ and `α` are as in the
    /// objective above. The method parameter are set in `inner` (see [`InnerSettings`]), while
    /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to
    /// save intermediate iteration states as images. The parameter `findim_data` should be
    /// prepared using [`Self::prepare_optimise_weights`]:
    ///
    /// Returns the number of iterations taken by the method configured in `inner`.
    fn optimise_weights<'a>(
        &self,
        μ : &mut RNDM<F, N>,
        opA : &'a A,
        b : &A::Observable,
        findim_data : &FindimData<F>,
        inner : &InnerSettings<F>,
        iterator : I
    ) -> usize;
}

/// Trait for regularisation terms supported by [`pointsource_fw_reg`].
pub trait RegTermFW<
    F : Float + ToNalgebraRealField,
    A : ForwardModel<RNDM<F, N>, F>,
    I : AlgIteratorFactory<F>,
    const N : usize
> : RegTerm<F, N>
    + WeightOptim<F, A, I, N>
    + Mapping<RNDM<F, N>, Codomain = F> {

    /// With $g = A\_\*(Aμ-b)$, returns $(x, g(x))$ for $x$ a new point to be inserted
    /// into $μ$, as determined by the regulariser.
    ///
    /// The parameters `refinement_tolerance` and `max_steps` are passed to relevant
    /// [`BTFN`] minimisation and maximisation routines.
    fn find_insertion(
        &self,
        g : &mut A::PreadjointCodomain,
        refinement_tolerance : F,
        max_steps : usize
    ) -> (Loc<F, N>, F);

    /// Insert point `ξ` into `μ` for the relaxed algorithm from Bredies–Pikkarainen.
    fn relaxed_insert<'a>(
        &self,
        μ : &mut RNDM<F, N>,
        g : &A::PreadjointCodomain,
        opA : &'a A,
        ξ : Loc<F, N>,
        v_ξ : F,
        findim_data : &FindimData<F>
    );
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N>
for RadonRegTerm<F>
where I : AlgIteratorFactory<F>,
      A : FindimQuadraticModel<Loc<F, N>, F>  {

    fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> {
        FindimData{
            opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2),
            m0 : b.norm2_squared() / (2.0 * self.α()),
        }
    }

    fn optimise_weights<'a>(
        &self,
        μ : &mut RNDM<F, N>,
        opA : &'a A,
        b : &A::Observable,
        findim_data : &FindimData<F>,
        inner : &InnerSettings<F>,
        iterator : I
    ) -> usize {

        // Form and solve finite-dimensional subproblem.
        let (Ã, g̃) = opA.findim_quadratic_model(&μ, b);
        let mut x = μ.masses_dvector();

        // `inner_τ1` is based on an estimate of the operator norm of $A$ from ℳ(Ω) to
        // ℝ^n. This estimate is a good one for the matrix norm from ℝ^m to ℝ^n when the
        // former is equipped with the 1-norm. We need the 2-norm. To pass from 1-norm to
        // 2-norm, we estimate
        //      ‖A‖_{2,2} := sup_{‖x‖_2 ≤ 1} ‖Ax‖_2 ≤ sup_{‖x‖_1 ≤ C} ‖Ax‖_2
        //                 = C sup_{‖x‖_1 ≤ 1} ‖Ax‖_2 = C ‖A‖_{1,2},
        // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no
        // square root is needed when we scale:
        let normest = findim_data.opAnorm_squared * F::cast_from(μ.len());
        let iters = quadratic_unconstrained(&Ã, &g̃, self.α(), &mut x,
                                            normest, inner, iterator);
        // Update masses of μ based on solution of finite-dimensional subproblem.
        μ.set_masses_dvector(&x);

        iters
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N>
for RadonRegTerm<F>
where
    Cube<F, N> : P2Minimise<Loc<F, N>, F>,
    I : AlgIteratorFactory<F>,
    S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
    GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
    A : FindimQuadraticModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
    BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
    // FIXME: the following *should not* be needed, they are already implied
    RNDM<F, N> : Mapping<A::PreadjointCodomain, Codomain = F>,
    DeltaMeasure<Loc<F, N>, F> : Mapping<A::PreadjointCodomain, Codomain = F>,
    //A : Mapping<RNDM<F, N>, Codomain = A::Observable>,
    //A : Mapping<DeltaMeasure<Loc<F, N>, F>, Codomain = A::Observable>,
{

    fn find_insertion(
        &self,
        g : &mut A::PreadjointCodomain,
        refinement_tolerance : F,
        max_steps : usize
    ) -> (Loc<F, N>, F) {
        let (ξmax, v_ξmax) = g.maximise(refinement_tolerance, max_steps);
        let (ξmin, v_ξmin) = g.minimise(refinement_tolerance, max_steps);
        if v_ξmin < 0.0 && -v_ξmin > v_ξmax {
            (ξmin, v_ξmin)
        } else {
            (ξmax, v_ξmax)
        }
    }

    fn relaxed_insert<'a>(
        &self,
        μ : &mut RNDM<F, N>,
        g : &A::PreadjointCodomain,
        opA : &'a A,
        ξ : Loc<F, N>,
        v_ξ : F,
        findim_data : &FindimData<F>
    ) {
        let α = self.0;
        let m0 = findim_data.m0;
        let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) };
        let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ };
        let δ = DeltaMeasure { x : ξ, α : v };
        let dp = μ.apply(g) - δ.apply(g);
        let d = opA.apply(&*μ) - opA.apply(δ);
        let r = d.norm2_squared();
        let s = if r == 0.0 {
            1.0
        } else {
            1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r)
        };
        *μ *= 1.0 - s;
        *μ += δ * s;
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N>
for NonnegRadonRegTerm<F>
where I : AlgIteratorFactory<F>,
      A : FindimQuadraticModel<Loc<F, N>, F> {

    fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> {
        FindimData{
            opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2),
            m0 : b.norm2_squared() / (2.0 * self.α()),
        }
    }

    fn optimise_weights<'a>(
        &self,
        μ : &mut RNDM<F, N>,
        opA : &'a A,
        b : &A::Observable,
        findim_data : &FindimData<F>,
        inner : &InnerSettings<F>,
        iterator : I
    ) -> usize {

        // Form and solve finite-dimensional subproblem.
        let (Ã, g̃) = opA.findim_quadratic_model(&μ, b);
        let mut x = μ.masses_dvector();

        // `inner_τ1` is based on an estimate of the operator norm of $A$ from ℳ(Ω) to
        // ℝ^n. This estimate is a good one for the matrix norm from ℝ^m to ℝ^n when the
        // former is equipped with the 1-norm. We need the 2-norm. To pass from 1-norm to
        // 2-norm, we estimate
        //      ‖A‖_{2,2} := sup_{‖x‖_2 ≤ 1} ‖Ax‖_2 ≤ sup_{‖x‖_1 ≤ C} ‖Ax‖_2
        //                 = C sup_{‖x‖_1 ≤ 1} ‖Ax‖_2 = C ‖A‖_{1,2},
        // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no
        // square root is needed when we scale:
        let normest = findim_data.opAnorm_squared * F::cast_from(μ.len());
        let iters = quadratic_nonneg(&Ã, &g̃, self.α(), &mut x,
                                     normest, inner, iterator);
        // Update masses of μ based on solution of finite-dimensional subproblem.
        μ.set_masses_dvector(&x);

        iters
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N>
for NonnegRadonRegTerm<F>
where
    Cube<F, N> : P2Minimise<Loc<F, N>, F>,
    I : AlgIteratorFactory<F>,
    S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
    GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
    A : FindimQuadraticModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
    BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
    // FIXME: the following *should not* be needed, they are already implied
    RNDM<F, N> : Mapping<A::PreadjointCodomain, Codomain = F>,
    DeltaMeasure<Loc<F, N>, F> : Mapping<A::PreadjointCodomain, Codomain = F>,
{

    fn find_insertion(
        &self,
        g : &mut A::PreadjointCodomain,
        refinement_tolerance : F,
        max_steps : usize
    ) -> (Loc<F, N>, F) {
        g.maximise(refinement_tolerance, max_steps)
    }


    fn relaxed_insert<'a>(
        &self,
        μ : &mut RNDM<F, N>,
        g : &A::PreadjointCodomain,
        opA : &'a A,
        ξ : Loc<F, N>,
        v_ξ : F,
        findim_data : &FindimData<F>
    ) {
        // This is just a verbatim copy of RadonRegTerm::relaxed_insert.
        let α = self.0;
        let m0 = findim_data.m0;
        let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) };
        let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ };
        let δ = DeltaMeasure { x : ξ, α : v };
        let dp = μ.apply(g) - δ.apply(g);
        let d = opA.apply(&*μ) - opA.apply(&δ);
        let r = d.norm2_squared();
        let s = if r == 0.0 {
            1.0
        } else {
            1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r)
        };
        *μ *= 1.0 - s;
        *μ += δ * s;
    }
}


/// Solve point source localisation problem using a conditional gradient method
/// for the 2-norm-squared data fidelity, i.e., the problem
/// <div>$$
///     \min_μ \frac{1}{2}\|Aμ-b\|_w^2 + G(μ),
/// $$
/// where $G$ is the regularisation term modelled by `reg`.
/// </div>
///
/// The `opA` parameter is the forward operator $A$, while `b`$ is as in the
/// objective above. The method parameter are set in `config` (see [`FWConfig`]), while
/// `iterator` is used to iterate the steps of the method, and `plotter` may be used to
/// save intermediate iteration states as images.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_fw_reg<F, I, A, GA, BTA, S, Reg, const N : usize>(
    opA : &A,
    b : &A::Observable,
    reg : Reg,
    //domain : Cube<F, N>,
    config : &FWConfig<F>,
    iterator : I,
    mut plotter : SeqPlotter<F, N>,
) -> RNDM<F, N>
where F : Float + ToNalgebraRealField,
      I : AlgIteratorFactory<IterInfo<F, N>>,
      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
      PlotLookup : Plotting<N>,
      RNDM<F, N> : SpikeMerging<F>,
      Reg : RegTermFW<F, A, ValueIteratorFactory<F, AlgIteratorOptions>, N> {

    // Set up parameters
    // We multiply tolerance by α for all algoritms.
    let tolerance = config.tolerance * reg.tolerance_scaling();
    let mut ε = tolerance.initial();
    let findim_data = reg.prepare_optimise_weights(opA, b);

    // Initialise operators
    let preadjA = opA.preadjoint();

    // Initialise iterates
    let mut μ = DiscreteMeasure::new();
    let mut residual = -b;

    // Statistics
    let full_stats = |residual : &A::Observable,
                      ν : &RNDM<F, N>,
                      ε, stats| IterInfo {
        value : residual.norm2_squared_div2() + reg.apply(ν),
        n_spikes : ν.len(),
        ε,
        .. stats
    };
    let mut stats = IterInfo::new();

    // Run the algorithm
    for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
        let inner_tolerance = ε * config.inner.tolerance_mult;
        let refinement_tolerance = ε * config.refinement.tolerance_mult;

        // Calculate smooth part of surrogate model.
        let mut g = preadjA.apply(residual * (-1.0));

        // Find absolute value maximising point
        let (ξ, v_ξ) = reg.find_insertion(&mut g, refinement_tolerance,
                                          config.refinement.max_steps);

        let inner_it = match config.variant {
            FWVariant::FullyCorrective => {
                // No point in optimising the weight here: the finite-dimensional algorithm is fast.
                μ += DeltaMeasure { x : ξ, α : 0.0 };
                stats.inserted += 1;
                config.inner.iterator_options.stop_target(inner_tolerance)
            },
            FWVariant::Relaxed => {
                // Perform a relaxed initialisation of μ
                reg.relaxed_insert(&mut μ, &g, opA, ξ, v_ξ, &findim_data);
                stats.inserted += 1;
                // The stop_target is only needed for the type system.
                AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0)
            }
        };

        stats.inner_iters += reg.optimise_weights(&mut μ, opA, b, &findim_data,
                                                  &config.inner, inner_it);
   
        // Merge spikes and update residual for next step and `if_verbose` below.
        let (r, count) = μ.merge_spikes_fitness(config.merging,
                                                |μ̃| opA.apply(μ̃) - b,
                                                A::Observable::norm2_squared);
        residual = r;
        stats.merged += count;

        // Prune points with zero mass
        let n_before_prune = μ.len();
        μ.prune();
        debug_assert!(μ.len() <= n_before_prune);
        stats.pruned += n_before_prune - μ.len();

        stats.this_iters += 1;
        let iter = state.iteration();

        // Give statistics if needed
        state.if_verbose(|| {
            plotter.plot_spikes(iter, Some(&g), Option::<&S>::None, &μ);
            full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
        });

        // Update tolerance
        ε = tolerance.update(ε, iter);
    }

    // Return final iterate
    μ
}

mercurial