--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/radon_fb.rs Thu Aug 29 00:00:00 2024 -0500 @@ -0,0 +1,455 @@ +/*! +Solver for the point source localisation problem using a simplified forward-backward splitting method. + +Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map. +*/ + +use numeric_literals::replace_float_literals; +use serde::{Serialize, Deserialize}; +use colored::Colorize; +use nalgebra::DVector; + +use alg_tools::iterate::{ + AlgIteratorFactory, + AlgIteratorState, +}; +use alg_tools::euclidean::Euclidean; +use alg_tools::linops::Apply; +use alg_tools::sets::Cube; +use alg_tools::loc::Loc; +use alg_tools::bisection_tree::{ + BTFN, + Bounds, + BTNodeLookup, + BTNode, + BTSearch, + P2Minimise, + SupportGenerator, + LocalAnalysis, +}; +use alg_tools::mapping::RealMapping; +use alg_tools::nalgebra_support::ToNalgebraRealField; + +use crate::types::*; +use crate::measures::{ + DiscreteMeasure, + DeltaMeasure, +}; +use crate::measures::merging::{ + SpikeMergingMethod, + SpikeMerging, +}; +use crate::forward_model::ForwardModel; +use crate::plot::{ + SeqPlotter, + Plotting, + PlotLookup +}; +use crate::regularisation::RegTerm; +use crate::dataterm::{ + calculate_residual, + L2Squared, + DataTerm, +}; + +use crate::fb::{ + FBGenericConfig, + postprocess +}; + +/// Settings for [`pointsource_fb_reg`]. +#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] +#[serde(default)] +pub struct RadonFBConfig<F : Float> { + /// Step length scaling + pub τ0 : F, + /// Generic parameters + pub insertion : FBGenericConfig<F>, +} + +#[replace_float_literals(F::cast_from(literal))] +impl<F : Float> Default for RadonFBConfig<F> { + fn default() -> Self { + RadonFBConfig { + τ0 : 0.99, + insertion : Default::default() + } + } +} + +#[replace_float_literals(F::cast_from(literal))] +pub(crate) fn insert_and_reweigh< + 'a, F, GA, BTA, S, Reg, State, const N : usize +>( + μ : &mut DiscreteMeasure<Loc<F, N>, F>, + minus_τv : &mut BTFN<F, GA, BTA, N>, + μ_base : &mut DiscreteMeasure<Loc<F, N>, F>, + _ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>, + τ : F, + ε : F, + config : &FBGenericConfig<F>, + reg : &Reg, + _state : &State, + stats : &mut IterInfo<F, N>, +) +where F : Float + ToNalgebraRealField, + GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, + BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, + S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, + BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, + DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, + Reg : RegTerm<F, N>, + State : AlgIteratorState { + + 'i_and_w: for i in 0..=1 { + // Optimise weights + if μ.len() > 0 { + // Form finite-dimensional subproblem. The subproblem references to the original μ^k + // from the beginning of the iteration are all contained in the immutable c and g. + let g̃ = DVector::from_iterator(μ.len(), + μ.iter_locations() + .map(|ζ| F::to_nalgebra_mixed(minus_τv.apply(ζ)))); + let mut x = μ.masses_dvector(); + let y = μ_base.masses_dvector(); + + // Solve finite-dimensional subproblem. + stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config); + + // Update masses of μ based on solution of finite-dimensional subproblem. + μ.set_masses_dvector(&x); + } + + if i>0 { + // Simple debugging test to see if more inserts would be needed. Doesn't seem so. + //let n = μ.dist_matching(μ_base); + //println!("{:?}", reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n)); + break 'i_and_w + } + + // Calculate ‖μ - μ_base‖_ℳ + let n = μ.dist_matching(μ_base); + + // Find a spike to insert, if needed. + // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, + // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. + match reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n) { + None => { break 'i_and_w }, + Some((ξ, _v_ξ, _in_bounds)) => { + // Weight is found out by running the finite-dimensional optimisation algorithm + // above + *μ += DeltaMeasure { x : ξ, α : 0.0 }; + *μ_base += DeltaMeasure { x : ξ, α : 0.0 }; + } + }; + } +} + +#[replace_float_literals(F::cast_from(literal))] +pub(crate) fn prune_and_maybe_simple_merge< + 'a, F, GA, BTA, S, Reg, State, const N : usize +>( + μ : &mut DiscreteMeasure<Loc<F, N>, F>, + minus_τv : &mut BTFN<F, GA, BTA, N>, + μ_base : &DiscreteMeasure<Loc<F, N>, F>, + τ : F, + ε : F, + config : &FBGenericConfig<F>, + reg : &Reg, + state : &State, + stats : &mut IterInfo<F, N>, +) +where F : Float + ToNalgebraRealField, + GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, + BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, + S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, + BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, + DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, + Reg : RegTerm<F, N>, + State : AlgIteratorState { + + assert!(μ_base.len() <= μ.len()); + + if state.iteration() % config.merge_every == 0 { + stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { + // Important: μ_candidate's new points are afterwards, + // and do not conflict with μ_base. + // TODO: could simplify to requiring μ_base instead of μ_radon. + // but may complicate with sliding base's exgtra points that need to be + // after μ_candidate's extra points. + // TODO: doesn't seem to work, maybe need to merge μ_base as well? + // Although that doesn't seem to make sense. + let μ_radon = μ_candidate.sub_matching(μ_base); + reg.verify_merge_candidate_radonsq(minus_τv, μ_candidate, τ, ε, &config, &μ_radon) + //let n = μ_candidate.dist_matching(μ_base); + //reg.find_tolerance_violation_slack(minus_τv, τ, ε, false, config, n).is_none() + }); + } + + let n_before_prune = μ.len(); + μ.prune(); + debug_assert!(μ.len() <= n_before_prune); + stats.pruned += n_before_prune - μ.len(); +} + + +/// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting. +/// +/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the +/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. +/// Finally, the `iterator` is an outer loop verbosity and iteration count control +/// as documented in [`alg_tools::iterate`]. +/// +/// For details on the mathematical formulation, see the [module level](self) documentation. +/// +/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of +/// sums of simple functions usign bisection trees, and the related +/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions +/// active at a specific points, and to maximise their sums. Through the implementation of the +/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features +/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. +/// +/// Returns the final iterate. +#[replace_float_literals(F::cast_from(literal))] +pub fn pointsource_radon_fb_reg< + 'a, F, I, A, GA, BTA, S, Reg, const N : usize +>( + opA : &'a A, + b : &A::Observable, + reg : Reg, + fbconfig : &RadonFBConfig<F>, + iterator : I, + mut _plotter : SeqPlotter<F, N>, +) -> DiscreteMeasure<Loc<F, N>, F> +where F : Float + ToNalgebraRealField, + I : AlgIteratorFactory<IterInfo<F, N>>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, + //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow + A::Observable : std::ops::MulAssign<F>, + GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, + A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, + BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, + S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, + BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, + Cube<F, N>: P2Minimise<Loc<F, N>, F>, + PlotLookup : Plotting<N>, + DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, + Reg : RegTerm<F, N> { + + // Set up parameters + let config = &fbconfig.insertion; + // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ + // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such + // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. + let τ = fbconfig.τ0/opA.opnorm_bound().powi(2); + // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled + // by τ compared to the conditional gradient approach. + let tolerance = config.tolerance * τ * reg.tolerance_scaling(); + let mut ε = tolerance.initial(); + + // Initialise iterates + let mut μ = DiscreteMeasure::new(); + let mut residual = -b; + let mut stats = IterInfo::new(); + + // Run the algorithm + iterator.iterate(|state| { + // Calculate smooth part of surrogate model. + // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` + // has no significant overhead. For some reosn Rust doesn't allow us simply moving + // the residual and replacing it below before the end of this closure. + residual *= -τ; + let r = std::mem::replace(&mut residual, opA.empty_observable()); + let mut minus_τv = opA.preadjoint().apply(r); + + // Save current base point + let mut μ_base = μ.clone(); + + // Insert and reweigh + insert_and_reweigh( + &mut μ, &mut minus_τv, &mut μ_base, None, + τ, ε, + config, ®, state, &mut stats + ); + + // Prune and possibly merge spikes + prune_and_maybe_simple_merge( + &mut μ, &mut minus_τv, &μ_base, + τ, ε, + config, ®, state, &mut stats + ); + + // Update residual + residual = calculate_residual(&μ, opA, b); + + // Update main tolerance for next iteration + let ε_prev = ε; + ε = tolerance.update(ε, state.iteration()); + stats.this_iters += 1; + + // Give function value if needed + state.if_verbose(|| { + // Plot if so requested + // plotter.plot_spikes( + // format!("iter {} end;", state.iteration()), &d, + // "start".to_string(), Some(&minus_τv), + // reg.target_bounds(τ, ε_prev), &μ, + // ); + // Calculate mean inner iterations and reset relevant counters. + // Return the statistics + let res = IterInfo { + value : residual.norm2_squared_div2() + reg.apply(&μ), + n_spikes : μ.len(), + ε : ε_prev, + postprocessing: config.postprocessing.then(|| μ.clone()), + .. stats + }; + stats = IterInfo::new(); + res + }) + }); + + postprocess(μ, config, L2Squared, opA, b) +} + +/// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting. +/// +/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the +/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. +/// Finally, the `iterator` is an outer loop verbosity and iteration count control +/// as documented in [`alg_tools::iterate`]. +/// +/// For details on the mathematical formulation, see the [module level](self) documentation. +/// +/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of +/// sums of simple functions usign bisection trees, and the related +/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions +/// active at a specific points, and to maximise their sums. Through the implementation of the +/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features +/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. +/// +/// Returns the final iterate. +#[replace_float_literals(F::cast_from(literal))] +pub fn pointsource_radon_fista_reg< + 'a, F, I, A, GA, BTA, S, Reg, const N : usize +>( + opA : &'a A, + b : &A::Observable, + reg : Reg, + fbconfig : &RadonFBConfig<F>, + iterator : I, + mut _plotter : SeqPlotter<F, N>, +) -> DiscreteMeasure<Loc<F, N>, F> +where F : Float + ToNalgebraRealField, + I : AlgIteratorFactory<IterInfo<F, N>>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, + //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow + A::Observable : std::ops::MulAssign<F>, + GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, + A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, + BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, + S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, + BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, + Cube<F, N>: P2Minimise<Loc<F, N>, F>, + PlotLookup : Plotting<N>, + DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, + Reg : RegTerm<F, N> { + + // Set up parameters + let config = &fbconfig.insertion; + // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ + // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such + // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. + let τ = fbconfig.τ0/opA.opnorm_bound().powi(2); + let mut λ = 1.0; + // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled + // by τ compared to the conditional gradient approach. + let tolerance = config.tolerance * τ * reg.tolerance_scaling(); + let mut ε = tolerance.initial(); + + // Initialise iterates + let mut μ = DiscreteMeasure::new(); + let mut μ_prev = DiscreteMeasure::new(); + let mut residual = -b; + let mut stats = IterInfo::new(); + let mut warned_merging = false; + + // Run the algorithm + iterator.iterate(|state| { + // Calculate smooth part of surrogate model. + // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` + // has no significant overhead. For some reosn Rust doesn't allow us simply moving + // the residual and replacing it below before the end of this closure. + residual *= -τ; + let r = std::mem::replace(&mut residual, opA.empty_observable()); + let mut minus_τv = opA.preadjoint().apply(r); + + // Save current base point + let mut μ_base = μ.clone(); + + // Insert new spikes and reweigh + insert_and_reweigh( + &mut μ, &mut minus_τv, &mut μ_base, None, + τ, ε, + config, ®, state, &mut stats + ); + + // (Do not) merge spikes. + if state.iteration() % config.merge_every == 0 { + match config.merging { + SpikeMergingMethod::None => { }, + _ => if !warned_merging { + let err = format!("Merging not supported for μFISTA"); + println!("{}", err.red()); + warned_merging = true; + } + } + } + + // Update inertial prameters + let λ_prev = λ; + λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); + let θ = λ / λ_prev - λ; + + // Perform inertial update on μ. + // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ + // and μ_prev have zero weight. Since both have weights from the finite-dimensional + // subproblem with a proximal projection step, this is likely to happen when the + // spike is not needed. A copy of the pruned μ without artithmetic performed is + // stored in μ_prev. + let n_before_prune = μ.len(); + μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); + debug_assert!(μ.len() <= n_before_prune); + stats.pruned += n_before_prune - μ.len(); + + // Update residual + residual = calculate_residual(&μ, opA, b); + + // Update main tolerance for next iteration + let ε_prev = ε; + ε = tolerance.update(ε, state.iteration()); + stats.this_iters += 1; + + // Give function value if needed + state.if_verbose(|| { + // Plot if so requested + // plotter.plot_spikes( + // format!("iter {} end;", state.iteration()), &d, + // "start".to_string(), Some(&minus_τv), + // reg.target_bounds(τ, ε_prev), &μ_prev, + // ); + // Calculate mean inner iterations and reset relevant counters. + // Return the statistics + let res = IterInfo { + value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev), + n_spikes : μ_prev.len(), + ε : ε_prev, + postprocessing: config.postprocessing.then(|| μ_prev.clone()), + .. stats + }; + stats = IterInfo::new(); + res + }) + }); + + postprocess(μ_prev, config, L2Squared, opA, b) +}