Fri, 02 Dec 2022 18:08:40 +0200
Remove ergodic tolerance; it's not useful.
/*! Solver for the point source localisation problem using a conditional gradient method. We implement two variants, the “fully corrective” method from * Pieper K., Walter D. _Linear convergence of accelerated conditional gradient algorithms in spaces of measures_, DOI: [10.1051/cocv/2021042](https://doi.org/10.1051/cocv/2021042), arXiv: [1904.09218](https://doi.org/10.48550/arXiv.1904.09218). and what we call the “relaxed” method from * Bredies K., Pikkarainen H. - _Inverse problems in spaces of measures_, DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205). */ use numeric_literals::replace_float_literals; use serde::{Serialize, Deserialize}; //use colored::Colorize; use alg_tools::iterate::{ AlgIteratorFactory, AlgIteratorState, AlgIteratorOptions, }; use alg_tools::euclidean::Euclidean; use alg_tools::norms::Norm; use alg_tools::linops::Apply; use alg_tools::sets::Cube; use alg_tools::loc::Loc; use alg_tools::bisection_tree::{ BTFN, Bounds, BTNodeLookup, BTNode, BTSearch, P2Minimise, SupportGenerator, LocalAnalysis, }; use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use crate::types::*; use crate::measures::{ DiscreteMeasure, DeltaMeasure, Radon, }; use crate::measures::merging::{ SpikeMergingMethod, SpikeMerging, }; use crate::forward_model::ForwardModel; #[allow(unused_imports)] // Used in documentation use crate::subproblem::{ quadratic_nonneg, InnerSettings, InnerMethod, }; use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, Plotting, PlotLookup }; /// Settings for [`pointsource_fw`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct FWConfig<F : Float> { /// Tolerance for branch-and-bound new spike location discovery pub tolerance : Tolerance<F>, /// Inner problem solution configuration. Has to have `method` set to [`InnerMethod::FB`] /// as the conditional gradient subproblems' optimality conditions do not in general have an /// invertible Newton derivative for SSN. pub inner : InnerSettings<F>, /// Variant of the conditional gradient method pub variant : FWVariant, /// Settings for branch and bound refinement when looking for predual maxima pub refinement : RefinementSettings<F>, /// Spike merging heuristic pub merging : SpikeMergingMethod<F>, } /// Conditional gradient method variant; see also [`FWConfig`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[allow(dead_code)] pub enum FWVariant { /// Algorithm 2 of Walter-Pieper FullyCorrective, /// Bredies–Pikkarainen. Forces `FWConfig.inner.max_iter = 1`. Relaxed, } impl<F : Float> Default for FWConfig<F> { fn default() -> Self { FWConfig { tolerance : Default::default(), refinement : Default::default(), inner : Default::default(), variant : FWVariant::FullyCorrective, merging : Default::default(), } } } /// Helper struct for pre-initialising the finite-dimensional subproblems solver /// [`prepare_optimise_weights`]. /// /// The pre-initialisation is done by [`prepare_optimise_weights`]. pub struct FindimData<F : Float> { opAnorm_squared : F } /// Return a pre-initialisation struct for [`prepare_optimise_weights`]. /// /// The parameter `opA` is the forward operator $A$. pub fn prepare_optimise_weights<F, A, const N : usize>(opA : &A) -> FindimData<F> where F : Float + ToNalgebraRealField, A : ForwardModel<Loc<F, N>, F> { FindimData{ opAnorm_squared : opA.opnorm_bound().powi(2) } } /// Solve the finite-dimensional weight optimisation problem for the 2-norm-squared data fidelity /// point source localisation problem. /// /// That is, we minimise /// <div>$$ /// μ ↦ \frac{1}{2}\|Aμ-b\|_w^2 + α\|μ\|_ℳ + δ_{≥ 0}(μ) /// $$</div> /// only with respect to the weights of $μ$. /// /// The parameter `μ` is the discrete measure whose weights are to be optimised. /// The `opA` parameter is the forward operator $A$, while `b`$ and `α` are as in the /// objective above. The method parameter are set in `inner` (see [`InnerSettings`]), while /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to /// save intermediate iteration states as images. The parameter `findim_data` should be /// prepared using [`prepare_optimise_weights`]: /// /// Returns the number of iterations taken by the method configured in `inner`. pub fn optimise_weights<'a, F, A, I, const N : usize>( μ : &mut DiscreteMeasure<Loc<F, N>, F>, opA : &'a A, b : &A::Observable, α : F, findim_data : &FindimData<F>, inner : &InnerSettings<F>, iterator : I ) -> usize where F : Float + ToNalgebraRealField, I : AlgIteratorFactory<F>, A : ForwardModel<Loc<F, N>, F> { // Form and solve finite-dimensional subproblem. let (Ã, g̃) = opA.findim_quadratic_model(&μ, b); let mut x = μ.masses_dvector(); // `inner_τ1` is based on an estimate of the operator norm of $A$ from ℳ(Ω) to // ℝ^n. This estimate is a good one for the matrix norm from ℝ^m to ℝ^n when the // former is equipped with the 1-norm. We need the 2-norm. To pass from 1-norm to // 2-norm, we estimate // ‖A‖_{2,2} := sup_{‖x‖_2 ≤ 1} ‖Ax‖_2 ≤ sup_{‖x‖_1 ≤ C} ‖Ax‖_2 // = C sup_{‖x‖_1 ≤ 1} ‖Ax‖_2 = C ‖A‖_{1,2}, // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no // square root is needed when we scale: let inner_τ = inner.τ0 / (findim_data.opAnorm_squared * F::cast_from(μ.len())); let iters = quadratic_nonneg(inner.method, &Ã, &g̃, α, &mut x, inner_τ, iterator); // Update masses of μ based on solution of finite-dimensional subproblem. μ.set_masses_dvector(&x); iters } /// Solve point source localisation problem using a conditional gradient method /// for the 2-norm-squared data fidelity, i.e., the problem /// <div>$$ /// \min_μ \frac{1}{2}\|Aμ-b\|_w^2 + α\|μ\|_ℳ + δ_{≥ 0}(μ). /// $$</div> /// /// The `opA` parameter is the forward operator $A$, while `b`$ and `α` are as in the /// objective above. The method parameter are set in `config` (see [`FWConfig`]), while /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to /// save intermediate iteration states as images. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_fw<'a, F, I, A, GA, BTA, S, const N : usize>( opA : &'a A, b : &A::Observable, α : F, //domain : Cube<F, N>, config : &FWConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, ) -> DiscreteMeasure<Loc<F, N>, F> where F : Float + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow A::Observable : std::ops::MulAssign<F>, GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, Cube<F, N>: P2Minimise<Loc<F, N>, F>, PlotLookup : Plotting<N>, DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { // Set up parameters // We multiply tolerance by α for all algoritms. let tolerance = config.tolerance * α; let mut ε = tolerance.initial(); let findim_data = prepare_optimise_weights(opA); let m0 = b.norm2_squared() / (2.0 * α); let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) }; // Initialise operators let preadjA = opA.preadjoint(); // Initialise iterates let mut μ = DiscreteMeasure::new(); let mut residual = -b; let mut inner_iters = 0; let mut this_iters = 0; let mut pruned = 0; let mut merged = 0; // Run the algorithm iterator.iterate(|state| { // Update tolerance let inner_tolerance = ε * config.inner.tolerance_mult; let refinement_tolerance = ε * config.refinement.tolerance_mult; let ε_prev = ε; ε = tolerance.update(ε, state.iteration()); // Calculate smooth part of surrogate model. // // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` // has no significant overhead. For some reosn Rust doesn't allow us simply moving // the residual and replacing it below before the end of this closure. let r = std::mem::replace(&mut residual, opA.empty_observable()); let mut g = -preadjA.apply(r); // Find absolute value maximising point let (ξmax, v_ξmax) = g.maximise(refinement_tolerance, config.refinement.max_steps); let (ξmin, v_ξmin) = g.minimise(refinement_tolerance, config.refinement.max_steps); let (ξ, v_ξ) = if v_ξmin < 0.0 && -v_ξmin > v_ξmax { (ξmin, v_ξmin) } else { (ξmax, v_ξmax) }; let inner_it = match config.variant { FWVariant::FullyCorrective => { // No point in optimising the weight here: the finite-dimensional algorithm is fast. μ += DeltaMeasure { x : ξ, α : 0.0 }; config.inner.iterator_options.stop_target(inner_tolerance) }, FWVariant::Relaxed => { // Perform a relaxed initialisation of μ let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ }; let δ = DeltaMeasure { x : ξ, α : v }; let dp = μ.apply(&g) - δ.apply(&g); let d = opA.apply(&μ) - opA.apply(&δ); let r = d.norm2_squared(); let s = if r == 0.0 { 1.0 } else { 1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r) }; μ *= 1.0 - s; μ += δ * s; // The stop_target is only needed for the type system. AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0) } }; inner_iters += optimise_weights(&mut μ, opA, b, α, &findim_data, &config.inner, inner_it); // Merge spikes and update residual for next step and `if_verbose` below. let n_before_merge = μ.len(); residual = μ.merge_spikes_fitness(config.merging, |μ̃| opA.apply(μ̃) - b, A::Observable::norm2_squared); assert!(μ.len() >= n_before_merge); merged += μ.len() - n_before_merge; // Prune points with zero mass let n_before_prune = μ.len(); μ.prune(); debug_assert!(μ.len() <= n_before_prune); pruned += n_before_prune - μ.len(); this_iters +=1; // Give function value if needed state.if_verbose(|| { plotter.plot_spikes( format!("iter {} start", state.iteration()), &g, "".to_string(), None::<&A::PreadjointCodomain>, None, &μ ); let res = IterInfo { value : residual.norm2_squared_div2() + α * μ.norm(Radon), n_spikes : μ.len(), inner_iters, this_iters, merged, pruned, ε : ε_prev, maybe_ε1 : None, postprocessing : None, }; inner_iters = 0; this_iters = 0; merged = 0; pruned = 0; res }) }); // Return final iterate μ }