Tue, 31 Dec 2024 09:25:45 -0500
New version of sliding.
/*! Solver for the point source localisation problem using a forward-backward splitting method. This corresponds to the manuscript * Valkonen T. - _Proximal methods for point source localisation_, [arXiv:2212.02991](https://arxiv.org/abs/2212.02991). The main routine is [`pointsource_fb_reg`]. ## Problem <p> Our objective is to solve $$ \min_{μ ∈ ℳ(Ω)}~ F_0(Aμ-b) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ), $$ where $F_0(y)=\frac{1}{2}\|y\|_2^2$ and the forward operator $A \in 𝕃(ℳ(Ω); ℝ^n)$. </p> ## Approach <p> As documented in more detail in the paper, on each step we approximately solve $$ \min_{μ ∈ ℳ(Ω)}~ F(x) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(x) + \frac{1}{2}\|μ-μ^k|_𝒟^2, $$ where $𝒟: 𝕃(ℳ(Ω); C_c(Ω))$ is typically a convolution operator. </p> ## Finite-dimensional subproblems. With $C$ a projection from [`DiscreteMeasure`] to the weights, and $x^k$ such that $x^k=Cμ^k$, we form the discretised linearised inner problem <p> $$ \min_{x ∈ ℝ^n}~ τ\bigl(F(Cx^k) + [C^*∇F(Cx^k)]^⊤(x-x^k) + α {\vec 1}^⊤ x\bigr) + δ_{≥ 0}(x) + \frac{1}{2}\|x-x^k\|_{C^*𝒟C}^2, $$ equivalently $$ \begin{aligned} \min_x~ & τF(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k \\ & - [C^*𝒟C x^k - τC^*∇F(Cx^k)]^⊤ x \\ & + \frac{1}{2} x^⊤ C^*𝒟C x + τα {\vec 1}^⊤ x + δ_{≥ 0}(x), \end{aligned} $$ In other words, we obtain the quadratic non-negativity constrained problem $$ \min_{x ∈ ℝ^n}~ \frac{1}{2} x^⊤ à x - b̃^⊤ x + c + τα {\vec 1}^⊤ x + δ_{≥ 0}(x). $$ where $$ \begin{aligned} à & = C^*𝒟C, \\ g̃ & = C^*𝒟C x^k - τ C^*∇F(Cx^k) = C^* 𝒟 μ^k - τ C^*A^*(Aμ^k - b) \\ c & = τ F(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k \\ & = \frac{τ}{2} \|Aμ^k-b\|^2 - τ[Aμ^k-b]^⊤Aμ^k + \frac{1}{2} \|μ_k\|_{𝒟}^2 \\ & = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2. \end{aligned} $$ </p> We solve this with either SSN or FB as determined by [`InnerSettings`] in [`FBGenericConfig::inner`]. */ use numeric_literals::replace_float_literals; use serde::{Serialize, Deserialize}; use colored::Colorize; use nalgebra::DVector; use alg_tools::iterate::{ AlgIteratorFactory, AlgIteratorIteration, AlgIterator, }; use alg_tools::euclidean::Euclidean; use alg_tools::linops::{Mapping, GEMV}; use alg_tools::sets::Cube; use alg_tools::loc::Loc; use alg_tools::bisection_tree::{ BTFN, PreBTFN, Bounds, BTNodeLookup, BTNode, BTSearch, P2Minimise, SupportGenerator, LocalAnalysis, BothGenerators, }; use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::instance::Instance; use alg_tools::norms::Linfinity; use crate::types::*; use crate::measures::{ DiscreteMeasure, RNDM, DeltaMeasure, Radon, }; use crate::measures::merging::{ SpikeMergingMethod, SpikeMerging, }; use crate::forward_model::{ ForwardModel, AdjointProductBoundedBy }; use crate::seminorms::DiscreteMeasureOp; use crate::subproblem::{ InnerSettings, InnerMethod, }; use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, Plotting, PlotLookup }; use crate::regularisation::RegTerm; use crate::dataterm::{ calculate_residual, L2Squared, DataTerm, }; /// Settings for [`pointsource_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct FBConfig<F : Float> { /// Step length scaling pub τ0 : F, /// Generic parameters pub generic : FBGenericConfig<F>, } /// Settings for the solution of the stepwise optimality condition. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct FBGenericConfig<F : Float> { /// Tolerance for point insertion. pub tolerance : Tolerance<F>, /// Stop looking for predual maximum (where to isert a new point) below /// `tolerance` multiplied by this factor. /// /// Not used by [`super::radon_fb`]. pub insertion_cutoff_factor : F, /// Settings for branch and bound refinement when looking for predual maxima pub refinement : RefinementSettings<F>, /// Maximum insertions within each outer iteration /// /// Not used by [`super::radon_fb`]. pub max_insertions : usize, /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. /// /// Not used by [`super::radon_fb`]. pub bootstrap_insertions : Option<(usize, usize)>, /// Inner method settings pub inner : InnerSettings<F>, /// Spike merging method pub merging : SpikeMergingMethod<F>, /// Tolerance multiplier for merges pub merge_tolerance_mult : F, /// Spike merging method after the last step pub final_merging : SpikeMergingMethod<F>, /// Iterations between merging heuristic tries pub merge_every : usize, // /// Save $μ$ for postprocessing optimisation // pub postprocessing : bool } #[replace_float_literals(F::cast_from(literal))] impl<F : Float> Default for FBConfig<F> { fn default() -> Self { FBConfig { τ0 : 0.99, generic : Default::default(), } } } #[replace_float_literals(F::cast_from(literal))] impl<F : Float> Default for FBGenericConfig<F> { fn default() -> Self { FBGenericConfig { tolerance : Default::default(), insertion_cutoff_factor : 1.0, refinement : Default::default(), max_insertions : 100, //bootstrap_insertions : None, bootstrap_insertions : Some((10, 1)), inner : InnerSettings { method : InnerMethod::Default, .. Default::default() }, merging : SpikeMergingMethod::None, //merging : Default::default(), final_merging : Default::default(), merge_every : 10, merge_tolerance_mult : 2.0, // postprocessing : false, } } } impl<F : Float> FBGenericConfig<F> { /// Check if merging should be attempted this iteration pub fn merge_now<I : AlgIterator>(&self, state : &AlgIteratorIteration<I>) -> bool { state.iteration() % self.merge_every == 0 } } /// TODO: document. /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike /// locations, while `ν_delta` may have different locations. #[replace_float_literals(F::cast_from(literal))] pub(crate) fn insert_and_reweigh< 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, I, const N : usize >( μ : &mut RNDM<F, N>, τv : &BTFN<F, GA, BTA, N>, μ_base : &RNDM<F, N>, ν_delta: Option<&RNDM<F, N>>, op𝒟 : &'a 𝒟, op𝒟norm : F, τ : F, ε : F, config : &FBGenericConfig<F>, reg : &Reg, state : &AlgIteratorIteration<I>, stats : &mut IterInfo<F, N>, ) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool) where F : Float + ToNalgebraRealField, GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 𝒟::Codomain : RealMapping<F, N>, S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, Reg : RegTerm<F, N>, I : AlgIterator { // Maximum insertion count and measure difference calculation depend on insertion style. let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { (i, Some((l, k))) if i <= l => (k, false), _ => (config.max_insertions, !state.is_quiet()), }; let ω0 = match ν_delta { None => op𝒟.apply(μ_base), Some(ν) => op𝒟.apply(μ_base + ν), }; // Add points to support until within error tolerance or maximum insertion count reached. let mut count = 0; let (within_tolerances, d) = 'insertion: loop { if μ.len() > 0 { // Form finite-dimensional subproblem. The subproblem references to the original μ^k // from the beginning of the iteration are all contained in the immutable c and g. // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional // problems have not yet been updated to sign change. let à = op𝒟.findim_matrix(μ.iter_locations()); let g̃ = DVector::from_iterator(μ.len(), μ.iter_locations() .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) .map(F::to_nalgebra_mixed)); let mut x = μ.masses_dvector(); // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 // = n |𝒟| |x|_2, where n is the number of points. Therefore let Ã_normest = op𝒟norm * F::cast_from(μ.len()); // Solve finite-dimensional subproblem. stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); // Update masses of μ based on solution of finite-dimensional subproblem. μ.set_masses_dvector(&x); } // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality // conditions in the predual space, and finding new points for insertion, if necessary. let mut d = τv + match ν_delta { None => op𝒟.preapply(μ.sub_matching(μ_base)), Some(ν) => op𝒟.preapply(μ.sub_matching(μ_base) - ν) }; // If no merging heuristic is used, let's be more conservative about spike insertion, // and skip it after first round. If merging is done, being more greedy about spike // insertion also seems to improve performance. let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { false } else { count > 0 }; // Find a spike to insert, if needed let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( &mut d, τ, ε, skip_by_rough_check, config ) { None => break 'insertion (true, d), Some(res) => res, }; // Break if maximum insertion count reached if count >= max_insertions { break 'insertion (in_bounds, d) } // No point in optimising the weight here; the finite-dimensional algorithm is fast. *μ += DeltaMeasure { x : ξ, α : 0.0 }; count += 1; stats.inserted += 1; }; if !within_tolerances && warn_insertions { // Complain (but continue) if we failed to get within tolerances // by inserting more points. let err = format!("Maximum insertions reached without achieving \ subproblem solution tolerance"); println!("{}", err.red()); } (d, within_tolerances) } pub(crate) fn prune_with_stats<F : Float, const N : usize>( μ : &mut RNDM<F, N>, ) -> usize { let n_before_prune = μ.len(); μ.prune(); debug_assert!(μ.len() <= n_before_prune); n_before_prune - μ.len() } #[replace_float_literals(F::cast_from(literal))] pub(crate) fn postprocess< F : Float, V : Euclidean<F> + Clone, A : GEMV<F, RNDM<F, N>, Codomain = V>, D : DataTerm<F, V, N>, const N : usize > ( mut μ : RNDM<F, N>, config : &FBGenericConfig<F>, dataterm : D, opA : &A, b : &V, ) -> RNDM<F, N> where RNDM<F, N> : SpikeMerging<F>, for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>, { μ.merge_spikes_fitness(config.merging, |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |&v| v); μ.prune(); μ } /// Iteratively solve the pointsource localisation problem using forward-backward splitting. /// /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control /// as documented in [`alg_tools::iterate`]. /// /// For details on the mathematical formulation, see the [module level](self) documentation. /// /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of /// sums of simple functions usign bisection trees, and the related /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions /// active at a specific points, and to maximise their sums. Through the implementation of the /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_fb_reg< 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize >( opA : &'a A, b : &A::Observable, reg : Reg, op𝒟 : &'a 𝒟, fbconfig : &FBConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, ) -> RNDM<F, N> where F : Float + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 𝒟::Codomain : RealMapping<F, N>, S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, Cube<F, N>: P2Minimise<Loc<F, N>, F>, PlotLookup : Plotting<N>, RNDM<F, N> : SpikeMerging<F>, Reg : RegTerm<F, N> { // Set up parameters let config = &fbconfig.generic; let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap(); // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. let tolerance = config.tolerance * τ * reg.tolerance_scaling(); let mut ε = tolerance.initial(); // Initialise iterates let mut μ = DiscreteMeasure::new(); let mut residual = -b; // Statistics let full_stats = |residual : &A::Observable, μ : &RNDM<F, N>, ε, stats| IterInfo { value : residual.norm2_squared_div2() + reg.apply(μ), n_spikes : μ.len(), ε, //postprocessing: config.postprocessing.then(|| μ.clone()), .. stats }; let mut stats = IterInfo::new(); // Run the algorithm for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. let τv = opA.preadjoint().apply(residual * τ); // Save current base point let μ_base = μ.clone(); // Insert and reweigh let (d, _within_tolerances) = insert_and_reweigh( &mut μ, &τv, &μ_base, None, op𝒟, op𝒟norm, τ, ε, config, ®, &state, &mut stats ); // Prune and possibly merge spikes if config.merge_now(&state) { stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) }); } stats.pruned += prune_with_stats(&mut μ); // Update residual residual = calculate_residual(&μ, opA, b); let iter = state.iteration(); stats.this_iters += 1; // Give statistics if needed state.if_verbose(|| { plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) }); // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } postprocess(μ, config, L2Squared, opA, b) } /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. /// /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control /// as documented in [`alg_tools::iterate`]. /// /// For details on the mathematical formulation, see the [module level](self) documentation. /// /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of /// sums of simple functions usign bisection trees, and the related /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions /// active at a specific points, and to maximise their sums. Through the implementation of the /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_fista_reg< 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize >( opA : &'a A, b : &A::Observable, reg : Reg, op𝒟 : &'a 𝒟, fbconfig : &FBConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, ) -> RNDM<F, N> where F : Float + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 𝒟::Codomain : RealMapping<F, N>, S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, Cube<F, N>: P2Minimise<Loc<F, N>, F>, PlotLookup : Plotting<N>, RNDM<F, N> : SpikeMerging<F>, Reg : RegTerm<F, N> { // Set up parameters let config = &fbconfig.generic; let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap(); let mut λ = 1.0; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. let tolerance = config.tolerance * τ * reg.tolerance_scaling(); let mut ε = tolerance.initial(); // Initialise iterates let mut μ = DiscreteMeasure::new(); let mut μ_prev = DiscreteMeasure::new(); let mut residual = -b; let mut warned_merging = false; // Statistics let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo { value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), n_spikes : ν.len(), ε, // postprocessing: config.postprocessing.then(|| ν.clone()), .. stats }; let mut stats = IterInfo::new(); // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. let τv = opA.preadjoint().apply(residual * τ); // Save current base point let μ_base = μ.clone(); // Insert new spikes and reweigh let (d, _within_tolerances) = insert_and_reweigh( &mut μ, &τv, &μ_base, None, op𝒟, op𝒟norm, τ, ε, config, ®, &state, &mut stats ); // (Do not) merge spikes. if config.merge_now(&state) { match config.merging { SpikeMergingMethod::None => { }, _ => if !warned_merging { let err = format!("Merging not supported for μFISTA"); println!("{}", err.red()); warned_merging = true; } } } // Update inertial prameters let λ_prev = λ; λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); let θ = λ / λ_prev - λ; // Perform inertial update on μ. // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ // and μ_prev have zero weight. Since both have weights from the finite-dimensional // subproblem with a proximal projection step, this is likely to happen when the // spike is not needed. A copy of the pruned μ without artithmetic performed is // stored in μ_prev. let n_before_prune = μ.len(); μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); debug_assert!(μ.len() <= n_before_prune); stats.pruned += n_before_prune - μ.len(); // Update residual residual = calculate_residual(&μ, opA, b); let iter = state.iteration(); stats.this_iters += 1; // Give statistics if needed state.if_verbose(|| { plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ_prev); full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new())) }); // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } postprocess(μ_prev, config, L2Squared, opA, b) }