src/radon_fb.rs

Thu, 29 Aug 2024 00:00:00 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 29 Aug 2024 00:00:00 -0500
branch
dev
changeset 34
efa60bc4f743
child 35
b087e3eab191
permissions
-rw-r--r--

Radon FB + sliding improvements

/*!
Solver for the point source localisation problem using a simplified forward-backward splitting method.

Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map.
*/

use numeric_literals::replace_float_literals;
use serde::{Serialize, Deserialize};
use colored::Colorize;
use nalgebra::DVector;

use alg_tools::iterate::{
    AlgIteratorFactory,
    AlgIteratorState,
};
use alg_tools::euclidean::Euclidean;
use alg_tools::linops::Apply;
use alg_tools::sets::Cube;
use alg_tools::loc::Loc;
use alg_tools::bisection_tree::{
    BTFN,
    Bounds,
    BTNodeLookup,
    BTNode,
    BTSearch,
    P2Minimise,
    SupportGenerator,
    LocalAnalysis,
};
use alg_tools::mapping::RealMapping;
use alg_tools::nalgebra_support::ToNalgebraRealField;

use crate::types::*;
use crate::measures::{
    DiscreteMeasure,
    DeltaMeasure,
};
use crate::measures::merging::{
    SpikeMergingMethod,
    SpikeMerging,
};
use crate::forward_model::ForwardModel;
use crate::plot::{
    SeqPlotter,
    Plotting,
    PlotLookup
};
use crate::regularisation::RegTerm;
use crate::dataterm::{
    calculate_residual,
    L2Squared,
    DataTerm,
};

use crate::fb::{
    FBGenericConfig,
    postprocess
};

/// Settings for [`pointsource_fb_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct RadonFBConfig<F : Float> {
    /// Step length scaling
    pub τ0 : F,
    /// Generic parameters
    pub insertion : FBGenericConfig<F>,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for RadonFBConfig<F> {
    fn default() -> Self {
        RadonFBConfig {
            τ0 : 0.99,
            insertion : Default::default()
        }
    }
}

#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn insert_and_reweigh<
    'a, F, GA, BTA, S, Reg, State, const N : usize
>(
    μ : &mut DiscreteMeasure<Loc<F, N>, F>,
    minus_τv : &mut BTFN<F, GA, BTA, N>,
    μ_base : &mut DiscreteMeasure<Loc<F, N>, F>,
    _ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>,
    τ : F,
    ε : F,
    config : &FBGenericConfig<F>,
    reg : &Reg,
    _state : &State,
    stats : &mut IterInfo<F, N>,
)
where F : Float + ToNalgebraRealField,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
      Reg : RegTerm<F, N>,
      State : AlgIteratorState {

    'i_and_w: for i in 0..=1 {
        // Optimise weights
        if μ.len() > 0 {
            // Form finite-dimensional subproblem. The subproblem references to the original μ^k
            // from the beginning of the iteration are all contained in the immutable c and g.
            let g̃ = DVector::from_iterator(μ.len(),
                                           μ.iter_locations()
                                            .map(|ζ| F::to_nalgebra_mixed(minus_τv.apply(ζ))));
            let mut x = μ.masses_dvector();
            let y = μ_base.masses_dvector();

            // Solve finite-dimensional subproblem.
            stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config);

            // Update masses of μ based on solution of finite-dimensional subproblem.
            μ.set_masses_dvector(&x);
        }

        if i>0 {
            // Simple debugging test to see if more inserts would be needed. Doesn't seem so.
            //let n = μ.dist_matching(μ_base);
            //println!("{:?}", reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n));
            break 'i_and_w
        }
        
        // Calculate ‖μ - μ_base‖_ℳ
        let n = μ.dist_matching(μ_base);
       
        // Find a spike to insert, if needed.
        // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ,
        // which are supposed to have been guaranteed by the finite-dimensional weight optimisation.
        match reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n) {
            None => { break 'i_and_w },
            Some((ξ, _v_ξ, _in_bounds)) => {
                // Weight is found out by running the finite-dimensional optimisation algorithm
                // above
                *μ += DeltaMeasure { x : ξ, α : 0.0 };
                *μ_base += DeltaMeasure { x : ξ, α : 0.0 };
            }
        };
    }
}

#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn prune_and_maybe_simple_merge<
    'a, F, GA, BTA, S, Reg, State, const N : usize
>(
    μ : &mut DiscreteMeasure<Loc<F, N>, F>,
    minus_τv : &mut BTFN<F, GA, BTA, N>,
    μ_base : &DiscreteMeasure<Loc<F, N>, F>,
    τ : F,
    ε : F,
    config : &FBGenericConfig<F>,
    reg : &Reg,
    state : &State,
    stats : &mut IterInfo<F, N>,
)
where F : Float + ToNalgebraRealField,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
      Reg : RegTerm<F, N>,
      State : AlgIteratorState {

    assert!(μ_base.len() <= μ.len());

    if state.iteration() % config.merge_every == 0 {
        stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
            // Important: μ_candidate's new points are afterwards,
            // and do not conflict with μ_base.
            // TODO: could simplify to requiring μ_base instead of μ_radon.
            // but may complicate with sliding base's exgtra points that need to be
            // after μ_candidate's extra points.
            // TODO: doesn't seem to work, maybe need to merge μ_base as well?
            // Although that doesn't seem to make sense.
            let μ_radon = μ_candidate.sub_matching(μ_base);
            reg.verify_merge_candidate_radonsq(minus_τv, μ_candidate, τ, ε, &config, &μ_radon)
            //let n = μ_candidate.dist_matching(μ_base);
            //reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n).is_none()
        });
    }

    let n_before_prune = μ.len();
    μ.prune();
    debug_assert!(μ.len() <= n_before_prune);
    stats.pruned += n_before_prune - μ.len();
}


/// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting.
///
/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
/// Finally, the `iterator` is an outer loop verbosity and iteration count control
/// as documented in [`alg_tools::iterate`].
///
/// For details on the mathematical formulation, see the [module level](self) documentation.
///
/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
/// sums of simple functions usign bisection trees, and the related
/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
/// active at a specific points, and to maximise their sums. Through the implementation of the
/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
///
/// Returns the final iterate.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_radon_fb_reg<
    'a, F, I, A, GA, BTA, S, Reg, const N : usize
>(
    opA : &'a A,
    b : &A::Observable,
    reg : Reg,
    fbconfig : &RadonFBConfig<F>,
    iterator : I,
    mut _plotter : SeqPlotter<F, N>,
) -> DiscreteMeasure<Loc<F, N>, F>
where F : Float + ToNalgebraRealField,
      I : AlgIteratorFactory<IterInfo<F, N>>,
      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
      A::Observable : std::ops::MulAssign<F>,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
      PlotLookup : Plotting<N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
      Reg : RegTerm<F, N> {

    // Set up parameters
    let config = &fbconfig.insertion;
    // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ
    // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such
    // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L.
    let τ = fbconfig.τ0/opA.opnorm_bound().powi(2);
    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    // Initialise iterates
    let mut μ = DiscreteMeasure::new();
    let mut residual = -b;
    let mut stats = IterInfo::new();

    // Run the algorithm
    iterator.iterate(|state| {
        // Calculate smooth part of surrogate model.
        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
        // the residual and replacing it below before the end of this closure.
        residual *= -τ;
        let r = std::mem::replace(&mut residual, opA.empty_observable());
        let mut minus_τv = opA.preadjoint().apply(r);

        // Save current base point
        let mut μ_base = μ.clone();
            
        // Insert and reweigh
        insert_and_reweigh(
            &mut μ, &mut minus_τv, &mut μ_base, None,
            τ, ε,
            config, &reg, state, &mut stats
        );

        // Prune and possibly merge spikes
        prune_and_maybe_simple_merge(
            &mut μ, &mut minus_τv, &μ_base,
            τ, ε,
            config, &reg, state, &mut stats
        );

        // Update residual
        residual = calculate_residual(&μ, opA, b);

        // Update main tolerance for next iteration
        let ε_prev = ε;
        ε = tolerance.update(ε, state.iteration());
        stats.this_iters += 1;

        // Give function value if needed
        state.if_verbose(|| {
            // Plot if so requested
            // plotter.plot_spikes(
            //     format!("iter {} end;", state.iteration()), &d,
            //     "start".to_string(), Some(&minus_τv),
            //     reg.target_bounds(τ, ε_prev), &μ,
            // );
            // Calculate mean inner iterations and reset relevant counters.
            // Return the statistics
            let res = IterInfo {
                value : residual.norm2_squared_div2() + reg.apply(&μ),
                n_spikes : μ.len(),
                ε : ε_prev,
                postprocessing: config.postprocessing.then(|| μ.clone()),
                .. stats
            };
            stats = IterInfo::new();
            res
        })
    });

    postprocess(μ, config, L2Squared, opA, b)
}

/// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting.
///
/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
/// Finally, the `iterator` is an outer loop verbosity and iteration count control
/// as documented in [`alg_tools::iterate`].
///
/// For details on the mathematical formulation, see the [module level](self) documentation.
///
/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
/// sums of simple functions usign bisection trees, and the related
/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
/// active at a specific points, and to maximise their sums. Through the implementation of the
/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
///
/// Returns the final iterate.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_radon_fista_reg<
    'a, F, I, A, GA, BTA, S, Reg, const N : usize
>(
    opA : &'a A,
    b : &A::Observable,
    reg : Reg,
    fbconfig : &RadonFBConfig<F>,
    iterator : I,
    mut _plotter : SeqPlotter<F, N>,
) -> DiscreteMeasure<Loc<F, N>, F>
where F : Float + ToNalgebraRealField,
      I : AlgIteratorFactory<IterInfo<F, N>>,
      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
      A::Observable : std::ops::MulAssign<F>,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
      PlotLookup : Plotting<N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
      Reg : RegTerm<F, N> {

    // Set up parameters
    let config = &fbconfig.insertion;
    // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ
    // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such
    // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L.
    let τ = fbconfig.τ0/opA.opnorm_bound().powi(2);
    let mut λ = 1.0;
    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    // Initialise iterates
    let mut μ = DiscreteMeasure::new();
    let mut μ_prev = DiscreteMeasure::new();
    let mut residual = -b;
    let mut stats = IterInfo::new();
    let mut warned_merging = false;

    // Run the algorithm
    iterator.iterate(|state| {
        // Calculate smooth part of surrogate model.
        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
        // the residual and replacing it below before the end of this closure.
        residual *= -τ;
        let r = std::mem::replace(&mut residual, opA.empty_observable());
        let mut minus_τv = opA.preadjoint().apply(r);

        // Save current base point
        let mut μ_base = μ.clone();
            
        // Insert new spikes and reweigh
        insert_and_reweigh(
            &mut μ, &mut minus_τv, &mut μ_base, None,
            τ, ε,
            config, &reg, state, &mut stats
        );

        // (Do not) merge spikes.
        if state.iteration() % config.merge_every == 0 {
            match config.merging {
                SpikeMergingMethod::None => { },
                _ => if !warned_merging {
                    let err = format!("Merging not supported for μFISTA");
                    println!("{}", err.red());
                    warned_merging = true;
                }
            }
        }

        // Update inertial prameters
        let λ_prev = λ;
        λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() );
        let θ = λ / λ_prev - λ;

        // Perform inertial update on μ.
        // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ
        // and μ_prev have zero weight. Since both have weights from the finite-dimensional
        // subproblem with a proximal projection step, this is likely to happen when the
        // spike is not needed. A copy of the pruned μ without artithmetic performed is
        // stored in μ_prev.
        let n_before_prune = μ.len();
        μ.pruning_sub(1.0 + θ, θ, &mut μ_prev);
        debug_assert!(μ.len() <= n_before_prune);
        stats.pruned += n_before_prune - μ.len();

        // Update residual
        residual = calculate_residual(&μ, opA, b);

        // Update main tolerance for next iteration
        let ε_prev = ε;
        ε = tolerance.update(ε, state.iteration());
        stats.this_iters += 1;

        // Give function value if needed
        state.if_verbose(|| {
            // Plot if so requested
            // plotter.plot_spikes(
            //     format!("iter {} end;", state.iteration()), &d,
            //     "start".to_string(), Some(&minus_τv),
            //     reg.target_bounds(τ, ε_prev), &μ_prev,
            // );
            // Calculate mean inner iterations and reset relevant counters.
            // Return the statistics
            let res = IterInfo {
                value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev),
                n_spikes : μ_prev.len(),
                ε : ε_prev,
                postprocessing: config.postprocessing.then(|| μ_prev.clone()),
                .. stats
            };
            stats = IterInfo::new();
            res
        })
    });

    postprocess(μ_prev, config, L2Squared, opA, b)
}

mercurial