diff -r bd13c2ae3450 -r 56c8adc32b09 src/fb.rs --- a/src/fb.rs Fri Apr 28 13:15:19 2023 +0300 +++ b/src/fb.rs Tue Dec 31 09:34:24 2024 -0500 @@ -83,17 +83,16 @@ use numeric_literals::replace_float_literals; use serde::{Serialize, Deserialize}; use colored::Colorize; -use nalgebra::{DVector, DMatrix}; +use nalgebra::DVector; use alg_tools::iterate::{ AlgIteratorFactory, AlgIteratorState, }; use alg_tools::euclidean::Euclidean; -use alg_tools::linops::Apply; +use alg_tools::linops::{Apply, GEMV}; use alg_tools::sets::Cube; use alg_tools::loc::Loc; -use alg_tools::mapping::Mapping; use alg_tools::bisection_tree::{ BTFN, PreBTFN, @@ -104,7 +103,7 @@ P2Minimise, SupportGenerator, LocalAnalysis, - Bounded, + BothGenerators, }; use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; @@ -119,12 +118,8 @@ SpikeMerging, }; use crate::forward_model::ForwardModel; -use crate::seminorms::{ - DiscreteMeasureOp, Lipschitz -}; +use crate::seminorms::DiscreteMeasureOp; use crate::subproblem::{ - nonneg::quadratic_nonneg, - unconstrained::quadratic_unconstrained, InnerSettings, InnerMethod, }; @@ -134,9 +129,11 @@ Plotting, PlotLookup }; -use crate::regularisation::{ - NonnegRadonRegTerm, - RadonRegTerm, +use crate::regularisation::RegTerm; +use crate::dataterm::{ + calculate_residual, + L2Squared, + DataTerm, }; /// Method for constructing $μ$ on each iteration @@ -150,24 +147,12 @@ Zero, } -/// Meta-algorithm type -#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] -#[allow(dead_code)] -pub enum FBMetaAlgorithm { - /// No meta-algorithm - None, - /// FISTA-style inertia - InertiaFISTA, -} - /// Settings for [`pointsource_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct FBConfig { /// Step length scaling pub τ0 : F, - /// Meta-algorithm to apply - pub meta : FBMetaAlgorithm, /// Generic parameters pub insertion : FBGenericConfig, } @@ -209,7 +194,6 @@ fn default() -> Self { FBConfig { τ0 : 0.99, - meta : FBMetaAlgorithm::None, insertion : Default::default() } } @@ -240,486 +224,236 @@ } } -/// Trait for specialisation of [`generic_pointsource_fb_reg`] to basic FB, FISTA. -/// -/// The idea is that the residual $Aμ - b$ in the forward step can be replaced by an arbitrary -/// value. For example, to implement [primal-dual proximal splitting][crate::pdps] we replace it -/// with the dual variable $y$. We can then also implement alternative data terms, as the -/// (pre)differential of $F(μ)=F\_0(Aμ-b)$ is $F\'(μ) = A\_*F\_0\'(Aμ-b)$. In the case of the -/// quadratic fidelity $F_0(y)=\frac{1}{2}\\|y\\|_2^2$ in a Hilbert space, of course, -/// $F\_0\'(Aμ-b)=Aμ-b$ is the residual. -pub trait FBSpecialisation, const N : usize> : Sized { - /// Updates the residual and does any necessary pruning of `μ`. - /// - /// Returns the new residual and possibly a new step length. - /// - /// The measure `μ` may also be modified to apply, e.g., inertia to it. - /// The updated residual should correspond to the residual at `μ`. - /// See the [trait documentation][FBSpecialisation] for the use and meaning of the residual. - /// - /// The parameter `μ_base` is the base point of the iteration, typically the previous iterate, - /// but for, e.g., FISTA has inertia applied to it. - fn update( - &mut self, - μ : &mut DiscreteMeasure, F>, - μ_base : &DiscreteMeasure, F>, - ) -> (Observable, Option); - - /// Calculates the data term value corresponding to iterate `μ` and available residual. - /// - /// Inertia and other modifications, as deemed, necessary, should be applied to `μ`. - /// - /// The blanket implementation correspondsn to the 2-norm-squared data fidelity - /// $\\|\text{residual}\\|\_2^2/2$. - fn calculate_fit( - &self, - _μ : &DiscreteMeasure, F>, - residual : &Observable - ) -> F { - residual.norm2_squared_div2() - } - - /// Calculates the data term value at $μ$. - /// - /// Unlike [`Self::calculate_fit`], no inertia, etc., should be applied to `μ`. - fn calculate_fit_simple( - &self, - μ : &DiscreteMeasure, F>, - ) -> F; - - /// Returns the final iterate after any necessary postprocess pruning, merging, etc. - fn postprocess(self, mut μ : DiscreteMeasure, F>, merging : SpikeMergingMethod) - -> DiscreteMeasure, F> - where DiscreteMeasure, F> : SpikeMerging { - μ.merge_spikes_fitness(merging, - |μ̃| self.calculate_fit_simple(μ̃), - |&v| v); - μ.prune(); - μ - } - - /// Returns measure to be used for value calculations, which may differ from μ. - fn value_μ<'c, 'b : 'c>(&'b self, μ : &'c DiscreteMeasure, F>) - -> &'c DiscreteMeasure, F> { - μ - } -} - -/// Specialisation of [`generic_pointsource_fb_reg`] to basic μFB. -struct BasicFB< - 'a, - F : Float + ToNalgebraRealField, - A : ForwardModel, F>, - const N : usize -> { - /// The data - b : &'a A::Observable, - /// The forward operator - opA : &'a A, -} - -/// Implementation of [`FBSpecialisation`] for basic μFB forward-backward splitting. #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float + ToNalgebraRealField , A : ForwardModel, F>, const N : usize> -FBSpecialisation for BasicFB<'a, F, A, N> { - fn update( - &mut self, - μ : &mut DiscreteMeasure, F>, - _μ_base : &DiscreteMeasure, F> - ) -> (A::Observable, Option) { - μ.prune(); - //*residual = self.opA.apply(μ) - self.b; - let mut residual = self.b.clone(); - self.opA.gemv(&mut residual, 1.0, μ, -1.0); - (residual, None) - } - - fn calculate_fit_simple( - &self, - μ : &DiscreteMeasure, F>, - ) -> F { - let mut residual = self.b.clone(); - self.opA.gemv(&mut residual, 1.0, μ, -1.0); - residual.norm2_squared_div2() - } -} - -/// Specialisation of [`generic_pointsource_fb_reg`] to FISTA. -struct FISTA< - 'a, - F : Float + ToNalgebraRealField, - A : ForwardModel, F>, - const N : usize -> { - /// The data - b : &'a A::Observable, - /// The forward operator - opA : &'a A, - /// Current inertial parameter - λ : F, - /// Previous iterate without inertia applied. - /// We need to store this here because `μ_base` passed to [`FBSpecialisation::update`] will - /// have inertia applied to it, so is not useful to use. - μ_prev : DiscreteMeasure, F>, -} - -/// Implementation of [`FBSpecialisation`] for μFISTA inertial forward-backward splitting. -#[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float + ToNalgebraRealField, A : ForwardModel, F>, const N : usize> -FBSpecialisation for FISTA<'a, F, A, N> { - fn update( - &mut self, - μ : &mut DiscreteMeasure, F>, - _μ_base : &DiscreteMeasure, F> - ) -> (A::Observable, Option) { - // Update inertial parameters - let λ_prev = self.λ; - self.λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); - let θ = self.λ / λ_prev - self.λ; - // Perform inertial update on μ. - // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ - // and μ_prev have zero weight. Since both have weights from the finite-dimensional - // subproblem with a proximal projection step, this is likely to happen when the - // spike is not needed. A copy of the pruned μ without artithmetic performed is - // stored in μ_prev. - μ.pruning_sub(1.0 + θ, θ, &mut self.μ_prev); - - //*residual = self.opA.apply(μ) - self.b; - let mut residual = self.b.clone(); - self.opA.gemv(&mut residual, 1.0, μ, -1.0); - (residual, None) - } - - fn calculate_fit_simple( - &self, - μ : &DiscreteMeasure, F>, - ) -> F { - let mut residual = self.b.clone(); - self.opA.gemv(&mut residual, 1.0, μ, -1.0); - residual.norm2_squared_div2() - } - - fn calculate_fit( - &self, - _μ : &DiscreteMeasure, F>, - _residual : &A::Observable - ) -> F { - self.calculate_fit_simple(&self.μ_prev) - } - - // For FISTA we need to do a final pruning as well, due to the limited - // pruning that can be done on each step. - fn postprocess(mut self, μ_base : DiscreteMeasure, F>, merging : SpikeMergingMethod) - -> DiscreteMeasure, F> - where DiscreteMeasure, F> : SpikeMerging { - let mut μ = self.μ_prev; - self.μ_prev = μ_base; - μ.merge_spikes_fitness(merging, - |μ̃| self.calculate_fit_simple(μ̃), - |&v| v); - μ.prune(); - μ - } - - fn value_μ<'c, 'b : 'c>(&'c self, _μ : &'c DiscreteMeasure, F>) - -> &'c DiscreteMeasure, F> { - &self.μ_prev - } -} - - -/// Abstraction of regularisation terms for [`generic_pointsource_fb_reg`]. -pub trait RegTerm -: for<'a> Apply<&'a DiscreteMeasure, F>, Output = F> { - /// Approximately solve the problem - ///
$$ - /// \min_{x ∈ ℝ^n} \frac{1}{2} x^⊤Ax - g^⊤ x + τ G(x) - /// $$
- /// for $G$ depending on the trait implementation. - /// - /// The parameter `mA` is $A$. An estimate for its opeator norm should be provided in - /// `mA_normest`. The initial iterate and output is `x`. The current main tolerance is `ε`. - /// - /// Returns the number of iterations taken. - fn solve_findim( - &self, - mA : &DMatrix, - g : &DVector, - τ : F, - x : &mut DVector, - mA_normest : F, - ε : F, - config : &FBGenericConfig - ) -> usize; - - /// Find a point where `d` may violate the tolerance `ε`. - /// - /// If `skip_by_rough_check` is set, do not find the point if a rough check indicates that we - /// are in bounds. `ε` is the current main tolerance and `τ` a scaling factor for the - /// regulariser. - /// - /// Returns `None` if `d` is in bounds either based on the rough check, or a more precise check - /// terminating early. Otherwise returns a possibly violating point, the value of `d` there, - /// and a boolean indicating whether the found point is in bounds. - fn find_tolerance_violation( - &self, - d : &mut BTFN, - τ : F, - ε : F, - skip_by_rough_check : bool, - config : &FBGenericConfig, - ) -> Option<(Loc, F, bool)> - where BT : BTSearch>, - G : SupportGenerator, - G::SupportType : Mapping,Codomain=F> - + LocalAnalysis, N>; - - /// Verify that `d` is in bounds `ε` for a merge candidate `μ` - /// - /// `ε` is the current main tolerance and `τ` a scaling factor for the regulariser. - fn verify_merge_candidate( - &self, - d : &mut BTFN, - μ : &DiscreteMeasure, F>, - τ : F, - ε : F, - config : &FBGenericConfig, - ) -> bool - where BT : BTSearch>, - G : SupportGenerator, - G::SupportType : Mapping,Codomain=F> - + LocalAnalysis, N>; - - fn target_bounds(&self, τ : F, ε : F) -> Option>; - - /// Returns a scaling factor for the tolerance sequence. - /// - /// Typically this is the regularisation parameter. - fn tolerance_scaling(&self) -> F; -} - -#[replace_float_literals(F::cast_from(literal))] -impl RegTerm for NonnegRadonRegTerm -where Cube : P2Minimise, F> { - fn solve_findim( - &self, - mA : &DMatrix, - g : &DVector, - τ : F, - x : &mut DVector, - mA_normest : F, - ε : F, - config : &FBGenericConfig - ) -> usize { - let inner_tolerance = ε * config.inner.tolerance_mult; - let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); - let inner_τ = config.inner.τ0 / mA_normest; - quadratic_nonneg(config.inner.method, mA, g, τ * self.α(), x, - inner_τ, inner_it) - } - - #[inline] - fn find_tolerance_violation( - &self, - d : &mut BTFN, - τ : F, - ε : F, - skip_by_rough_check : bool, - config : &FBGenericConfig, - ) -> Option<(Loc, F, bool)> - where BT : BTSearch>, - G : SupportGenerator, - G::SupportType : Mapping,Codomain=F> - + LocalAnalysis, N> { - let τα = τ * self.α(); - let keep_below = τα + ε; - let maximise_above = τα + ε * config.insertion_cutoff_factor; - let refinement_tolerance = ε * config.refinement.tolerance_mult; - - // If preliminary check indicates that we are in bonds, and if it otherwise matches - // the insertion strategy, skip insertion. - if skip_by_rough_check && d.bounds().upper() <= keep_below { - None - } else { - // If the rough check didn't indicate no insertion needed, find maximising point. - d.maximise_above(maximise_above, refinement_tolerance, config.refinement.max_steps) - .map(|(ξ, v_ξ)| (ξ, v_ξ, v_ξ <= keep_below)) +pub(crate) fn μ_diff( + μ_new : &DiscreteMeasure, F>, + μ_base : &DiscreteMeasure, F>, + ν_delta : Option<&DiscreteMeasure, F>>, + config : &FBGenericConfig +) -> DiscreteMeasure, F> { + let mut ν : DiscreteMeasure, F> = match config.insertion_style { + InsertionStyle::Reuse => { + μ_new.iter_spikes() + .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0))) + .map(|(δ, α_base)| (δ.x, α_base - δ.α)) + .collect() + }, + InsertionStyle::Zero => { + μ_new.iter_spikes() + .map(|δ| -δ) + .chain(μ_base.iter_spikes().copied()) + .collect() } - } - - fn verify_merge_candidate( - &self, - d : &mut BTFN, - μ : &DiscreteMeasure, F>, - τ : F, - ε : F, - config : &FBGenericConfig, - ) -> bool - where BT : BTSearch>, - G : SupportGenerator, - G::SupportType : Mapping,Codomain=F> - + LocalAnalysis, N> { - let τα = τ * self.α(); - let refinement_tolerance = ε * config.refinement.tolerance_mult; - let merge_tolerance = config.merge_tolerance_mult * ε; - let keep_below = τα + merge_tolerance; - let keep_supp_above = τα - merge_tolerance; - let bnd = d.bounds(); - - return ( - bnd.lower() >= keep_supp_above - || - μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| { - (β == 0.0) || d.apply(x) >= keep_supp_above - }).all(std::convert::identity) - ) && ( - bnd.upper() <= keep_below - || - d.has_upper_bound(keep_below, refinement_tolerance, config.refinement.max_steps) - ) - } - - fn target_bounds(&self, τ : F, ε : F) -> Option> { - let τα = τ * self.α(); - Some(Bounds(τα - ε, τα + ε)) - } - - fn tolerance_scaling(&self) -> F { - self.α() + }; + ν.prune(); // Potential small performance improvement + // Add ν_delta if given + match ν_delta { + None => ν, + Some(ν_d) => ν + ν_d, } } #[replace_float_literals(F::cast_from(literal))] -impl RegTerm for RadonRegTerm -where Cube : P2Minimise, F> { - fn solve_findim( - &self, - mA : &DMatrix, - g : &DVector, - τ : F, - x : &mut DVector, - mA_normest: F, - ε : F, - config : &FBGenericConfig - ) -> usize { - let inner_tolerance = ε * config.inner.tolerance_mult; - let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); - let inner_τ = config.inner.τ0 / mA_normest; - quadratic_unconstrained(config.inner.method, mA, g, τ * self.α(), x, - inner_τ, inner_it) +pub(crate) fn insert_and_reweigh< + 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize +>( + μ : &mut DiscreteMeasure, F>, + minus_τv : &BTFN, + μ_base : &DiscreteMeasure, F>, + ν_delta: Option<&DiscreteMeasure, F>>, + op𝒟 : &'a 𝒟, + op𝒟norm : F, + τ : F, + ε : F, + config : &FBGenericConfig, + reg : &Reg, + state : &State, + stats : &mut IterInfo, +) -> (BTFN, BTA, N>, bool) +where F : Float + ToNalgebraRealField, + GA : SupportGenerator + Clone, + BTA : BTSearch>, + G𝒟 : SupportGenerator + Clone, + 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, + 𝒟::Codomain : RealMapping, + S: RealMapping + LocalAnalysis, N>, + K: RealMapping + LocalAnalysis, N>, + BTNodeLookup: BTNode, N>, + DiscreteMeasure, F> : SpikeMerging, + Reg : RegTerm, + State : AlgIteratorState { + + // Maximum insertion count and measure difference calculation depend on insertion style. + let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { + (i, Some((l, k))) if i <= l => (k, false), + _ => (config.max_insertions, !state.is_quiet()), + }; + let max_insertions = match config.insertion_style { + InsertionStyle::Zero => { + todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled."); + // let n = μ.len(); + // μ = DiscreteMeasure::new(); + // n + m + }, + InsertionStyle::Reuse => m, + }; + + // TODO: should avoid a second copy of μ here; μ_base already stores a copy. + let ω0 = op𝒟.apply(match ν_delta { + None => μ.clone(), + Some(ν_d) => &*μ + ν_d, + }); + + // Add points to support until within error tolerance or maximum insertion count reached. + let mut count = 0; + let (within_tolerances, d) = 'insertion: loop { + if μ.len() > 0 { + // Form finite-dimensional subproblem. The subproblem references to the original μ^k + // from the beginning of the iteration are all contained in the immutable c and g. + let à = op𝒟.findim_matrix(μ.iter_locations()); + let g̃ = DVector::from_iterator(μ.len(), + μ.iter_locations() + .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ)) + .map(F::to_nalgebra_mixed)); + let mut x = μ.masses_dvector(); + + // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. + // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ + // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ + // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 + // = n |𝒟| |x|_2, where n is the number of points. Therefore + let Ã_normest = op𝒟norm * F::cast_from(μ.len()); + + // Solve finite-dimensional subproblem. + stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); + + // Update masses of μ based on solution of finite-dimensional subproblem. + μ.set_masses_dvector(&x); + } + + // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality + // conditions in the predual space, and finding new points for insertion, if necessary. + let mut d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_base, ν_delta, config)); + + // If no merging heuristic is used, let's be more conservative about spike insertion, + // and skip it after first round. If merging is done, being more greedy about spike + // insertion also seems to improve performance. + let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { + false + } else { + count > 0 + }; + + // Find a spike to insert, if needed + let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( + &mut d, τ, ε, skip_by_rough_check, config + ) { + None => break 'insertion (true, d), + Some(res) => res, + }; + + // Break if maximum insertion count reached + if count >= max_insertions { + break 'insertion (in_bounds, d) + } + + // No point in optimising the weight here; the finite-dimensional algorithm is fast. + *μ += DeltaMeasure { x : ξ, α : 0.0 }; + count += 1; + }; + + // TODO: should redo everything if some transports cause a problem. + // Maybe implementation should call above loop as a closure. + + if !within_tolerances && warn_insertions { + // Complain (but continue) if we failed to get within tolerances + // by inserting more points. + let err = format!("Maximum insertions reached without achieving \ + subproblem solution tolerance"); + println!("{}", err.red()); } - fn find_tolerance_violation( - &self, - d : &mut BTFN, - τ : F, - ε : F, - skip_by_rough_check : bool, - config : &FBGenericConfig, - ) -> Option<(Loc, F, bool)> - where BT : BTSearch>, - G : SupportGenerator, - G::SupportType : Mapping,Codomain=F> - + LocalAnalysis, N> { - let τα = τ * self.α(); - let keep_below = τα + ε; - let keep_above = -τα - ε; - let maximise_above = τα + ε * config.insertion_cutoff_factor; - let minimise_below = -τα - ε * config.insertion_cutoff_factor; - let refinement_tolerance = ε * config.refinement.tolerance_mult; + (d, within_tolerances) +} - // If preliminary check indicates that we are in bonds, and if it otherwise matches - // the insertion strategy, skip insertion. - if skip_by_rough_check && Bounds(keep_above, keep_below).superset(&d.bounds()) { - None - } else { - // If the rough check didn't indicate no insertion needed, find maximising point. - let mx = d.maximise_above(maximise_above, refinement_tolerance, - config.refinement.max_steps); - let mi = d.minimise_below(minimise_below, refinement_tolerance, - config.refinement.max_steps); +#[replace_float_literals(F::cast_from(literal))] +pub(crate) fn prune_and_maybe_simple_merge< + 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize +>( + μ : &mut DiscreteMeasure, F>, + minus_τv : &BTFN, + μ_base : &DiscreteMeasure, F>, + op𝒟 : &'a 𝒟, + τ : F, + ε : F, + config : &FBGenericConfig, + reg : &Reg, + state : &State, + stats : &mut IterInfo, +) +where F : Float + ToNalgebraRealField, + GA : SupportGenerator + Clone, + BTA : BTSearch>, + G𝒟 : SupportGenerator + Clone, + 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, + 𝒟::Codomain : RealMapping, + S: RealMapping + LocalAnalysis, N>, + K: RealMapping + LocalAnalysis, N>, + BTNodeLookup: BTNode, N>, + DiscreteMeasure, F> : SpikeMerging, + Reg : RegTerm, + State : AlgIteratorState { + if state.iteration() % config.merge_every == 0 { + let n_before_merge = μ.len(); + μ.merge_spikes(config.merging, |μ_candidate| { + let μd = μ_diff(&μ_candidate, &μ_base, None, config); + let mut d = minus_τv + op𝒟.preapply(μd); - match (mx, mi) { - (None, None) => None, - (Some((ξ, v_ξ)), None) => Some((ξ, v_ξ, keep_below >= v_ξ)), - (None, Some((ζ, v_ζ))) => Some((ζ, v_ζ, keep_above <= v_ζ)), - (Some((ξ, v_ξ)), Some((ζ, v_ζ))) => { - if v_ξ - τα > τα - v_ζ { - Some((ξ, v_ξ, keep_below >= v_ξ)) - } else { - Some((ζ, v_ζ, keep_above <= v_ζ)) - } - } - } - } + reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) + .then_some(()) + }); + debug_assert!(μ.len() >= n_before_merge); + stats.merged += μ.len() - n_before_merge; } - fn verify_merge_candidate( - &self, - d : &mut BTFN, - μ : &DiscreteMeasure, F>, - τ : F, - ε : F, - config : &FBGenericConfig, - ) -> bool - where BT : BTSearch>, - G : SupportGenerator, - G::SupportType : Mapping,Codomain=F> - + LocalAnalysis, N> { - let τα = τ * self.α(); - let refinement_tolerance = ε * config.refinement.tolerance_mult; - let merge_tolerance = config.merge_tolerance_mult * ε; - let keep_below = τα + merge_tolerance; - let keep_above = -τα - merge_tolerance; - let keep_supp_pos_above = τα - merge_tolerance; - let keep_supp_neg_below = -τα + merge_tolerance; - let bnd = d.bounds(); - - return ( - (bnd.lower() >= keep_supp_pos_above && bnd.upper() <= keep_supp_neg_below) - || - μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| { - use std::cmp::Ordering::*; - match β.partial_cmp(&0.0) { - Some(Greater) => d.apply(x) >= keep_supp_pos_above, - Some(Less) => d.apply(x) <= keep_supp_neg_below, - _ => true, - } - }).all(std::convert::identity) - ) && ( - bnd.upper() <= keep_below - || - d.has_upper_bound(keep_below, refinement_tolerance, - config.refinement.max_steps) - ) && ( - bnd.lower() >= keep_above - || - d.has_lower_bound(keep_above, refinement_tolerance, - config.refinement.max_steps) - ) - } - - fn target_bounds(&self, τ : F, ε : F) -> Option> { - let τα = τ * self.α(); - Some(Bounds(-τα - ε, τα + ε)) - } - - fn tolerance_scaling(&self) -> F { - self.α() - } + let n_before_prune = μ.len(); + μ.prune(); + debug_assert!(μ.len() <= n_before_prune); + stats.pruned += n_before_prune - μ.len(); } +#[replace_float_literals(F::cast_from(literal))] +pub(crate) fn postprocess< + F : Float, + V : Euclidean + Clone, + A : GEMV, F>, Codomain = V>, + D : DataTerm, + const N : usize +> ( + mut μ : DiscreteMeasure, F>, + config : &FBGenericConfig, + dataterm : D, + opA : &A, + b : &V, +) -> DiscreteMeasure, F> +where DiscreteMeasure, F> : SpikeMerging { + μ.merge_spikes_fitness(config.merging, + |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), + |&v| v); + μ.prune(); + μ +} -/// Generic implementation of [`pointsource_fb_reg`]. +/// Iteratively solve the pointsource localisation problem using forward-backward splitting. /// -/// The method can be specialised to even primal-dual proximal splitting through the -/// [`FBSpecialisation`] parameter `specialisation`. -/// The settings in `config` have their [respective documentation](FBGenericConfig). `opA` is the +/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control /// as documented in [`alg_tools::iterate`]. /// +/// For details on the mathematical formulation, see the [module level](self) documentation. +/// /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of /// sums of simple functions usign bisection trees, and the related /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions @@ -729,252 +463,16 @@ /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn generic_pointsource_fb_reg< - 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, Reg, const N : usize +pub fn pointsource_fb_reg< + 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize >( opA : &'a A, - reg : Reg, - op𝒟 : &'a 𝒟, - mut τ : F, - config : &FBGenericConfig, - iterator : I, - mut plotter : SeqPlotter, - mut residual : A::Observable, - mut specialisation : Spec -) -> DiscreteMeasure, F> -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - Spec : FBSpecialisation, - A::Observable : std::ops::MulAssign, - GA : SupportGenerator + Clone, - A : ForwardModel, F, PreadjointCodomain = BTFN> - + Lipschitz<𝒟, FloatType=F>, - BTA : BTSearch>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, - 𝒟::Codomain : RealMapping, - S: RealMapping + LocalAnalysis, N>, - K: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - PlotLookup : Plotting, - DiscreteMeasure, F> : SpikeMerging, - Reg : RegTerm { - - // Set up parameters - let quiet = iterator.is_quiet(); - let op𝒟norm = op𝒟.opnorm_bound(); - // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled - // by τ compared to the conditional gradient approach. - let tolerance = config.tolerance * τ * reg.tolerance_scaling(); - let mut ε = tolerance.initial(); - - // Initialise operators - let preadjA = opA.preadjoint(); - - // Initialise iterates - let mut μ = DiscreteMeasure::new(); - - let mut inner_iters = 0; - let mut this_iters = 0; - let mut pruned = 0; - let mut merged = 0; - - let μ_diff = |μ_new : &DiscreteMeasure, F>, - μ_base : &DiscreteMeasure, F>| { - let mut ν : DiscreteMeasure, F> = match config.insertion_style { - InsertionStyle::Reuse => { - μ_new.iter_spikes() - .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0))) - .map(|(δ, α_base)| (δ.x, α_base - δ.α)) - .collect() - }, - InsertionStyle::Zero => { - μ_new.iter_spikes() - .map(|δ| -δ) - .chain(μ_base.iter_spikes().copied()) - .collect() - } - }; - ν.prune(); // Potential small performance improvement - ν - }; - - // Run the algorithm - iterator.iterate(|state| { - // Maximum insertion count and measure difference calculation depend on insertion style. - let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { - (i, Some((l, k))) if i <= l => (k, false), - _ => (config.max_insertions, !quiet), - }; - let max_insertions = match config.insertion_style { - InsertionStyle::Zero => { - todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled."); - // let n = μ.len(); - // μ = DiscreteMeasure::new(); - // n + m - }, - InsertionStyle::Reuse => m, - }; - - // Calculate smooth part of surrogate model. - // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` - // has no significant overhead. For some reosn Rust doesn't allow us simply moving - // the residual and replacing it below before the end of this closure. - residual *= -τ; - let r = std::mem::replace(&mut residual, opA.empty_observable()); - let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b) - // TODO: should avoid a second copy of μ here; μ_base already stores a copy. - let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k - //let g = &minus_τv + ω0; // Linear term of surrogate model - - // Save current base point - let μ_base = μ.clone(); - - // Add points to support until within error tolerance or maximum insertion count reached. - let mut count = 0; - let (within_tolerances, d) = 'insertion: loop { - if μ.len() > 0 { - // Form finite-dimensional subproblem. The subproblem references to the original μ^k - // from the beginning of the iteration are all contained in the immutable c and g. - let à = op𝒟.findim_matrix(μ.iter_locations()); - let g̃ = DVector::from_iterator(μ.len(), - μ.iter_locations() - .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ)) - .map(F::to_nalgebra_mixed)); - let mut x = μ.masses_dvector(); - - // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. - // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ - // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ - // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 - // = n |𝒟| |x|_2, where n is the number of points. Therefore - let Ã_normest = op𝒟norm * F::cast_from(μ.len()); - - // Solve finite-dimensional subproblem. - inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); - - // Update masses of μ based on solution of finite-dimensional subproblem. - μ.set_masses_dvector(&x); - } - - // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality - // conditions in the predual space, and finding new points for insertion, if necessary. - let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ, &μ_base)); - - // If no merging heuristic is used, let's be more conservative about spike insertion, - // and skip it after first round. If merging is done, being more greedy about spike - // insertion also seems to improve performance. - let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { - false - } else { - count > 0 - }; - - // Find a spike to insert, if needed - let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( - &mut d, τ, ε, skip_by_rough_check, config - ) { - None => break 'insertion (true, d), - Some(res) => res, - }; - - // Break if maximum insertion count reached - if count >= max_insertions { - break 'insertion (in_bounds, d) - } - - // No point in optimising the weight here; the finite-dimensional algorithm is fast. - μ += DeltaMeasure { x : ξ, α : 0.0 }; - count += 1; - }; - - if !within_tolerances && warn_insertions { - // Complain (but continue) if we failed to get within tolerances - // by inserting more points. - let err = format!("Maximum insertions reached without achieving \ - subproblem solution tolerance"); - println!("{}", err.red()); - } - - // Merge spikes - if state.iteration() % config.merge_every == 0 { - let n_before_merge = μ.len(); - μ.merge_spikes(config.merging, |μ_candidate| { - let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ_candidate, &μ_base)); - - reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) - .then_some(()) - }); - debug_assert!(μ.len() >= n_before_merge); - merged += μ.len() - n_before_merge; - } - - let n_before_prune = μ.len(); - (residual, τ) = match specialisation.update(&mut μ, &μ_base) { - (r, None) => (r, τ), - (r, Some(new_τ)) => (r, new_τ) - }; - debug_assert!(μ.len() <= n_before_prune); - pruned += n_before_prune - μ.len(); - - this_iters += 1; - - // Update main tolerance for next iteration - let ε_prev = ε; - ε = tolerance.update(ε, state.iteration()); - - // Give function value if needed - state.if_verbose(|| { - let value_μ = specialisation.value_μ(&μ); - // Plot if so requested - plotter.plot_spikes( - format!("iter {} end; {}", state.iteration(), within_tolerances), &d, - "start".to_string(), Some(&minus_τv), - reg.target_bounds(τ, ε_prev), value_μ, - ); - // Calculate mean inner iterations and reset relevant counters. - // Return the statistics - let res = IterInfo { - value : specialisation.calculate_fit(&μ, &residual) + reg.apply(value_μ), - n_spikes : value_μ.len(), - inner_iters, - this_iters, - merged, - pruned, - ε : ε_prev, - postprocessing: config.postprocessing.then(|| value_μ.clone()), - }; - inner_iters = 0; - this_iters = 0; - merged = 0; - pruned = 0; - res - }) - }); - - specialisation.postprocess(μ, config.final_merging) -} - -/// Iteratively solve the pointsource localisation problem using forward-backward splitting -/// -/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the -/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. -/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution -/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control -/// as documented in [`alg_tools::iterate`]. -/// -/// For details on the mathematical formulation, see the [module level](self) documentation. -/// -/// Returns the final iterate. -#[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>( - opA : &'a A, b : &A::Observable, reg : Reg, op𝒟 : &'a 𝒟, - config : &FBConfig, + fbconfig : &FBConfig, iterator : I, - plotter : SeqPlotter, + mut plotter : SeqPlotter, ) -> DiscreteMeasure, F> where F : Float + ToNalgebraRealField, I : AlgIteratorFactory>, @@ -983,7 +481,7 @@ A::Observable : std::ops::MulAssign, GA : SupportGenerator + Clone, A : ForwardModel, F, PreadjointCodomain = BTFN> - + Lipschitz<𝒟, FloatType=F>, + + Lipschitz<&'a 𝒟, FloatType=F>, BTA : BTSearch>, G𝒟 : SupportGenerator + Clone, 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, @@ -996,17 +494,227 @@ DiscreteMeasure, F> : SpikeMerging, Reg : RegTerm { - let initial_residual = -b; - let τ = config.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); + // Set up parameters + let config = &fbconfig.insertion; + let op𝒟norm = op𝒟.opnorm_bound(); + let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); + // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled + // by τ compared to the conditional gradient approach. + let tolerance = config.tolerance * τ * reg.tolerance_scaling(); + let mut ε = tolerance.initial(); + + // Initialise iterates + let mut μ = DiscreteMeasure::new(); + let mut residual = -b; + let mut stats = IterInfo::new(); + + // Run the algorithm + iterator.iterate(|state| { + // Calculate smooth part of surrogate model. + // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` + // has no significant overhead. For some reosn Rust doesn't allow us simply moving + // the residual and replacing it below before the end of this closure. + residual *= -τ; + let r = std::mem::replace(&mut residual, opA.empty_observable()); + let minus_τv = opA.preadjoint().apply(r); + + // Save current base point + let μ_base = μ.clone(); + + // Insert and reweigh + let (d, within_tolerances) = insert_and_reweigh( + &mut μ, &minus_τv, &μ_base, None, + op𝒟, op𝒟norm, + τ, ε, + config, ®, state, &mut stats + ); + + // Prune and possibly merge spikes + prune_and_maybe_simple_merge( + &mut μ, &minus_τv, &μ_base, + op𝒟, + τ, ε, + config, ®, state, &mut stats + ); + + // Update residual + residual = calculate_residual(&μ, opA, b); + + // Update main tolerance for next iteration + let ε_prev = ε; + ε = tolerance.update(ε, state.iteration()); + stats.this_iters += 1; + + // Give function value if needed + state.if_verbose(|| { + // Plot if so requested + plotter.plot_spikes( + format!("iter {} end; {}", state.iteration(), within_tolerances), &d, + "start".to_string(), Some(&minus_τv), + reg.target_bounds(τ, ε_prev), &μ, + ); + // Calculate mean inner iterations and reset relevant counters. + // Return the statistics + let res = IterInfo { + value : residual.norm2_squared_div2() + reg.apply(&μ), + n_spikes : μ.len(), + ε : ε_prev, + postprocessing: config.postprocessing.then(|| μ.clone()), + .. stats + }; + stats = IterInfo::new(); + res + }) + }); + + postprocess(μ, config, L2Squared, opA, b) +} - match config.meta { - FBMetaAlgorithm::None => generic_pointsource_fb_reg( - opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, - BasicFB{ b, opA }, - ), - FBMetaAlgorithm::InertiaFISTA => generic_pointsource_fb_reg( - opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, - FISTA{ b, opA, λ : 1.0, μ_prev : DiscreteMeasure::new() }, - ), - } +/// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. +/// +/// The settings in `config` have their [respective documentation](FBConfig). `opA` is the +/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. +/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution +/// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control +/// as documented in [`alg_tools::iterate`]. +/// +/// For details on the mathematical formulation, see the [module level](self) documentation. +/// +/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of +/// sums of simple functions usign bisection trees, and the related +/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions +/// active at a specific points, and to maximise their sums. Through the implementation of the +/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features +/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. +/// +/// Returns the final iterate. +#[replace_float_literals(F::cast_from(literal))] +pub fn pointsource_fista_reg< + 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize +>( + opA : &'a A, + b : &A::Observable, + reg : Reg, + op𝒟 : &'a 𝒟, + fbconfig : &FBConfig, + iterator : I, + mut plotter : SeqPlotter, +) -> DiscreteMeasure, F> +where F : Float + ToNalgebraRealField, + I : AlgIteratorFactory>, + for<'b> &'b A::Observable : std::ops::Neg, + //+ std::ops::Mul, <-- FIXME: compiler overflow + A::Observable : std::ops::MulAssign, + GA : SupportGenerator + Clone, + A : ForwardModel, F, PreadjointCodomain = BTFN> + + Lipschitz<&'a 𝒟, FloatType=F>, + BTA : BTSearch>, + G𝒟 : SupportGenerator + Clone, + 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, + 𝒟::Codomain : RealMapping, + S: RealMapping + LocalAnalysis, N>, + K: RealMapping + LocalAnalysis, N>, + BTNodeLookup: BTNode, N>, + Cube: P2Minimise, F>, + PlotLookup : Plotting, + DiscreteMeasure, F> : SpikeMerging, + Reg : RegTerm { + + // Set up parameters + let config = &fbconfig.insertion; + let op𝒟norm = op𝒟.opnorm_bound(); + let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); + let mut λ = 1.0; + // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled + // by τ compared to the conditional gradient approach. + let tolerance = config.tolerance * τ * reg.tolerance_scaling(); + let mut ε = tolerance.initial(); + + // Initialise iterates + let mut μ = DiscreteMeasure::new(); + let mut μ_prev = DiscreteMeasure::new(); + let mut residual = -b; + let mut stats = IterInfo::new(); + let mut warned_merging = false; + + // Run the algorithm + iterator.iterate(|state| { + // Calculate smooth part of surrogate model. + // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` + // has no significant overhead. For some reosn Rust doesn't allow us simply moving + // the residual and replacing it below before the end of this closure. + residual *= -τ; + let r = std::mem::replace(&mut residual, opA.empty_observable()); + let minus_τv = opA.preadjoint().apply(r); + + // Save current base point + let μ_base = μ.clone(); + + // Insert new spikes and reweigh + let (d, within_tolerances) = insert_and_reweigh( + &mut μ, &minus_τv, &μ_base, None, + op𝒟, op𝒟norm, + τ, ε, + config, ®, state, &mut stats + ); + + // (Do not) merge spikes. + if state.iteration() % config.merge_every == 0 { + match config.merging { + SpikeMergingMethod::None => { }, + _ => if !warned_merging { + let err = format!("Merging not supported for μFISTA"); + println!("{}", err.red()); + warned_merging = true; + } + } + } + + // Update inertial prameters + let λ_prev = λ; + λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); + let θ = λ / λ_prev - λ; + + // Perform inertial update on μ. + // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ + // and μ_prev have zero weight. Since both have weights from the finite-dimensional + // subproblem with a proximal projection step, this is likely to happen when the + // spike is not needed. A copy of the pruned μ without artithmetic performed is + // stored in μ_prev. + let n_before_prune = μ.len(); + μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); + debug_assert!(μ.len() <= n_before_prune); + stats.pruned += n_before_prune - μ.len(); + + // Update residual + residual = calculate_residual(&μ, opA, b); + + // Update main tolerance for next iteration + let ε_prev = ε; + ε = tolerance.update(ε, state.iteration()); + stats.this_iters += 1; + + // Give function value if needed + state.if_verbose(|| { + // Plot if so requested + plotter.plot_spikes( + format!("iter {} end; {}", state.iteration(), within_tolerances), &d, + "start".to_string(), Some(&minus_τv), + reg.target_bounds(τ, ε_prev), &μ_prev, + ); + // Calculate mean inner iterations and reset relevant counters. + // Return the statistics + let res = IterInfo { + value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev), + n_spikes : μ_prev.len(), + ε : ε_prev, + postprocessing: config.postprocessing.then(|| μ_prev.clone()), + .. stats + }; + stats = IterInfo::new(); + res + }) + }); + + postprocess(μ_prev, config, L2Squared, opA, b) }