src/fb.rs

changeset 52
f0e8704d3f0e
parent 51
0693cc9ba9f0
equal deleted inserted replaced
31:6105b5cd8d89 52:f0e8704d3f0e
4 This corresponds to the manuscript 4 This corresponds to the manuscript
5 5
6 * Valkonen T. - _Proximal methods for point source localisation_, 6 * Valkonen T. - _Proximal methods for point source localisation_,
7 [arXiv:2212.02991](https://arxiv.org/abs/2212.02991). 7 [arXiv:2212.02991](https://arxiv.org/abs/2212.02991).
8 8
9 The main routine is [`pointsource_fb_reg`]. It is based on [`generic_pointsource_fb_reg`], which is 9 The main routine is [`pointsource_fb_reg`].
10 also used by our [primal-dual proximal splitting][crate::pdps] implementation.
11
12 FISTA-type inertia can also be enabled through [`FBConfig::meta`].
13 10
14 ## Problem 11 ## Problem
15 12
16 <p> 13 <p>
17 Our objective is to solve 14 Our objective is to solve
74 = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2. 71 = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2.
75 \end{aligned} 72 \end{aligned}
76 $$ 73 $$
77 </p> 74 </p>
78 75
79 We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by 76 We solve this with either SSN or FB as determined by
80 [`InnerSettings`] in [`FBGenericConfig::inner`]. 77 [`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`].
81 */ 78 */
82 79
80 use colored::Colorize;
83 use numeric_literals::replace_float_literals; 81 use numeric_literals::replace_float_literals;
84 use serde::{Serialize, Deserialize}; 82 use serde::{Deserialize, Serialize};
85 use colored::Colorize; 83
86 use nalgebra::{DVector, DMatrix};
87
88 use alg_tools::iterate::{
89 AlgIteratorFactory,
90 AlgIteratorState,
91 };
92 use alg_tools::euclidean::Euclidean; 84 use alg_tools::euclidean::Euclidean;
93 use alg_tools::linops::Apply; 85 use alg_tools::instance::Instance;
94 use alg_tools::sets::Cube; 86 use alg_tools::iterate::AlgIteratorFactory;
95 use alg_tools::loc::Loc; 87 use alg_tools::linops::{Mapping, GEMV};
96 use alg_tools::mapping::Mapping;
97 use alg_tools::bisection_tree::{
98 BTFN,
99 PreBTFN,
100 Bounds,
101 BTNodeLookup,
102 BTNode,
103 BTSearch,
104 P2Minimise,
105 SupportGenerator,
106 LocalAnalysis,
107 Bounded,
108 };
109 use alg_tools::mapping::RealMapping; 88 use alg_tools::mapping::RealMapping;
110 use alg_tools::nalgebra_support::ToNalgebraRealField; 89 use alg_tools::nalgebra_support::ToNalgebraRealField;
111 90
91 use crate::dataterm::{calculate_residual, DataTerm, L2Squared};
92 use crate::forward_model::{AdjointProductBoundedBy, ForwardModel};
93 use crate::measures::merging::SpikeMerging;
94 use crate::measures::{DiscreteMeasure, RNDM};
95 use crate::plot::{PlotLookup, Plotting, SeqPlotter};
96 pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty};
97 use crate::regularisation::RegTerm;
112 use crate::types::*; 98 use crate::types::*;
113 use crate::measures::{
114 DiscreteMeasure,
115 DeltaMeasure,
116 };
117 use crate::measures::merging::{
118 SpikeMergingMethod,
119 SpikeMerging,
120 };
121 use crate::forward_model::ForwardModel;
122 use crate::seminorms::{
123 DiscreteMeasureOp, Lipschitz
124 };
125 use crate::subproblem::{
126 nonneg::quadratic_nonneg,
127 unconstrained::quadratic_unconstrained,
128 InnerSettings,
129 InnerMethod,
130 };
131 use crate::tolerance::Tolerance;
132 use crate::plot::{
133 SeqPlotter,
134 Plotting,
135 PlotLookup
136 };
137 use crate::regularisation::{
138 NonnegRadonRegTerm,
139 RadonRegTerm,
140 };
141
142 /// Method for constructing $μ$ on each iteration
143 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
144 #[allow(dead_code)]
145 pub enum InsertionStyle {
146 /// Resuse previous $μ$ from previous iteration, optimising weights
147 /// before inserting new spikes.
148 Reuse,
149 /// Start each iteration with $μ=0$.
150 Zero,
151 }
152
153 /// Meta-algorithm type
154 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
155 #[allow(dead_code)]
156 pub enum FBMetaAlgorithm {
157 /// No meta-algorithm
158 None,
159 /// FISTA-style inertia
160 InertiaFISTA,
161 }
162 99
163 /// Settings for [`pointsource_fb_reg`]. 100 /// Settings for [`pointsource_fb_reg`].
164 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 101 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
165 #[serde(default)] 102 #[serde(default)]
166 pub struct FBConfig<F : Float> { 103 pub struct FBConfig<F: Float> {
167 /// Step length scaling 104 /// Step length scaling
168 pub τ0 : F, 105 pub τ0: F,
169 /// Meta-algorithm to apply
170 pub meta : FBMetaAlgorithm,
171 /// Generic parameters 106 /// Generic parameters
172 pub insertion : FBGenericConfig<F>, 107 pub generic: FBGenericConfig<F>,
173 }
174
175 /// Settings for the solution of the stepwise optimality condition in algorithms based on
176 /// [`generic_pointsource_fb_reg`].
177 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
178 #[serde(default)]
179 pub struct FBGenericConfig<F : Float> {
180 /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`].
181 pub insertion_style : InsertionStyle,
182 /// Tolerance for point insertion.
183 pub tolerance : Tolerance<F>,
184 /// Stop looking for predual maximum (where to isert a new point) below
185 /// `tolerance` multiplied by this factor.
186 pub insertion_cutoff_factor : F,
187 /// Settings for branch and bound refinement when looking for predual maxima
188 pub refinement : RefinementSettings<F>,
189 /// Maximum insertions within each outer iteration
190 pub max_insertions : usize,
191 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations.
192 pub bootstrap_insertions : Option<(usize, usize)>,
193 /// Inner method settings
194 pub inner : InnerSettings<F>,
195 /// Spike merging method
196 pub merging : SpikeMergingMethod<F>,
197 /// Tolerance multiplier for merges
198 pub merge_tolerance_mult : F,
199 /// Spike merging method after the last step
200 pub final_merging : SpikeMergingMethod<F>,
201 /// Iterations between merging heuristic tries
202 pub merge_every : usize,
203 /// Save $μ$ for postprocessing optimisation
204 pub postprocessing : bool
205 } 108 }
206 109
207 #[replace_float_literals(F::cast_from(literal))] 110 #[replace_float_literals(F::cast_from(literal))]
208 impl<F : Float> Default for FBConfig<F> { 111 impl<F: Float> Default for FBConfig<F> {
209 fn default() -> Self { 112 fn default() -> Self {
210 FBConfig { 113 FBConfig {
211 τ0 : 0.99, 114 τ0: 0.99,
212 meta : FBMetaAlgorithm::None, 115 generic: Default::default(),
213 insertion : Default::default()
214 } 116 }
215 } 117 }
216 } 118 }
217 119
120 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize {
121 let n_before_prune = μ.len();
122 μ.prune();
123 debug_assert!(μ.len() <= n_before_prune);
124 n_before_prune - μ.len()
125 }
126
218 #[replace_float_literals(F::cast_from(literal))] 127 #[replace_float_literals(F::cast_from(literal))]
219 impl<F : Float> Default for FBGenericConfig<F> { 128 pub(crate) fn postprocess<
220 fn default() -> Self { 129 F: Float,
221 FBGenericConfig { 130 V: Euclidean<F> + Clone,
222 insertion_style : InsertionStyle::Reuse, 131 A: GEMV<F, RNDM<F, N>, Codomain = V>,
223 tolerance : Default::default(), 132 D: DataTerm<F, V, N>,
224 insertion_cutoff_factor : 1.0, 133 const N: usize,
225 refinement : Default::default(), 134 >(
226 max_insertions : 100, 135 mut μ: RNDM<F, N>,
227 //bootstrap_insertions : None, 136 config: &FBGenericConfig<F>,
228 bootstrap_insertions : Some((10, 1)), 137 dataterm: D,
229 inner : InnerSettings { 138 opA: &A,
230 method : InnerMethod::SSN, 139 b: &V,
231 .. Default::default() 140 ) -> RNDM<F, N>
232 }, 141 where
233 merging : SpikeMergingMethod::None, 142 RNDM<F, N>: SpikeMerging<F>,
234 //merging : Default::default(), 143 for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>,
235 final_merging : Default::default(), 144 {
236 merge_every : 10, 145 μ.merge_spikes_fitness(
237 merge_tolerance_mult : 2.0, 146 config.final_merging_method(),
238 postprocessing : false, 147 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
148 |&v| v,
149 );
150 μ.prune();
151 μ
152 }
153
154 /// Iteratively solve the pointsource localisation problem using forward-backward splitting.
155 ///
156 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
157 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
158 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
159 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
160 /// as documented in [`alg_tools::iterate`].
161 ///
162 /// For details on the mathematical formulation, see the [module level](self) documentation.
163 ///
164 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
165 /// sums of simple functions usign bisection trees, and the related
166 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
167 /// active at a specific points, and to maximise their sums. Through the implementation of the
168 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
169 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
170 ///
171 /// Returns the final iterate.
172 #[replace_float_literals(F::cast_from(literal))]
173 pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>(
174 opA: &A,
175 b: &A::Observable,
176 reg: Reg,
177 prox_penalty: &P,
178 fbconfig: &FBConfig<F>,
179 iterator: I,
180 mut plotter: SeqPlotter<F, N>,
181 ) -> RNDM<F, N>
182 where
183 F: Float + ToNalgebraRealField,
184 I: AlgIteratorFactory<IterInfo<F, N>>,
185 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>,
186 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>,
187 A::PreadjointCodomain: RealMapping<F, N>,
188 PlotLookup: Plotting<N>,
189 RNDM<F, N>: SpikeMerging<F>,
190 Reg: RegTerm<F, N>,
191 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
192 {
193 // Set up parameters
194 let config = &fbconfig.generic;
195 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
196 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
197 // by τ compared to the conditional gradient approach.
198 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
199 let mut ε = tolerance.initial();
200
201 // Initialise iterates
202 let mut μ = DiscreteMeasure::new();
203 let mut residual = -b;
204
205 // Statistics
206 let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo {
207 value: residual.norm2_squared_div2() + reg.apply(μ),
208 n_spikes: μ.len(),
209 ε,
210 //postprocessing: config.postprocessing.then(|| μ.clone()),
211 ..stats
212 };
213 let mut stats = IterInfo::new();
214
215 // Run the algorithm
216 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
217 // Calculate smooth part of surrogate model.
218 let mut τv = opA.preadjoint().apply(residual * τ);
219
220 // Save current base point
221 let μ_base = μ.clone();
222
223 // Insert and reweigh
224 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
225 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
226 );
227
228 // Prune and possibly merge spikes
229 if config.merge_now(&state) {
230 stats.merged += prox_penalty.merge_spikes(
231 &mut μ,
232 &mut τv,
233 &μ_base,
234 None,
235 τ,
236 ε,
237 config,
238 &reg,
239 Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
240 );
239 } 241 }
242
243 stats.pruned += prune_with_stats(&mut μ);
244
245 // Update residual
246 residual = calculate_residual(&μ, opA, b);
247
248 let iter = state.iteration();
249 stats.this_iters += 1;
250
251 // Give statistics if needed
252 state.if_verbose(|| {
253 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
254 full_stats(
255 &residual,
256 &μ,
257 ε,
258 std::mem::replace(&mut stats, IterInfo::new()),
259 )
260 });
261
262 // Update main tolerance for next iteration
263 ε = tolerance.update(ε, iter);
240 } 264 }
241 } 265
242 266 postprocess(μ, config, L2Squared, opA, b)
243 /// Trait for specialisation of [`generic_pointsource_fb_reg`] to basic FB, FISTA. 267 }
244 /// 268
245 /// The idea is that the residual $Aμ - b$ in the forward step can be replaced by an arbitrary 269 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting.
246 /// value. For example, to implement [primal-dual proximal splitting][crate::pdps] we replace it 270 ///
247 /// with the dual variable $y$. We can then also implement alternative data terms, as the 271 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
248 /// (pre)differential of $F(μ)=F\_0(Aμ-b)$ is $F\'(μ) = A\_*F\_0\'(Aμ-b)$. In the case of the 272 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
249 /// quadratic fidelity $F_0(y)=\frac{1}{2}\\|y\\|_2^2$ in a Hilbert space, of course, 273 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
250 /// $F\_0\'(Aμ-b)=Aμ-b$ is the residual. 274 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
251 pub trait FBSpecialisation<F : Float, Observable : Euclidean<F>, const N : usize> : Sized { 275 /// as documented in [`alg_tools::iterate`].
252 /// Updates the residual and does any necessary pruning of `μ`. 276 ///
253 /// 277 /// For details on the mathematical formulation, see the [module level](self) documentation.
254 /// Returns the new residual and possibly a new step length. 278 ///
255 /// 279 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
256 /// The measure `μ` may also be modified to apply, e.g., inertia to it. 280 /// sums of simple functions usign bisection trees, and the related
257 /// The updated residual should correspond to the residual at `μ`. 281 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
258 /// See the [trait documentation][FBSpecialisation] for the use and meaning of the residual. 282 /// active at a specific points, and to maximise their sums. Through the implementation of the
259 /// 283 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
260 /// The parameter `μ_base` is the base point of the iteration, typically the previous iterate, 284 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
261 /// but for, e.g., FISTA has inertia applied to it. 285 ///
262 fn update( 286 /// Returns the final iterate.
263 &mut self,
264 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
265 μ_base : &DiscreteMeasure<Loc<F, N>, F>,
266 ) -> (Observable, Option<F>);
267
268 /// Calculates the data term value corresponding to iterate `μ` and available residual.
269 ///
270 /// Inertia and other modifications, as deemed, necessary, should be applied to `μ`.
271 ///
272 /// The blanket implementation correspondsn to the 2-norm-squared data fidelity
273 /// $\\|\text{residual}\\|\_2^2/2$.
274 fn calculate_fit(
275 &self,
276 _μ : &DiscreteMeasure<Loc<F, N>, F>,
277 residual : &Observable
278 ) -> F {
279 residual.norm2_squared_div2()
280 }
281
282 /// Calculates the data term value at $μ$.
283 ///
284 /// Unlike [`Self::calculate_fit`], no inertia, etc., should be applied to `μ`.
285 fn calculate_fit_simple(
286 &self,
287 μ : &DiscreteMeasure<Loc<F, N>, F>,
288 ) -> F;
289
290 /// Returns the final iterate after any necessary postprocess pruning, merging, etc.
291 fn postprocess(self, mut μ : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>)
292 -> DiscreteMeasure<Loc<F, N>, F>
293 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
294 μ.merge_spikes_fitness(merging,
295 |μ̃| self.calculate_fit_simple(μ̃),
296 |&v| v);
297 μ.prune();
298 μ
299 }
300
301 /// Returns measure to be used for value calculations, which may differ from μ.
302 fn value_μ<'c, 'b : 'c>(&'b self, μ : &'c DiscreteMeasure<Loc<F, N>, F>)
303 -> &'c DiscreteMeasure<Loc<F, N>, F> {
304 μ
305 }
306 }
307
308 /// Specialisation of [`generic_pointsource_fb_reg`] to basic μFB.
309 struct BasicFB<
310 'a,
311 F : Float + ToNalgebraRealField,
312 A : ForwardModel<Loc<F, N>, F>,
313 const N : usize
314 > {
315 /// The data
316 b : &'a A::Observable,
317 /// The forward operator
318 opA : &'a A,
319 }
320
321 /// Implementation of [`FBSpecialisation`] for basic μFB forward-backward splitting.
322 #[replace_float_literals(F::cast_from(literal))] 287 #[replace_float_literals(F::cast_from(literal))]
323 impl<'a, F : Float + ToNalgebraRealField , A : ForwardModel<Loc<F, N>, F>, const N : usize> 288 pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>(
324 FBSpecialisation<F, A::Observable, N> for BasicFB<'a, F, A, N> { 289 opA: &A,
325 fn update( 290 b: &A::Observable,
326 &mut self, 291 reg: Reg,
327 μ : &mut DiscreteMeasure<Loc<F, N>, F>, 292 prox_penalty: &P,
328 _μ_base : &DiscreteMeasure<Loc<F, N>, F> 293 fbconfig: &FBConfig<F>,
329 ) -> (A::Observable, Option<F>) { 294 iterator: I,
330 μ.prune(); 295 mut plotter: SeqPlotter<F, N>,
331 //*residual = self.opA.apply(μ) - self.b; 296 ) -> RNDM<F, N>
332 let mut residual = self.b.clone(); 297 where
333 self.opA.gemv(&mut residual, 1.0, μ, -1.0); 298 F: Float + ToNalgebraRealField,
334 (residual, None) 299 I: AlgIteratorFactory<IterInfo<F, N>>,
335 } 300 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>,
336 301 A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>,
337 fn calculate_fit_simple( 302 A::PreadjointCodomain: RealMapping<F, N>,
338 &self, 303 PlotLookup: Plotting<N>,
339 μ : &DiscreteMeasure<Loc<F, N>, F>, 304 RNDM<F, N>: SpikeMerging<F>,
340 ) -> F { 305 Reg: RegTerm<F, N>,
341 let mut residual = self.b.clone(); 306 P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
342 self.opA.gemv(&mut residual, 1.0, μ, -1.0); 307 {
343 residual.norm2_squared_div2() 308 // Set up parameters
344 } 309 let config = &fbconfig.generic;
345 } 310 let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
346 311 let mut λ = 1.0;
347 /// Specialisation of [`generic_pointsource_fb_reg`] to FISTA. 312 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
348 struct FISTA< 313 // by τ compared to the conditional gradient approach.
349 'a, 314 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
350 F : Float + ToNalgebraRealField, 315 let mut ε = tolerance.initial();
351 A : ForwardModel<Loc<F, N>, F>, 316
352 const N : usize 317 // Initialise iterates
353 > { 318 let mut μ = DiscreteMeasure::new();
354 /// The data 319 let mut μ_prev = DiscreteMeasure::new();
355 b : &'a A::Observable, 320 let mut residual = -b;
356 /// The forward operator 321 let mut warned_merging = false;
357 opA : &'a A, 322
358 /// Current inertial parameter 323 // Statistics
359 λ : F, 324 let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo {
360 /// Previous iterate without inertia applied. 325 value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν),
361 /// We need to store this here because `μ_base` passed to [`FBSpecialisation::update`] will 326 n_spikes: ν.len(),
362 /// have inertia applied to it, so is not useful to use. 327 ε,
363 μ_prev : DiscreteMeasure<Loc<F, N>, F>, 328 // postprocessing: config.postprocessing.then(|| ν.clone()),
364 } 329 ..stats
365 330 };
366 /// Implementation of [`FBSpecialisation`] for μFISTA inertial forward-backward splitting. 331 let mut stats = IterInfo::new();
367 #[replace_float_literals(F::cast_from(literal))] 332
368 impl<'a, F : Float + ToNalgebraRealField, A : ForwardModel<Loc<F, N>, F>, const N : usize> 333 // Run the algorithm
369 FBSpecialisation<F, A::Observable, N> for FISTA<'a, F, A, N> { 334 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
370 fn update( 335 // Calculate smooth part of surrogate model.
371 &mut self, 336 let mut τv = opA.preadjoint().apply(residual * τ);
372 μ : &mut DiscreteMeasure<Loc<F, N>, F>, 337
373 _μ_base : &DiscreteMeasure<Loc<F, N>, F> 338 // Save current base point
374 ) -> (A::Observable, Option<F>) { 339 let μ_base = μ.clone();
375 // Update inertial parameters 340
376 let λ_prev = self.λ; 341 // Insert new spikes and reweigh
377 self.λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); 342 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
378 let θ = self.λ / λ_prev - self.λ; 343 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
344 );
345
346 // (Do not) merge spikes.
347 if config.merge_now(&state) && !warned_merging {
348 let err = format!("Merging not supported for μFISTA");
349 println!("{}", err.red());
350 warned_merging = true;
351 }
352
353 // Update inertial prameters
354 let λ_prev = λ;
355 λ = 2.0 * λ_prev / (λ_prev + (4.0 + λ_prev * λ_prev).sqrt());
356 let θ = λ / λ_prev - λ;
357
379 // Perform inertial update on μ. 358 // Perform inertial update on μ.
380 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ 359 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ
381 // and μ_prev have zero weight. Since both have weights from the finite-dimensional 360 // and μ_prev have zero weight. Since both have weights from the finite-dimensional
382 // subproblem with a proximal projection step, this is likely to happen when the 361 // subproblem with a proximal projection step, this is likely to happen when the
383 // spike is not needed. A copy of the pruned μ without artithmetic performed is 362 // spike is not needed. A copy of the pruned μ without artithmetic performed is
384 // stored in μ_prev. 363 // stored in μ_prev.
385 μ.pruning_sub(1.0 + θ, θ, &mut self.μ_prev); 364 let n_before_prune = μ.len();
386 365 μ.pruning_sub(1.0 + θ, θ, &mut μ_prev);
387 //*residual = self.opA.apply(μ) - self.b; 366 //let μ_new = (&μ * (1.0 + θ)).sub_matching(&(&μ_prev * θ));
388 let mut residual = self.b.clone(); 367 // μ_prev = μ;
389 self.opA.gemv(&mut residual, 1.0, μ, -1.0); 368 // μ = μ_new;
390 (residual, None) 369 debug_assert!(μ.len() <= n_before_prune);
370 stats.pruned += n_before_prune - μ.len();
371
372 // Update residual
373 residual = calculate_residual(&μ, opA, b);
374
375 let iter = state.iteration();
376 stats.this_iters += 1;
377
378 // Give statistics if needed
379 state.if_verbose(|| {
380 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ_prev);
381 full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new()))
382 });
383
384 // Update main tolerance for next iteration
385 ε = tolerance.update(ε, iter);
391 } 386 }
392 387
393 fn calculate_fit_simple( 388 postprocess(μ_prev, config, L2Squared, opA, b)
394 &self, 389 }
395 μ : &DiscreteMeasure<Loc<F, N>, F>,
396 ) -> F {
397 let mut residual = self.b.clone();
398 self.opA.gemv(&mut residual, 1.0, μ, -1.0);
399 residual.norm2_squared_div2()
400 }
401
402 fn calculate_fit(
403 &self,
404 _μ : &DiscreteMeasure<Loc<F, N>, F>,
405 _residual : &A::Observable
406 ) -> F {
407 self.calculate_fit_simple(&self.μ_prev)
408 }
409
410 // For FISTA we need to do a final pruning as well, due to the limited
411 // pruning that can be done on each step.
412 fn postprocess(mut self, μ_base : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>)
413 -> DiscreteMeasure<Loc<F, N>, F>
414 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
415 let mut μ = self.μ_prev;
416 self.μ_prev = μ_base;
417 μ.merge_spikes_fitness(merging,
418 |μ̃| self.calculate_fit_simple(μ̃),
419 |&v| v);
420 μ.prune();
421 μ
422 }
423
424 fn value_μ<'c, 'b : 'c>(&'c self, _μ : &'c DiscreteMeasure<Loc<F, N>, F>)
425 -> &'c DiscreteMeasure<Loc<F, N>, F> {
426 &self.μ_prev
427 }
428 }
429
430
431 /// Abstraction of regularisation terms for [`generic_pointsource_fb_reg`].
432 pub trait RegTerm<F : Float + ToNalgebraRealField, const N : usize>
433 : for<'a> Apply<&'a DiscreteMeasure<Loc<F, N>, F>, Output = F> {
434 /// Approximately solve the problem
435 /// <div>$$
436 /// \min_{x ∈ ℝ^n} \frac{1}{2} x^⊤Ax - g^⊤ x + τ G(x)
437 /// $$</div>
438 /// for $G$ depending on the trait implementation.
439 ///
440 /// The parameter `mA` is $A$. An estimate for its opeator norm should be provided in
441 /// `mA_normest`. The initial iterate and output is `x`. The current main tolerance is `ε`.
442 ///
443 /// Returns the number of iterations taken.
444 fn solve_findim(
445 &self,
446 mA : &DMatrix<F::MixedType>,
447 g : &DVector<F::MixedType>,
448 τ : F,
449 x : &mut DVector<F::MixedType>,
450 mA_normest : F,
451 ε : F,
452 config : &FBGenericConfig<F>
453 ) -> usize;
454
455 /// Find a point where `d` may violate the tolerance `ε`.
456 ///
457 /// If `skip_by_rough_check` is set, do not find the point if a rough check indicates that we
458 /// are in bounds. `ε` is the current main tolerance and `τ` a scaling factor for the
459 /// regulariser.
460 ///
461 /// Returns `None` if `d` is in bounds either based on the rough check, or a more precise check
462 /// terminating early. Otherwise returns a possibly violating point, the value of `d` there,
463 /// and a boolean indicating whether the found point is in bounds.
464 fn find_tolerance_violation<G, BT>(
465 &self,
466 d : &mut BTFN<F, G, BT, N>,
467 τ : F,
468 ε : F,
469 skip_by_rough_check : bool,
470 config : &FBGenericConfig<F>,
471 ) -> Option<(Loc<F, N>, F, bool)>
472 where BT : BTSearch<F, N, Agg=Bounds<F>>,
473 G : SupportGenerator<F, N, Id=BT::Data>,
474 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
475 + LocalAnalysis<F, Bounds<F>, N>;
476
477 /// Verify that `d` is in bounds `ε` for a merge candidate `μ`
478 ///
479 /// `ε` is the current main tolerance and `τ` a scaling factor for the regulariser.
480 fn verify_merge_candidate<G, BT>(
481 &self,
482 d : &mut BTFN<F, G, BT, N>,
483 μ : &DiscreteMeasure<Loc<F, N>, F>,
484 τ : F,
485 ε : F,
486 config : &FBGenericConfig<F>,
487 ) -> bool
488 where BT : BTSearch<F, N, Agg=Bounds<F>>,
489 G : SupportGenerator<F, N, Id=BT::Data>,
490 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
491 + LocalAnalysis<F, Bounds<F>, N>;
492
493 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>>;
494
495 /// Returns a scaling factor for the tolerance sequence.
496 ///
497 /// Typically this is the regularisation parameter.
498 fn tolerance_scaling(&self) -> F;
499 }
500
501 #[replace_float_literals(F::cast_from(literal))]
502 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for NonnegRadonRegTerm<F>
503 where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
504 fn solve_findim(
505 &self,
506 mA : &DMatrix<F::MixedType>,
507 g : &DVector<F::MixedType>,
508 τ : F,
509 x : &mut DVector<F::MixedType>,
510 mA_normest : F,
511 ε : F,
512 config : &FBGenericConfig<F>
513 ) -> usize {
514 let inner_tolerance = ε * config.inner.tolerance_mult;
515 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
516 let inner_τ = config.inner.τ0 / mA_normest;
517 quadratic_nonneg(config.inner.method, mA, g, τ * self.α(), x,
518 inner_τ, inner_it)
519 }
520
521 #[inline]
522 fn find_tolerance_violation<G, BT>(
523 &self,
524 d : &mut BTFN<F, G, BT, N>,
525 τ : F,
526 ε : F,
527 skip_by_rough_check : bool,
528 config : &FBGenericConfig<F>,
529 ) -> Option<(Loc<F, N>, F, bool)>
530 where BT : BTSearch<F, N, Agg=Bounds<F>>,
531 G : SupportGenerator<F, N, Id=BT::Data>,
532 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
533 + LocalAnalysis<F, Bounds<F>, N> {
534 let τα = τ * self.α();
535 let keep_below = τα + ε;
536 let maximise_above = τα + ε * config.insertion_cutoff_factor;
537 let refinement_tolerance = ε * config.refinement.tolerance_mult;
538
539 // If preliminary check indicates that we are in bonds, and if it otherwise matches
540 // the insertion strategy, skip insertion.
541 if skip_by_rough_check && d.bounds().upper() <= keep_below {
542 None
543 } else {
544 // If the rough check didn't indicate no insertion needed, find maximising point.
545 d.maximise_above(maximise_above, refinement_tolerance, config.refinement.max_steps)
546 .map(|(ξ, v_ξ)| (ξ, v_ξ, v_ξ <= keep_below))
547 }
548 }
549
550 fn verify_merge_candidate<G, BT>(
551 &self,
552 d : &mut BTFN<F, G, BT, N>,
553 μ : &DiscreteMeasure<Loc<F, N>, F>,
554 τ : F,
555 ε : F,
556 config : &FBGenericConfig<F>,
557 ) -> bool
558 where BT : BTSearch<F, N, Agg=Bounds<F>>,
559 G : SupportGenerator<F, N, Id=BT::Data>,
560 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
561 + LocalAnalysis<F, Bounds<F>, N> {
562 let τα = τ * self.α();
563 let refinement_tolerance = ε * config.refinement.tolerance_mult;
564 let merge_tolerance = config.merge_tolerance_mult * ε;
565 let keep_below = τα + merge_tolerance;
566 let keep_supp_above = τα - merge_tolerance;
567 let bnd = d.bounds();
568
569 return (
570 bnd.lower() >= keep_supp_above
571 ||
572 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
573 (β == 0.0) || d.apply(x) >= keep_supp_above
574 }).all(std::convert::identity)
575 ) && (
576 bnd.upper() <= keep_below
577 ||
578 d.has_upper_bound(keep_below, refinement_tolerance, config.refinement.max_steps)
579 )
580 }
581
582 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
583 let τα = τ * self.α();
584 Some(Bounds(τα - ε, τα + ε))
585 }
586
587 fn tolerance_scaling(&self) -> F {
588 self.α()
589 }
590 }
591
592 #[replace_float_literals(F::cast_from(literal))]
593 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for RadonRegTerm<F>
594 where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
595 fn solve_findim(
596 &self,
597 mA : &DMatrix<F::MixedType>,
598 g : &DVector<F::MixedType>,
599 τ : F,
600 x : &mut DVector<F::MixedType>,
601 mA_normest: F,
602 ε : F,
603 config : &FBGenericConfig<F>
604 ) -> usize {
605 let inner_tolerance = ε * config.inner.tolerance_mult;
606 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
607 let inner_τ = config.inner.τ0 / mA_normest;
608 quadratic_unconstrained(config.inner.method, mA, g, τ * self.α(), x,
609 inner_τ, inner_it)
610 }
611
612 fn find_tolerance_violation<G, BT>(
613 &self,
614 d : &mut BTFN<F, G, BT, N>,
615 τ : F,
616 ε : F,
617 skip_by_rough_check : bool,
618 config : &FBGenericConfig<F>,
619 ) -> Option<(Loc<F, N>, F, bool)>
620 where BT : BTSearch<F, N, Agg=Bounds<F>>,
621 G : SupportGenerator<F, N, Id=BT::Data>,
622 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
623 + LocalAnalysis<F, Bounds<F>, N> {
624 let τα = τ * self.α();
625 let keep_below = τα + ε;
626 let keep_above = -τα - ε;
627 let maximise_above = τα + ε * config.insertion_cutoff_factor;
628 let minimise_below = -τα - ε * config.insertion_cutoff_factor;
629 let refinement_tolerance = ε * config.refinement.tolerance_mult;
630
631 // If preliminary check indicates that we are in bonds, and if it otherwise matches
632 // the insertion strategy, skip insertion.
633 if skip_by_rough_check && Bounds(keep_above, keep_below).superset(&d.bounds()) {
634 None
635 } else {
636 // If the rough check didn't indicate no insertion needed, find maximising point.
637 let mx = d.maximise_above(maximise_above, refinement_tolerance,
638 config.refinement.max_steps);
639 let mi = d.minimise_below(minimise_below, refinement_tolerance,
640 config.refinement.max_steps);
641
642 match (mx, mi) {
643 (None, None) => None,
644 (Some((ξ, v_ξ)), None) => Some((ξ, v_ξ, keep_below >= v_ξ)),
645 (None, Some((ζ, v_ζ))) => Some((ζ, v_ζ, keep_above <= v_ζ)),
646 (Some((ξ, v_ξ)), Some((ζ, v_ζ))) => {
647 if v_ξ - τα > τα - v_ζ {
648 Some((ξ, v_ξ, keep_below >= v_ξ))
649 } else {
650 Some((ζ, v_ζ, keep_above <= v_ζ))
651 }
652 }
653 }
654 }
655 }
656
657 fn verify_merge_candidate<G, BT>(
658 &self,
659 d : &mut BTFN<F, G, BT, N>,
660 μ : &DiscreteMeasure<Loc<F, N>, F>,
661 τ : F,
662 ε : F,
663 config : &FBGenericConfig<F>,
664 ) -> bool
665 where BT : BTSearch<F, N, Agg=Bounds<F>>,
666 G : SupportGenerator<F, N, Id=BT::Data>,
667 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
668 + LocalAnalysis<F, Bounds<F>, N> {
669 let τα = τ * self.α();
670 let refinement_tolerance = ε * config.refinement.tolerance_mult;
671 let merge_tolerance = config.merge_tolerance_mult * ε;
672 let keep_below = τα + merge_tolerance;
673 let keep_above = -τα - merge_tolerance;
674 let keep_supp_pos_above = τα - merge_tolerance;
675 let keep_supp_neg_below = -τα + merge_tolerance;
676 let bnd = d.bounds();
677
678 return (
679 (bnd.lower() >= keep_supp_pos_above && bnd.upper() <= keep_supp_neg_below)
680 ||
681 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
682 use std::cmp::Ordering::*;
683 match β.partial_cmp(&0.0) {
684 Some(Greater) => d.apply(x) >= keep_supp_pos_above,
685 Some(Less) => d.apply(x) <= keep_supp_neg_below,
686 _ => true,
687 }
688 }).all(std::convert::identity)
689 ) && (
690 bnd.upper() <= keep_below
691 ||
692 d.has_upper_bound(keep_below, refinement_tolerance,
693 config.refinement.max_steps)
694 ) && (
695 bnd.lower() >= keep_above
696 ||
697 d.has_lower_bound(keep_above, refinement_tolerance,
698 config.refinement.max_steps)
699 )
700 }
701
702 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
703 let τα = τ * self.α();
704 Some(Bounds(-τα - ε, τα + ε))
705 }
706
707 fn tolerance_scaling(&self) -> F {
708 self.α()
709 }
710 }
711
712
713 /// Generic implementation of [`pointsource_fb_reg`].
714 ///
715 /// The method can be specialised to even primal-dual proximal splitting through the
716 /// [`FBSpecialisation`] parameter `specialisation`.
717 /// The settings in `config` have their [respective documentation](FBGenericConfig). `opA` is the
718 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
719 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
720 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
721 /// as documented in [`alg_tools::iterate`].
722 ///
723 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
724 /// sums of simple functions usign bisection trees, and the related
725 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
726 /// active at a specific points, and to maximise their sums. Through the implementation of the
727 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
728 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
729 ///
730 /// Returns the final iterate.
731 #[replace_float_literals(F::cast_from(literal))]
732 pub fn generic_pointsource_fb_reg<
733 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, Reg, const N : usize
734 >(
735 opA : &'a A,
736 reg : Reg,
737 op𝒟 : &'a 𝒟,
738 mut τ : F,
739 config : &FBGenericConfig<F>,
740 iterator : I,
741 mut plotter : SeqPlotter<F, N>,
742 mut residual : A::Observable,
743 mut specialisation : Spec
744 ) -> DiscreteMeasure<Loc<F, N>, F>
745 where F : Float + ToNalgebraRealField,
746 I : AlgIteratorFactory<IterInfo<F, N>>,
747 Spec : FBSpecialisation<F, A::Observable, N>,
748 A::Observable : std::ops::MulAssign<F>,
749 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
750 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
751 + Lipschitz<𝒟, FloatType=F>,
752 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
753 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
754 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
755 𝒟::Codomain : RealMapping<F, N>,
756 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
757 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
758 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
759 PlotLookup : Plotting<N>,
760 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
761 Reg : RegTerm<F, N> {
762
763 // Set up parameters
764 let quiet = iterator.is_quiet();
765 let op𝒟norm = op𝒟.opnorm_bound();
766 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
767 // by τ compared to the conditional gradient approach.
768 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
769 let mut ε = tolerance.initial();
770
771 // Initialise operators
772 let preadjA = opA.preadjoint();
773
774 // Initialise iterates
775 let mut μ = DiscreteMeasure::new();
776
777 let mut inner_iters = 0;
778 let mut this_iters = 0;
779 let mut pruned = 0;
780 let mut merged = 0;
781
782 let μ_diff = |μ_new : &DiscreteMeasure<Loc<F, N>, F>,
783 μ_base : &DiscreteMeasure<Loc<F, N>, F>| {
784 let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style {
785 InsertionStyle::Reuse => {
786 μ_new.iter_spikes()
787 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0)))
788 .map(|(δ, α_base)| (δ.x, α_base - δ.α))
789 .collect()
790 },
791 InsertionStyle::Zero => {
792 μ_new.iter_spikes()
793 .map(|δ| -δ)
794 .chain(μ_base.iter_spikes().copied())
795 .collect()
796 }
797 };
798 ν.prune(); // Potential small performance improvement
799 ν
800 };
801
802 // Run the algorithm
803 iterator.iterate(|state| {
804 // Maximum insertion count and measure difference calculation depend on insertion style.
805 let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
806 (i, Some((l, k))) if i <= l => (k, false),
807 _ => (config.max_insertions, !quiet),
808 };
809 let max_insertions = match config.insertion_style {
810 InsertionStyle::Zero => {
811 todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled.");
812 // let n = μ.len();
813 // μ = DiscreteMeasure::new();
814 // n + m
815 },
816 InsertionStyle::Reuse => m,
817 };
818
819 // Calculate smooth part of surrogate model.
820 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
821 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
822 // the residual and replacing it below before the end of this closure.
823 residual *= -τ;
824 let r = std::mem::replace(&mut residual, opA.empty_observable());
825 let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b)
826 // TODO: should avoid a second copy of μ here; μ_base already stores a copy.
827 let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k
828 //let g = &minus_τv + ω0; // Linear term of surrogate model
829
830 // Save current base point
831 let μ_base = μ.clone();
832
833 // Add points to support until within error tolerance or maximum insertion count reached.
834 let mut count = 0;
835 let (within_tolerances, d) = 'insertion: loop {
836 if μ.len() > 0 {
837 // Form finite-dimensional subproblem. The subproblem references to the original μ^k
838 // from the beginning of the iteration are all contained in the immutable c and g.
839 let à = op𝒟.findim_matrix(μ.iter_locations());
840 let g̃ = DVector::from_iterator(μ.len(),
841 μ.iter_locations()
842 .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ))
843 .map(F::to_nalgebra_mixed));
844 let mut x = μ.masses_dvector();
845
846 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
847 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
848 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
849 // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2
850 // = n |𝒟| |x|_2, where n is the number of points. Therefore
851 let Ã_normest = op𝒟norm * F::cast_from(μ.len());
852
853 // Solve finite-dimensional subproblem.
854 inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config);
855
856 // Update masses of μ based on solution of finite-dimensional subproblem.
857 μ.set_masses_dvector(&x);
858 }
859
860 // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality
861 // conditions in the predual space, and finding new points for insertion, if necessary.
862 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ, &μ_base));
863
864 // If no merging heuristic is used, let's be more conservative about spike insertion,
865 // and skip it after first round. If merging is done, being more greedy about spike
866 // insertion also seems to improve performance.
867 let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging {
868 false
869 } else {
870 count > 0
871 };
872
873 // Find a spike to insert, if needed
874 let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation(
875 &mut d, τ, ε, skip_by_rough_check, config
876 ) {
877 None => break 'insertion (true, d),
878 Some(res) => res,
879 };
880
881 // Break if maximum insertion count reached
882 if count >= max_insertions {
883 break 'insertion (in_bounds, d)
884 }
885
886 // No point in optimising the weight here; the finite-dimensional algorithm is fast.
887 μ += DeltaMeasure { x : ξ, α : 0.0 };
888 count += 1;
889 };
890
891 if !within_tolerances && warn_insertions {
892 // Complain (but continue) if we failed to get within tolerances
893 // by inserting more points.
894 let err = format!("Maximum insertions reached without achieving \
895 subproblem solution tolerance");
896 println!("{}", err.red());
897 }
898
899 // Merge spikes
900 if state.iteration() % config.merge_every == 0 {
901 let n_before_merge = μ.len();
902 μ.merge_spikes(config.merging, |μ_candidate| {
903 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ_candidate, &μ_base));
904
905 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
906 .then_some(())
907 });
908 debug_assert!(μ.len() >= n_before_merge);
909 merged += μ.len() - n_before_merge;
910 }
911
912 let n_before_prune = μ.len();
913 (residual, τ) = match specialisation.update(&mut μ, &μ_base) {
914 (r, None) => (r, τ),
915 (r, Some(new_τ)) => (r, new_τ)
916 };
917 debug_assert!(μ.len() <= n_before_prune);
918 pruned += n_before_prune - μ.len();
919
920 this_iters += 1;
921
922 // Update main tolerance for next iteration
923 let ε_prev = ε;
924 ε = tolerance.update(ε, state.iteration());
925
926 // Give function value if needed
927 state.if_verbose(|| {
928 let value_μ = specialisation.value_μ(&μ);
929 // Plot if so requested
930 plotter.plot_spikes(
931 format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
932 "start".to_string(), Some(&minus_τv),
933 reg.target_bounds(τ, ε_prev), value_μ,
934 );
935 // Calculate mean inner iterations and reset relevant counters.
936 // Return the statistics
937 let res = IterInfo {
938 value : specialisation.calculate_fit(&μ, &residual) + reg.apply(value_μ),
939 n_spikes : value_μ.len(),
940 inner_iters,
941 this_iters,
942 merged,
943 pruned,
944 ε : ε_prev,
945 postprocessing: config.postprocessing.then(|| value_μ.clone()),
946 };
947 inner_iters = 0;
948 this_iters = 0;
949 merged = 0;
950 pruned = 0;
951 res
952 })
953 });
954
955 specialisation.postprocess(μ, config.final_merging)
956 }
957
958 /// Iteratively solve the pointsource localisation problem using forward-backward splitting
959 ///
960 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
961 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
962 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
963 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
964 /// as documented in [`alg_tools::iterate`].
965 ///
966 /// For details on the mathematical formulation, see the [module level](self) documentation.
967 ///
968 /// Returns the final iterate.
969 #[replace_float_literals(F::cast_from(literal))]
970 pub fn pointsource_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>(
971 opA : &'a A,
972 b : &A::Observable,
973 reg : Reg,
974 op𝒟 : &'a 𝒟,
975 config : &FBConfig<F>,
976 iterator : I,
977 plotter : SeqPlotter<F, N>,
978 ) -> DiscreteMeasure<Loc<F, N>, F>
979 where F : Float + ToNalgebraRealField,
980 I : AlgIteratorFactory<IterInfo<F, N>>,
981 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
982 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
983 A::Observable : std::ops::MulAssign<F>,
984 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
985 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
986 + Lipschitz<𝒟, FloatType=F>,
987 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
988 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
989 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
990 𝒟::Codomain : RealMapping<F, N>,
991 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
992 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
993 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
994 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
995 PlotLookup : Plotting<N>,
996 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
997 Reg : RegTerm<F, N> {
998
999 let initial_residual = -b;
1000 let τ = config.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
1001
1002 match config.meta {
1003 FBMetaAlgorithm::None => generic_pointsource_fb_reg(
1004 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual,
1005 BasicFB{ b, opA },
1006 ),
1007 FBMetaAlgorithm::InertiaFISTA => generic_pointsource_fb_reg(
1008 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual,
1009 FISTA{ b, opA, λ : 1.0, μ_prev : DiscreteMeasure::new() },
1010 ),
1011 }
1012 }
1013
1014 //
1015 // Deprecated interfaces
1016 //
1017
1018 #[deprecated(note = "Use `pointsource_fb_reg`")]
1019 pub fn pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, const N : usize>(
1020 opA : &'a A,
1021 b : &A::Observable,
1022 α : F,
1023 op𝒟 : &'a 𝒟,
1024 config : &FBConfig<F>,
1025 iterator : I,
1026 plotter : SeqPlotter<F, N>
1027 ) -> DiscreteMeasure<Loc<F, N>, F>
1028 where F : Float + ToNalgebraRealField,
1029 I : AlgIteratorFactory<IterInfo<F, N>>,
1030 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
1031 A::Observable : std::ops::MulAssign<F>,
1032 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
1033 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
1034 + Lipschitz<𝒟, FloatType=F>,
1035 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
1036 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
1037 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
1038 𝒟::Codomain : RealMapping<F, N>,
1039 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
1040 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
1041 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
1042 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
1043 PlotLookup : Plotting<N>,
1044 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
1045
1046 pointsource_fb_reg(opA, b, NonnegRadonRegTerm(α), op𝒟, config, iterator, plotter)
1047 }
1048
1049
1050 #[deprecated(note = "Use `generic_pointsource_fb_reg`")]
1051 pub fn generic_pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, const N : usize>(
1052 opA : &'a A,
1053 α : F,
1054 op𝒟 : &'a 𝒟,
1055 τ : F,
1056 config : &FBGenericConfig<F>,
1057 iterator : I,
1058 plotter : SeqPlotter<F, N>,
1059 residual : A::Observable,
1060 specialisation : Spec,
1061 ) -> DiscreteMeasure<Loc<F, N>, F>
1062 where F : Float + ToNalgebraRealField,
1063 I : AlgIteratorFactory<IterInfo<F, N>>,
1064 Spec : FBSpecialisation<F, A::Observable, N>,
1065 A::Observable : std::ops::MulAssign<F>,
1066 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
1067 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
1068 + Lipschitz<𝒟, FloatType=F>,
1069 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
1070 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
1071 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
1072 𝒟::Codomain : RealMapping<F, N>,
1073 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
1074 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
1075 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
1076 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
1077 PlotLookup : Plotting<N>,
1078 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
1079
1080 generic_pointsource_fb_reg(opA, NonnegRadonRegTerm(α), op𝒟, τ, config, iterator, plotter,
1081 residual, specialisation)
1082 }

mercurial