--- a/src/radon_fb.rs Mon Jan 06 11:32:57 2025 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,404 +0,0 @@ -/*! -Solver for the point source localisation problem using a simplified forward-backward splitting method. - -Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map. -*/ - -use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; -use colored::Colorize; -use nalgebra::DVector; - -use alg_tools::iterate::{ - AlgIteratorFactory, - AlgIteratorIteration, - AlgIterator -}; -use alg_tools::euclidean::Euclidean; -use alg_tools::linops::Mapping; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::bisection_tree::{ - BTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, -}; -use alg_tools::mapping::RealMapping; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::L2; - -use crate::types::*; -use crate::measures::{ - RNDM, - DiscreteMeasure, - DeltaMeasure, - Radon, -}; -use crate::measures::merging::{ - SpikeMergingMethod, - SpikeMerging, -}; -use crate::forward_model::ForwardModel; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; -use crate::regularisation::RegTerm; -use crate::dataterm::{ - calculate_residual, - L2Squared, - DataTerm, -}; - -use crate::fb::{ - FBGenericConfig, - postprocess, - prune_with_stats -}; - -/// Settings for [`pointsource_radon_fb_reg`]. -#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] -#[serde(default)] -pub struct RadonFBConfig<F : Float> { - /// Step length scaling - pub τ0 : F, - /// Generic parameters - pub insertion : FBGenericConfig<F>, -} - -#[replace_float_literals(F::cast_from(literal))] -impl<F : Float> Default for RadonFBConfig<F> { - fn default() -> Self { - RadonFBConfig { - τ0 : 0.99, - insertion : Default::default() - } - } -} - -#[replace_float_literals(F::cast_from(literal))] -pub(crate) fn insert_and_reweigh< - 'a, F, GA, BTA, S, Reg, I, const N : usize ->( - μ : &mut RNDM<F, N>, - τv : &mut BTFN<F, GA, BTA, N>, - μ_base : &mut RNDM<F, N>, - //_ν_delta: Option<&RNDM<F, N>>, - τ : F, - ε : F, - config : &FBGenericConfig<F>, - reg : &Reg, - _state : &AlgIteratorIteration<I>, - stats : &mut IterInfo<F, N>, -) -where F : Float + ToNalgebraRealField, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - RNDM<F, N> : SpikeMerging<F>, - Reg : RegTerm<F, N>, - I : AlgIterator { - - 'i_and_w: for i in 0..=1 { - // Optimise weights - if μ.len() > 0 { - // Form finite-dimensional subproblem. The subproblem references to the original μ^k - // from the beginning of the iteration are all contained in the immutable c and g. - // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional - // problems have not yet been updated to sign change. - let g̃ = DVector::from_iterator(μ.len(), - μ.iter_locations() - .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ)))); - let mut x = μ.masses_dvector(); - let y = μ_base.masses_dvector(); - - // Solve finite-dimensional subproblem. - stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config); - - // Update masses of μ based on solution of finite-dimensional subproblem. - μ.set_masses_dvector(&x); - } - - if i>0 { - // Simple debugging test to see if more inserts would be needed. Doesn't seem so. - //let n = μ.dist_matching(μ_base); - //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n)); - break 'i_and_w - } - - // Calculate ‖μ - μ_base‖_ℳ - let n = μ.dist_matching(μ_base); - - // Find a spike to insert, if needed. - // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, - // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. - match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) { - None => { break 'i_and_w }, - Some((ξ, _v_ξ, _in_bounds)) => { - // Weight is found out by running the finite-dimensional optimisation algorithm - // above - *μ += DeltaMeasure { x : ξ, α : 0.0 }; - *μ_base += DeltaMeasure { x : ξ, α : 0.0 }; - stats.inserted += 1; - } - }; - } -} - - -/// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting. -/// -/// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the -/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. -/// Finally, the `iterator` is an outer loop verbosity and iteration count control -/// as documented in [`alg_tools::iterate`]. -/// -/// For details on the mathematical formulation, see the [module level](self) documentation. -/// -/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of -/// sums of simple functions usign bisection trees, and the related -/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions -/// active at a specific points, and to maximise their sums. Through the implementation of the -/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features -/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. -/// -/// Returns the final iterate. -#[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_radon_fb_reg< - 'a, F, I, A, GA, BTA, S, Reg, const N : usize ->( - opA : &'a A, - b : &A::Observable, - reg : Reg, - fbconfig : &RadonFBConfig<F>, - iterator : I, - mut _plotter : SeqPlotter<F, N>, -) -> RNDM<F, N> -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, - RNDM<F, N> : SpikeMerging<F>, - Reg : RegTerm<F, N> { - - // Set up parameters - let config = &fbconfig.insertion; - // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ - // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such - // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. - let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2); - // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled - // by τ compared to the conditional gradient approach. - let tolerance = config.tolerance * τ * reg.tolerance_scaling(); - let mut ε = tolerance.initial(); - - // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut residual = -b; - - // Statistics - let full_stats = |residual : &A::Observable, - μ : &RNDM<F, N>, - ε, stats| IterInfo { - value : residual.norm2_squared_div2() + reg.apply(μ), - n_spikes : μ.len(), - ε, - // postprocessing: config.postprocessing.then(|| μ.clone()), - .. stats - }; - let mut stats = IterInfo::new(); - - // Run the algorithm - for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { - // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(residual * τ); - - // Save current base point - let mut μ_base = μ.clone(); - - // Insert and reweigh - insert_and_reweigh( - &mut μ, &mut τv, &mut μ_base, //None, - τ, ε, - config, ®, &state, &mut stats - ); - - // Prune and possibly merge spikes - assert!(μ_base.len() <= μ.len()); - if config.merge_now(&state) { - stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { - // Important: μ_candidate's new points are afterwards, - // and do not conflict with μ_base. - // TODO: could simplify to requiring μ_base instead of μ_radon. - // but may complicate with sliding base's exgtra points that need to be - // after μ_candidate's extra points. - // TODO: doesn't seem to work, maybe need to merge μ_base as well? - // Although that doesn't seem to make sense. - let μ_radon = μ_candidate.sub_matching(&μ_base); - reg.verify_merge_candidate_radonsq(&mut τv, μ_candidate, τ, ε, &config, &μ_radon) - //let n = μ_candidate.dist_matching(μ_base); - //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() - }); - } - stats.pruned += prune_with_stats(&mut μ); - - // Update residual - residual = calculate_residual(&μ, opA, b); - - let iter = state.iteration(); - stats.this_iters += 1; - - // Give statistics if needed - state.if_verbose(|| { - full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) - }); - - // Update main tolerance for next iteration - ε = tolerance.update(ε, iter); - } - - postprocess(μ, config, L2Squared, opA, b) -} - -/// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting. -/// -/// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the -/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. -/// Finally, the `iterator` is an outer loop verbosity and iteration count control -/// as documented in [`alg_tools::iterate`]. -/// -/// For details on the mathematical formulation, see the [module level](self) documentation. -/// -/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of -/// sums of simple functions usign bisection trees, and the related -/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions -/// active at a specific points, and to maximise their sums. Through the implementation of the -/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features -/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. -/// -/// Returns the final iterate. -#[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_radon_fista_reg< - 'a, F, I, A, GA, BTA, S, Reg, const N : usize ->( - opA : &'a A, - b : &A::Observable, - reg : Reg, - fbconfig : &RadonFBConfig<F>, - iterator : I, - mut plotter : SeqPlotter<F, N>, -) -> RNDM<F, N> -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, - PlotLookup : Plotting<N>, - RNDM<F, N> : SpikeMerging<F>, - Reg : RegTerm<F, N> { - - // Set up parameters - let config = &fbconfig.insertion; - // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ - // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such - // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. - let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2); - let mut λ = 1.0; - // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled - // by τ compared to the conditional gradient approach. - let tolerance = config.tolerance * τ * reg.tolerance_scaling(); - let mut ε = tolerance.initial(); - - // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut μ_prev = DiscreteMeasure::new(); - let mut residual = -b; - let mut warned_merging = false; - - // Statistics - let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo { - value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), - n_spikes : ν.len(), - ε, - // postprocessing: config.postprocessing.then(|| ν.clone()), - .. stats - }; - let mut stats = IterInfo::new(); - - // Run the algorithm - for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { - // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(residual * τ); - - // Save current base point - let mut μ_base = μ.clone(); - - // Insert new spikes and reweigh - insert_and_reweigh( - &mut μ, &mut τv, &mut μ_base, //None, - τ, ε, - config, ®, &state, &mut stats - ); - - // (Do not) merge spikes. - if config.merge_now(&state) { - match config.merging { - SpikeMergingMethod::None => { }, - _ => if !warned_merging { - let err = format!("Merging not supported for μFISTA"); - println!("{}", err.red()); - warned_merging = true; - } - } - } - - // Update inertial prameters - let λ_prev = λ; - λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); - let θ = λ / λ_prev - λ; - - // Perform inertial update on μ. - // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ - // and μ_prev have zero weight. Since both have weights from the finite-dimensional - // subproblem with a proximal projection step, this is likely to happen when the - // spike is not needed. A copy of the pruned μ without artithmetic performed is - // stored in μ_prev. - let n_before_prune = μ.len(); - μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); - debug_assert!(μ.len() <= n_before_prune); - stats.pruned += n_before_prune - μ.len(); - - // Update residual - residual = calculate_residual(&μ, opA, b); - - let iter = state.iteration(); - stats.this_iters += 1; - - // Give statistics if needed - state.if_verbose(|| { - plotter.plot_spikes(iter, Option::<&S>::None, Some(&τv), &μ_prev); - full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new())) - }); - - // Update main tolerance for next iteration - ε = tolerance.update(ε, iter); - } - - postprocess(μ_prev, config, L2Squared, opA, b) -}