Wed, 30 Nov 2022 23:45:04 +0200
Sketch FBGenericConfig clap
0 | 1 | /*! |
2 | Solver for the point source localisation problem using a forward-backward splitting method. | |
3 | ||
4 | This corresponds to the manuscript | |
5 | ||
6 | * Valkonen T. - _Proximal methods for point source localisation_. ARXIV TO INSERT. | |
7 | ||
8 | The main routine is [`pointsource_fb`]. It is based on [`generic_pointsource_fb`], which is also | |
9 | used by our [primal-dual proximal splitting][crate::pdps] implementation. | |
10 | ||
11 | FISTA-type inertia can also be enabled through [`FBConfig::meta`]. | |
12 | ||
13 | ## Problem | |
14 | ||
15 | <p> | |
16 | Our objective is to solve | |
17 | $$ | |
18 | \min_{μ ∈ ℳ(Ω)}~ F_0(Aμ-b) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ), | |
19 | $$ | |
20 | where $F_0(y)=\frac{1}{2}\|y\|_2^2$ and the forward operator $A \in 𝕃(ℳ(Ω); ℝ^n)$. | |
21 | </p> | |
22 | ||
23 | ## Approach | |
24 | ||
25 | <p> | |
26 | As documented in more detail in the paper, on each step we approximately solve | |
27 | $$ | |
28 | \min_{μ ∈ ℳ(Ω)}~ F(x) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(x) + \frac{1}{2}\|μ-μ^k|_𝒟^2, | |
29 | $$ | |
30 | where $𝒟: 𝕃(ℳ(Ω); C_c(Ω))$ is typically a convolution operator. | |
31 | </p> | |
32 | ||
33 | ## Finite-dimensional subproblems. | |
34 | ||
35 | With $C$ a projection from [`DiscreteMeasure`] to the weights, and $x^k$ such that $x^k=Cμ^k$, we | |
36 | form the discretised linearised inner problem | |
37 | <p> | |
38 | $$ | |
39 | \min_{x ∈ ℝ^n}~ τ\bigl(F(Cx^k) + [C^*∇F(Cx^k)]^⊤(x-x^k) + α {\vec 1}^⊤ x\bigr) | |
40 | + δ_{≥ 0}(x) + \frac{1}{2}\|x-x^k\|_{C^*𝒟C}^2, | |
41 | $$ | |
42 | equivalently | |
43 | $$ | |
44 | \begin{aligned} | |
45 | \min_x~ & τF(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k | |
46 | \\ | |
47 | & | |
48 | - [C^*𝒟C x^k - τC^*∇F(Cx^k)]^⊤ x | |
49 | \\ | |
50 | & | |
51 | + \frac{1}{2} x^⊤ C^*𝒟C x | |
52 | + τα {\vec 1}^⊤ x + δ_{≥ 0}(x), | |
53 | \end{aligned} | |
54 | $$ | |
55 | In other words, we obtain the quadratic non-negativity constrained problem | |
56 | $$ | |
57 | \min_{x ∈ ℝ^n}~ \frac{1}{2} x^⊤ Ã x - b̃^⊤ x + c + τα {\vec 1}^⊤ x + δ_{≥ 0}(x). | |
58 | $$ | |
59 | where | |
60 | $$ | |
61 | \begin{aligned} | |
62 | Ã & = C^*𝒟C, | |
63 | \\ | |
64 | g̃ & = C^*𝒟C x^k - τ C^*∇F(Cx^k) | |
65 | = C^* 𝒟 μ^k - τ C^*A^*(Aμ^k - b) | |
66 | \\ | |
67 | c & = τ F(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k | |
68 | \\ | |
69 | & | |
70 | = \frac{τ}{2} \|Aμ^k-b\|^2 - τ[Aμ^k-b]^⊤Aμ^k + \frac{1}{2} \|μ_k\|_{𝒟}^2 | |
71 | \\ | |
72 | & | |
73 | = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2. | |
74 | \end{aligned} | |
75 | $$ | |
76 | </p> | |
77 | ||
78 | We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by | |
79 | [`InnerSettings`] in [`FBGenericConfig::inner`]. | |
80 | */ | |
81 | ||
82 | use numeric_literals::replace_float_literals; | |
83 | use std::cmp::Ordering::*; | |
84 | use serde::{Serialize, Deserialize}; | |
85 | use colored::Colorize; | |
86 | use nalgebra::DVector; | |
1 | 87 | use clap::Parser; |
0 | 88 | |
89 | use alg_tools::iterate::{ | |
90 | AlgIteratorFactory, | |
91 | AlgIteratorState, | |
92 | }; | |
93 | use alg_tools::euclidean::Euclidean; | |
94 | use alg_tools::norms::Norm; | |
95 | use alg_tools::linops::Apply; | |
96 | use alg_tools::sets::Cube; | |
97 | use alg_tools::loc::Loc; | |
98 | use alg_tools::bisection_tree::{ | |
99 | BTFN, | |
100 | PreBTFN, | |
101 | Bounds, | |
102 | BTNodeLookup, | |
103 | BTNode, | |
104 | BTSearch, | |
105 | P2Minimise, | |
106 | SupportGenerator, | |
107 | LocalAnalysis, | |
108 | Bounded, | |
109 | }; | |
110 | use alg_tools::mapping::RealMapping; | |
111 | use alg_tools::nalgebra_support::ToNalgebraRealField; | |
112 | ||
113 | use crate::types::*; | |
114 | use crate::measures::{ | |
115 | DiscreteMeasure, | |
116 | DeltaMeasure, | |
117 | Radon | |
118 | }; | |
119 | use crate::measures::merging::{ | |
120 | SpikeMergingMethod, | |
121 | SpikeMerging, | |
122 | }; | |
123 | use crate::forward_model::ForwardModel; | |
124 | use crate::seminorms::{ | |
125 | DiscreteMeasureOp, Lipschitz | |
126 | }; | |
127 | use crate::subproblem::{ | |
128 | quadratic_nonneg, | |
129 | InnerSettings, | |
130 | InnerMethod, | |
131 | }; | |
132 | use crate::tolerance::Tolerance; | |
133 | use crate::plot::{ | |
134 | SeqPlotter, | |
135 | Plotting, | |
136 | PlotLookup | |
137 | }; | |
138 | ||
139 | /// Method for constructing $μ$ on each iteration | |
140 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
141 | #[allow(dead_code)] | |
142 | pub enum InsertionStyle { | |
143 | /// Resuse previous $μ$ from previous iteration, optimising weights | |
144 | /// before inserting new spikes. | |
145 | Reuse, | |
146 | /// Start each iteration with $μ=0$. | |
147 | Zero, | |
148 | } | |
149 | ||
1 | 150 | impl Default for InsertionStyle { |
151 | fn default() -> Self { | |
152 | Self::Reuse | |
153 | } | |
154 | } | |
0 | 155 | /// Meta-algorithm type |
156 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
157 | #[allow(dead_code)] | |
158 | pub enum FBMetaAlgorithm { | |
159 | /// No meta-algorithm | |
160 | None, | |
161 | /// FISTA-style inertia | |
162 | InertiaFISTA, | |
163 | } | |
164 | ||
165 | /// Ergodic tolerance application style | |
166 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
167 | #[allow(dead_code)] | |
168 | pub enum ErgodicTolerance<F> { | |
169 | /// Non-ergodic iteration-wise tolerance | |
170 | NonErgodic, | |
171 | /// Bound after `n`th iteration to `factor` times value on that iteration. | |
172 | AfterNth{ n : usize, factor : F }, | |
173 | } | |
174 | ||
1 | 175 | impl<F : ClapFloat> Default for ErgodicTolerance<F> { |
176 | fn default() -> Self { | |
177 | Self::NonErgodic | |
178 | } | |
179 | } | |
180 | ||
0 | 181 | /// Settings for [`pointsource_fb`]. |
182 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
183 | #[serde(default)] | |
1 | 184 | pub struct FBConfig<F : ClapFloat> { |
0 | 185 | /// Step length scaling |
186 | pub τ0 : F, | |
187 | /// Meta-algorithm to apply | |
188 | pub meta : FBMetaAlgorithm, | |
189 | /// Generic parameters | |
190 | pub insertion : FBGenericConfig<F>, | |
191 | } | |
192 | ||
193 | /// Settings for the solution of the stepwise optimality condition in algorithms based on | |
194 | /// [`generic_pointsource_fb`]. | |
1 | 195 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug, Parser)] |
0 | 196 | #[serde(default)] |
1 | 197 | pub struct FBGenericConfig<F : ClapFloat> { |
198 | #[clap(skip)] | |
0 | 199 | /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`]. |
200 | pub insertion_style : InsertionStyle, | |
1 | 201 | #[clap(skip)] |
0 | 202 | /// Tolerance for point insertion. |
203 | pub tolerance : Tolerance<F>, | |
204 | /// Stop looking for predual maximum (where to isert a new point) below | |
205 | /// `tolerance` multiplied by this factor. | |
206 | pub insertion_cutoff_factor : F, | |
1 | 207 | #[clap(skip)] |
0 | 208 | /// Apply tolerance ergodically |
209 | pub ergodic_tolerance : ErgodicTolerance<F>, | |
1 | 210 | #[clap(skip)] |
0 | 211 | /// Settings for branch and bound refinement when looking for predual maxima |
212 | pub refinement : RefinementSettings<F>, | |
213 | /// Maximum insertions within each outer iteration | |
214 | pub max_insertions : usize, | |
215 | /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. | |
216 | pub bootstrap_insertions : Option<(usize, usize)>, | |
1 | 217 | #[clap(skip)] |
0 | 218 | /// Inner method settings |
219 | pub inner : InnerSettings<F>, | |
220 | /// Spike merging method | |
221 | pub merging : SpikeMergingMethod<F>, | |
222 | /// Tolerance multiplier for merges | |
223 | pub merge_tolerance_mult : F, | |
224 | /// Spike merging method after the last step | |
225 | pub final_merging : SpikeMergingMethod<F>, | |
226 | /// Iterations between merging heuristic tries | |
227 | pub merge_every : usize, | |
228 | /// Save $μ$ for postprocessing optimisation | |
229 | pub postprocessing : bool | |
230 | } | |
231 | ||
232 | #[replace_float_literals(F::cast_from(literal))] | |
1 | 233 | impl<F : ClapFloat> Default for FBConfig<F> { |
0 | 234 | fn default() -> Self { |
235 | FBConfig { | |
236 | τ0 : 0.99, | |
237 | meta : FBMetaAlgorithm::None, | |
238 | insertion : Default::default() | |
239 | } | |
240 | } | |
241 | } | |
242 | ||
243 | #[replace_float_literals(F::cast_from(literal))] | |
1 | 244 | impl<F : ClapFloat> Default for FBGenericConfig<F> { |
0 | 245 | fn default() -> Self { |
246 | FBGenericConfig { | |
247 | insertion_style : InsertionStyle::Reuse, | |
248 | tolerance : Default::default(), | |
249 | insertion_cutoff_factor : 1.0, | |
250 | ergodic_tolerance : ErgodicTolerance::NonErgodic, | |
251 | refinement : Default::default(), | |
252 | max_insertions : 100, | |
253 | //bootstrap_insertions : None, | |
254 | bootstrap_insertions : Some((10, 1)), | |
255 | inner : InnerSettings { | |
256 | method : InnerMethod::SSN, | |
257 | .. Default::default() | |
258 | }, | |
259 | merging : SpikeMergingMethod::None, | |
260 | //merging : Default::default(), | |
261 | final_merging : Default::default(), | |
262 | merge_every : 10, | |
263 | merge_tolerance_mult : 2.0, | |
264 | postprocessing : false, | |
265 | } | |
266 | } | |
267 | } | |
268 | ||
269 | /// Trait for specialisation of [`generic_pointsource_fb`] to basic FB, FISTA. | |
270 | /// | |
271 | /// The idea is that the residual $Aμ - b$ in the forward step can be replaced by an arbitrary | |
272 | /// value. For example, to implement [primal-dual proximal splitting][crate::pdps] we replace it | |
273 | /// with the dual variable $y$. We can then also implement alternative data terms, as the | |
274 | /// (pre)differential of $F(μ)=F\_0(Aμ-b)$ is $F\'(μ) = A\_*F\_0\'(Aμ-b)$. In the case of the | |
275 | /// quadratic fidelity $F_0(y)=\frac{1}{2}\\|y\\|_2^2$ in a Hilbert space, of course, | |
276 | /// $F\_0\'(Aμ-b)=Aμ-b$ is the residual. | |
277 | pub trait FBSpecialisation<F : Float, Observable : Euclidean<F>, const N : usize> : Sized { | |
278 | /// Updates the residual and does any necessary pruning of `μ`. | |
279 | /// | |
280 | /// Returns the new residual and possibly a new step length. | |
281 | /// | |
282 | /// The measure `μ` may also be modified to apply, e.g., inertia to it. | |
283 | /// The updated residual should correspond to the residual at `μ`. | |
284 | /// See the [trait documentation][FBSpecialisation] for the use and meaning of the residual. | |
285 | /// | |
286 | /// The parameter `μ_base` is the base point of the iteration, typically the previous iterate, | |
287 | /// but for, e.g., FISTA has inertia applied to it. | |
288 | fn update( | |
289 | &mut self, | |
290 | μ : &mut DiscreteMeasure<Loc<F, N>, F>, | |
291 | μ_base : &DiscreteMeasure<Loc<F, N>, F>, | |
292 | ) -> (Observable, Option<F>); | |
293 | ||
294 | /// Calculates the data term value corresponding to iterate `μ` and available residual. | |
295 | /// | |
296 | /// Inertia and other modifications, as deemed, necessary, should be applied to `μ`. | |
297 | /// | |
298 | /// The blanket implementation correspondsn to the 2-norm-squared data fidelity | |
299 | /// $\\|\text{residual}\\|\_2^2/2$. | |
300 | fn calculate_fit( | |
301 | &self, | |
302 | _μ : &DiscreteMeasure<Loc<F, N>, F>, | |
303 | residual : &Observable | |
304 | ) -> F { | |
305 | residual.norm2_squared_div2() | |
306 | } | |
307 | ||
308 | /// Calculates the data term value at $μ$. | |
309 | /// | |
310 | /// Unlike [`Self::calculate_fit`], no inertia, etc., should be applied to `μ`. | |
311 | fn calculate_fit_simple( | |
312 | &self, | |
313 | μ : &DiscreteMeasure<Loc<F, N>, F>, | |
314 | ) -> F; | |
315 | ||
316 | /// Returns the final iterate after any necessary postprocess pruning, merging, etc. | |
317 | fn postprocess(self, mut μ : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>) | |
318 | -> DiscreteMeasure<Loc<F, N>, F> | |
319 | where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { | |
320 | μ.merge_spikes_fitness(merging, | |
321 | |μ̃| self.calculate_fit_simple(μ̃), | |
322 | |&v| v); | |
323 | μ.prune(); | |
324 | μ | |
325 | } | |
326 | ||
327 | /// Returns measure to be used for value calculations, which may differ from μ. | |
328 | fn value_μ<'c, 'b : 'c>(&'b self, μ : &'c DiscreteMeasure<Loc<F, N>, F>) | |
329 | -> &'c DiscreteMeasure<Loc<F, N>, F> { | |
330 | μ | |
331 | } | |
332 | } | |
333 | ||
334 | /// Specialisation of [`generic_pointsource_fb`] to basic μFB. | |
335 | struct BasicFB< | |
336 | 'a, | |
337 | F : Float + ToNalgebraRealField, | |
338 | A : ForwardModel<Loc<F, N>, F>, | |
339 | const N : usize | |
340 | > { | |
341 | /// The data | |
342 | b : &'a A::Observable, | |
343 | /// The forward operator | |
344 | opA : &'a A, | |
345 | } | |
346 | ||
347 | /// Implementation of [`FBSpecialisation`] for basic μFB forward-backward splitting. | |
348 | #[replace_float_literals(F::cast_from(literal))] | |
349 | impl<'a, F : Float + ToNalgebraRealField , A : ForwardModel<Loc<F, N>, F>, const N : usize> | |
350 | FBSpecialisation<F, A::Observable, N> for BasicFB<'a, F, A, N> { | |
351 | fn update( | |
352 | &mut self, | |
353 | μ : &mut DiscreteMeasure<Loc<F, N>, F>, | |
354 | _μ_base : &DiscreteMeasure<Loc<F, N>, F> | |
355 | ) -> (A::Observable, Option<F>) { | |
356 | μ.prune(); | |
357 | //*residual = self.opA.apply(μ) - self.b; | |
358 | let mut residual = self.b.clone(); | |
359 | self.opA.gemv(&mut residual, 1.0, μ, -1.0); | |
360 | (residual, None) | |
361 | } | |
362 | ||
363 | fn calculate_fit_simple( | |
364 | &self, | |
365 | μ : &DiscreteMeasure<Loc<F, N>, F>, | |
366 | ) -> F { | |
367 | let mut residual = self.b.clone(); | |
368 | self.opA.gemv(&mut residual, 1.0, μ, -1.0); | |
369 | residual.norm2_squared_div2() | |
370 | } | |
371 | } | |
372 | ||
373 | /// Specialisation of [`generic_pointsource_fb`] to FISTA. | |
374 | struct FISTA< | |
375 | 'a, | |
376 | F : Float + ToNalgebraRealField, | |
377 | A : ForwardModel<Loc<F, N>, F>, | |
378 | const N : usize | |
379 | > { | |
380 | /// The data | |
381 | b : &'a A::Observable, | |
382 | /// The forward operator | |
383 | opA : &'a A, | |
384 | /// Current inertial parameter | |
385 | λ : F, | |
386 | /// Previous iterate without inertia applied. | |
387 | /// We need to store this here because `μ_base` passed to [`FBSpecialisation::update`] will | |
388 | /// have inertia applied to it, so is not useful to use. | |
389 | μ_prev : DiscreteMeasure<Loc<F, N>, F>, | |
390 | } | |
391 | ||
392 | /// Implementation of [`FBSpecialisation`] for μFISTA inertial forward-backward splitting. | |
393 | #[replace_float_literals(F::cast_from(literal))] | |
394 | impl<'a, F : Float + ToNalgebraRealField, A : ForwardModel<Loc<F, N>, F>, const N : usize> | |
395 | FBSpecialisation<F, A::Observable, N> for FISTA<'a, F, A, N> { | |
396 | fn update( | |
397 | &mut self, | |
398 | μ : &mut DiscreteMeasure<Loc<F, N>, F>, | |
399 | _μ_base : &DiscreteMeasure<Loc<F, N>, F> | |
400 | ) -> (A::Observable, Option<F>) { | |
401 | // Update inertial parameters | |
402 | let λ_prev = self.λ; | |
403 | self.λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); | |
404 | let θ = self.λ / λ_prev - self.λ; | |
405 | // Perform inertial update on μ. | |
406 | // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ | |
407 | // and μ_prev have zero weight. Since both have weights from the finite-dimensional | |
408 | // subproblem with a proximal projection step, this is likely to happen when the | |
409 | // spike is not needed. A copy of the pruned μ without artithmetic performed is | |
410 | // stored in μ_prev. | |
411 | μ.pruning_sub(1.0 + θ, θ, &mut self.μ_prev); | |
412 | ||
413 | //*residual = self.opA.apply(μ) - self.b; | |
414 | let mut residual = self.b.clone(); | |
415 | self.opA.gemv(&mut residual, 1.0, μ, -1.0); | |
416 | (residual, None) | |
417 | } | |
418 | ||
419 | fn calculate_fit_simple( | |
420 | &self, | |
421 | μ : &DiscreteMeasure<Loc<F, N>, F>, | |
422 | ) -> F { | |
423 | let mut residual = self.b.clone(); | |
424 | self.opA.gemv(&mut residual, 1.0, μ, -1.0); | |
425 | residual.norm2_squared_div2() | |
426 | } | |
427 | ||
428 | fn calculate_fit( | |
429 | &self, | |
430 | _μ : &DiscreteMeasure<Loc<F, N>, F>, | |
431 | _residual : &A::Observable | |
432 | ) -> F { | |
433 | self.calculate_fit_simple(&self.μ_prev) | |
434 | } | |
435 | ||
436 | // For FISTA we need to do a final pruning as well, due to the limited | |
437 | // pruning that can be done on each step. | |
438 | fn postprocess(mut self, μ_base : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>) | |
439 | -> DiscreteMeasure<Loc<F, N>, F> | |
440 | where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { | |
441 | let mut μ = self.μ_prev; | |
442 | self.μ_prev = μ_base; | |
443 | μ.merge_spikes_fitness(merging, | |
444 | |μ̃| self.calculate_fit_simple(μ̃), | |
445 | |&v| v); | |
446 | μ.prune(); | |
447 | μ | |
448 | } | |
449 | ||
450 | fn value_μ<'c, 'b : 'c>(&'c self, _μ : &'c DiscreteMeasure<Loc<F, N>, F>) | |
451 | -> &'c DiscreteMeasure<Loc<F, N>, F> { | |
452 | &self.μ_prev | |
453 | } | |
454 | } | |
455 | ||
456 | /// Iteratively solve the pointsource localisation problem using forward-backward splitting | |
457 | /// | |
458 | /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the | |
459 | /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. | |
460 | /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution | |
461 | /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control | |
462 | /// as documented in [`alg_tools::iterate`]. | |
463 | /// | |
464 | /// For details on the mathematical formulation, see the [module level](self) documentation. | |
465 | /// | |
466 | /// Returns the final iterate. | |
467 | #[replace_float_literals(F::cast_from(literal))] | |
468 | pub fn pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, const N : usize>( | |
469 | opA : &'a A, | |
470 | b : &A::Observable, | |
471 | α : F, | |
472 | op𝒟 : &'a 𝒟, | |
473 | config : &FBConfig<F>, | |
474 | iterator : I, | |
475 | plotter : SeqPlotter<F, N> | |
476 | ) -> DiscreteMeasure<Loc<F, N>, F> | |
1 | 477 | where F : ClapFloat + ToNalgebraRealField, |
0 | 478 | I : AlgIteratorFactory<IterInfo<F, N>>, |
479 | for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, | |
480 | //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow | |
481 | A::Observable : std::ops::MulAssign<F>, | |
482 | GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, | |
483 | A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> | |
484 | + Lipschitz<𝒟, FloatType=F>, | |
485 | BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, | |
486 | G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, | |
487 | 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, | |
488 | 𝒟::Codomain : RealMapping<F, N>, | |
489 | S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
490 | K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
491 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, | |
492 | Cube<F, N>: P2Minimise<Loc<F, N>, F>, | |
493 | PlotLookup : Plotting<N>, | |
494 | DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { | |
495 | ||
496 | let initial_residual = -b; | |
497 | let τ = config.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); | |
498 | ||
499 | match config.meta { | |
500 | FBMetaAlgorithm::None => generic_pointsource_fb( | |
501 | opA, α, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, | |
502 | BasicFB{ b, opA } | |
503 | ), | |
504 | FBMetaAlgorithm::InertiaFISTA => generic_pointsource_fb( | |
505 | opA, α, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, | |
506 | FISTA{ b, opA, λ : 1.0, μ_prev : DiscreteMeasure::new() } | |
507 | ), | |
508 | } | |
509 | } | |
510 | ||
511 | /// Generic implementation of [`pointsource_fb`]. | |
512 | /// | |
513 | /// The method can be specialised to even primal-dual proximal splitting through the | |
514 | /// [`FBSpecialisation`] parameter `specialisation`. | |
515 | /// The settings in `config` have their [respective documentation](FBGenericConfig). `opA` is the | |
516 | /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. | |
517 | /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution | |
518 | /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control | |
519 | /// as documented in [`alg_tools::iterate`]. | |
520 | /// | |
521 | /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of | |
522 | /// sums of simple functions usign bisection trees, and the related | |
523 | /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions | |
524 | /// active at a specific points, and to maximise their sums. Through the implementation of the | |
525 | /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features | |
526 | /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. | |
527 | /// | |
528 | /// Returns the final iterate. | |
529 | #[replace_float_literals(F::cast_from(literal))] | |
530 | pub fn generic_pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, const N : usize>( | |
531 | opA : &'a A, | |
532 | α : F, | |
533 | op𝒟 : &'a 𝒟, | |
534 | mut τ : F, | |
535 | config : &FBGenericConfig<F>, | |
536 | iterator : I, | |
537 | mut plotter : SeqPlotter<F, N>, | |
538 | mut residual : A::Observable, | |
539 | mut specialisation : Spec, | |
540 | ) -> DiscreteMeasure<Loc<F, N>, F> | |
1 | 541 | where F : ClapFloat + ToNalgebraRealField, |
0 | 542 | I : AlgIteratorFactory<IterInfo<F, N>>, |
543 | Spec : FBSpecialisation<F, A::Observable, N>, | |
544 | A::Observable : std::ops::MulAssign<F>, | |
545 | GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, | |
546 | A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> | |
547 | + Lipschitz<𝒟, FloatType=F>, | |
548 | BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, | |
549 | G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, | |
550 | 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, | |
551 | 𝒟::Codomain : RealMapping<F, N>, | |
552 | S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
553 | K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
554 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, | |
555 | Cube<F, N>: P2Minimise<Loc<F, N>, F>, | |
556 | PlotLookup : Plotting<N>, | |
557 | DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { | |
558 | ||
559 | // Set up parameters | |
560 | let quiet = iterator.is_quiet(); | |
561 | let op𝒟norm = op𝒟.opnorm_bound(); | |
562 | // We multiply tolerance by τ for FB since | |
563 | // our subproblems depending on tolerances are scaled by τ compared to the conditional | |
564 | // gradient approach. | |
565 | let mut tolerance = config.tolerance * τ * α; | |
566 | let mut ε = tolerance.initial(); | |
567 | ||
568 | // Initialise operators | |
569 | let preadjA = opA.preadjoint(); | |
570 | ||
571 | // Initialise iterates | |
572 | let mut μ = DiscreteMeasure::new(); | |
573 | ||
574 | let mut after_nth_bound = F::INFINITY; | |
575 | // FIXME: Don't allocate if not needed. | |
576 | let mut after_nth_accum = opA.zero_observable(); | |
577 | ||
578 | let mut inner_iters = 0; | |
579 | let mut this_iters = 0; | |
580 | let mut pruned = 0; | |
581 | let mut merged = 0; | |
582 | ||
583 | let μ_diff = |μ_new : &DiscreteMeasure<Loc<F, N>, F>, | |
584 | μ_base : &DiscreteMeasure<Loc<F, N>, F>| { | |
585 | let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style { | |
586 | InsertionStyle::Reuse => { | |
587 | μ_new.iter_spikes() | |
588 | .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0))) | |
589 | .map(|(δ, α_base)| (δ.x, α_base - δ.α)) | |
590 | .collect() | |
591 | }, | |
592 | InsertionStyle::Zero => { | |
593 | μ_new.iter_spikes() | |
594 | .map(|δ| -δ) | |
595 | .chain(μ_base.iter_spikes().copied()) | |
596 | .collect() | |
597 | } | |
598 | }; | |
599 | ν.prune(); // Potential small performance improvement | |
600 | ν | |
601 | }; | |
602 | ||
603 | // Run the algorithm | |
604 | iterator.iterate(|state| { | |
605 | // Calculate subproblem tolerances, and update main tolerance for next iteration | |
606 | let τα = τ * α; | |
607 | // if μ.len() == 0 /*state.iteration() == 1*/ { | |
608 | // let t = minus_τv.bounds().upper() * 0.001; | |
609 | // if t > 0.0 { | |
610 | // let (ξ, v_ξ) = minus_τv.maximise(t, config.refinement.max_steps); | |
611 | // if τα + ε > v_ξ && v_ξ > τα { | |
612 | // // The zero measure is already within bounds, so improve them | |
613 | // tolerance = config.tolerance * (v_ξ - τα); | |
614 | // ε = tolerance.initial(); | |
615 | // } | |
616 | // μ += DeltaMeasure { x : ξ, α : 0.0 }; | |
617 | // } else { | |
618 | // // Zero is the solution. | |
619 | // return Step::Terminated | |
620 | // } | |
621 | // } | |
622 | let target_bounds = Bounds(τα - ε, τα + ε); | |
623 | let merge_tolerance = config.merge_tolerance_mult * ε; | |
624 | let merge_target_bounds = Bounds(τα - merge_tolerance, τα + merge_tolerance); | |
625 | let inner_tolerance = ε * config.inner.tolerance_mult; | |
626 | let refinement_tolerance = ε * config.refinement.tolerance_mult; | |
627 | let maximise_above = τα + ε * config.insertion_cutoff_factor; | |
628 | let mut ε1 = ε; | |
629 | let ε_prev = ε; | |
630 | ε = tolerance.update(ε, state.iteration()); | |
631 | ||
632 | // Maximum insertion count and measure difference calculation depend on insertion style. | |
633 | let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { | |
634 | (i, Some((l, k))) if i <= l => (k, false), | |
635 | _ => (config.max_insertions, !quiet), | |
636 | }; | |
637 | let max_insertions = match config.insertion_style { | |
638 | InsertionStyle::Zero => { | |
639 | todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled."); | |
640 | // let n = μ.len(); | |
641 | // μ = DiscreteMeasure::new(); | |
642 | // n + m | |
643 | }, | |
644 | InsertionStyle::Reuse => m, | |
645 | }; | |
646 | ||
647 | // Calculate smooth part of surrogate model. | |
648 | residual *= -τ; | |
649 | if let ErgodicTolerance::AfterNth{ .. } = config.ergodic_tolerance { | |
650 | // Negative residual times τ expected here, as set above. | |
651 | // TODO: is this the correct location? | |
652 | after_nth_accum += &residual; | |
653 | } | |
654 | // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` | |
655 | // has no significant overhead. For some reosn Rust doesn't allow us simply moving | |
656 | // the residual and replacing it below before the end of this closure. | |
657 | let r = std::mem::replace(&mut residual, opA.empty_observable()); | |
658 | let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b) | |
659 | // TODO: should avoid a second copy of μ here; μ_base already stores a copy. | |
660 | let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k | |
661 | //let g = &minus_τv + ω0; // Linear term of surrogate model | |
662 | ||
663 | // Save current base point | |
664 | let μ_base = μ.clone(); | |
665 | ||
666 | // Add points to support until within error tolerance or maximum insertion count reached. | |
667 | let mut count = 0; | |
668 | let (within_tolerances, d) = 'insertion: loop { | |
669 | if μ.len() > 0 { | |
670 | // Form finite-dimensional subproblem. The subproblem references to the original μ^k | |
671 | // from the beginning of the iteration are all contained in the immutable c and g. | |
672 | let à = op𝒟.findim_matrix(μ.iter_locations()); | |
673 | let g̃ = DVector::from_iterator(μ.len(), | |
674 | μ.iter_locations() | |
675 | .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ)) | |
676 | .map(F::to_nalgebra_mixed)); | |
677 | let mut x = μ.masses_dvector(); | |
678 | ||
679 | // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. | |
680 | // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ | |
681 | // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ | |
682 | // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 | |
683 | // = n |𝒟| |x|_2, where n is the number of points. Therefore | |
684 | let inner_τ = config.inner.τ0 / (op𝒟norm * F::cast_from(μ.len())); | |
685 | ||
686 | // Solve finite-dimensional subproblem. | |
687 | let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); | |
688 | inner_iters += quadratic_nonneg(config.inner.method, &Ã, &g̃, τ*α, &mut x, | |
689 | inner_τ, inner_it); | |
690 | ||
691 | // Update masses of μ based on solution of finite-dimensional subproblem. | |
692 | μ.set_masses_dvector(&x); | |
693 | } | |
694 | ||
695 | // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality | |
696 | // conditions in the predual space, and finding new points for insertion, if necessary. | |
697 | let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ, &μ_base)); | |
698 | ||
699 | // If no merging heuristic is used, let's be more conservative about spike insertion, | |
700 | // and skip it after first round. If merging is done, being more greedy about spike | |
701 | // insertion also seems to improve performance. | |
702 | let may_break = if let SpikeMergingMethod::None = config.merging { | |
703 | false | |
704 | } else { | |
705 | count > 0 | |
706 | }; | |
707 | ||
708 | // First do a rough check whether we are within bounds and can stop. | |
709 | let in_bounds = match config.ergodic_tolerance { | |
710 | ErgodicTolerance::NonErgodic => { | |
711 | target_bounds.superset(&d.bounds()) | |
712 | }, | |
713 | ErgodicTolerance::AfterNth{ n, factor } => { | |
714 | // Bound -τ∑_{k=0}^{N-1}[A_*(Aμ^k-b)+α] from above. | |
715 | match state.iteration().cmp(&n) { | |
716 | Less => true, | |
717 | Equal => { | |
718 | let iter = F::cast_from(state.iteration()); | |
719 | let mut tmp = preadjA.apply(&after_nth_accum); | |
720 | let (_, v0) = tmp.maximise(refinement_tolerance, | |
721 | config.refinement.max_steps); | |
722 | let v = v0 - iter * τ * α; | |
723 | after_nth_bound = factor * v; | |
724 | println!("{}", format!("Set ergodic tolerance to {}", after_nth_bound)); | |
725 | true | |
726 | }, | |
727 | Greater => { | |
728 | // TODO: can divide after_nth_accum by N, so use basic tolerance on that. | |
729 | let iter = F::cast_from(state.iteration()); | |
730 | let mut tmp = preadjA.apply(&after_nth_accum); | |
731 | tmp.has_upper_bound(after_nth_bound + iter * τ * α, | |
732 | refinement_tolerance, | |
733 | config.refinement.max_steps) | |
734 | } | |
735 | } | |
736 | } | |
737 | }; | |
738 | ||
739 | // If preliminary check indicates that we are in bonds, and if it otherwise matches | |
740 | // the insertion strategy, skip insertion. | |
741 | if may_break && in_bounds { | |
742 | break 'insertion (true, d) | |
743 | } | |
744 | ||
745 | // If the rough check didn't indicate stopping, find maximising point, maintaining for | |
746 | // the calculations in the beginning of the loop that v_ξ = (ω0-τv-𝒟μ)(ξ) = d(ξ), | |
747 | // where 𝒟μ is now distinct from μ0 after the insertions already performed. | |
748 | // We do not need to check lower bounds, as a solution of the finite-dimensional | |
749 | // subproblem should always satisfy them. | |
750 | ||
751 | // // Find the mimimum over the support of μ. | |
752 | // let d_min_supp = d_max;μ.iter_spikes().filter_map(|&DeltaMeasure{ α, ref x }| { | |
753 | // (α != F::ZERO).then(|| d.value(x)) | |
754 | // }).reduce(F::min).unwrap_or(0.0); | |
755 | ||
756 | let (ξ, v_ξ) = if false /* μ.len() == 0*/ /*count == 0 &&*/ { | |
757 | // If μ has no spikes, just find the maximum of d. Then adjust the tolerance, if | |
758 | // necessary, to adapt it to the problem. | |
759 | let (ξ, v_ξ) = d.maximise(refinement_tolerance, config.refinement.max_steps); | |
760 | //dbg!((τα, v_ξ, target_bounds.upper(), maximise_above)); | |
761 | if τα < v_ξ && v_ξ < target_bounds.upper() { | |
762 | ε1 = v_ξ - τα; | |
763 | ε *= ε1 / ε_prev; | |
764 | tolerance *= ε1 / ε_prev; | |
765 | } | |
766 | (ξ, v_ξ) | |
767 | } else { | |
768 | // If μ has some spikes, only find a maximum of d if it is above a threshold | |
769 | // defined by the refinment tolerance. | |
770 | match d.maximise_above(maximise_above, refinement_tolerance, | |
771 | config.refinement.max_steps) { | |
772 | None => break 'insertion (true, d), | |
773 | Some(res) => res, | |
774 | } | |
775 | }; | |
776 | ||
777 | // // Do a one final check whether we can stop already without inserting more points | |
778 | // // because `d` actually in bounds based on a more refined estimate. | |
779 | // if may_break && target_bounds.upper() >= v_ξ { | |
780 | // break (true, d) | |
781 | // } | |
782 | ||
783 | // Break if maximum insertion count reached | |
784 | if count >= max_insertions { | |
785 | let in_bounds2 = target_bounds.upper() >= v_ξ; | |
786 | break 'insertion (in_bounds2, d) | |
787 | } | |
788 | ||
789 | // No point in optimising the weight here; the finite-dimensional algorithm is fast. | |
790 | μ += DeltaMeasure { x : ξ, α : 0.0 }; | |
791 | count += 1; | |
792 | }; | |
793 | ||
794 | if !within_tolerances && warn_insertions { | |
795 | // Complain (but continue) if we failed to get within tolerances | |
796 | // by inserting more points. | |
797 | let err = format!("Maximum insertions reached without achieving \ | |
798 | subproblem solution tolerance"); | |
799 | println!("{}", err.red()); | |
800 | } | |
801 | ||
802 | // Merge spikes | |
803 | if state.iteration() % config.merge_every == 0 { | |
804 | let n_before_merge = μ.len(); | |
805 | μ.merge_spikes(config.merging, |μ_candidate| { | |
806 | //println!("Merge attempt!"); | |
807 | let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ_candidate, &μ_base)); | |
808 | ||
809 | if merge_target_bounds.superset(&d.bounds()) { | |
810 | //println!("…Early Ok"); | |
811 | return Some(()) | |
812 | } | |
813 | ||
814 | let d_min_supp = μ_candidate.iter_spikes().filter_map(|&DeltaMeasure{ α, ref x }| { | |
815 | (α != 0.0).then(|| d.apply(x)) | |
816 | }).reduce(F::min); | |
817 | ||
818 | if d_min_supp.map_or(true, |b| b >= merge_target_bounds.lower()) && | |
819 | d.has_upper_bound(merge_target_bounds.upper(), refinement_tolerance, | |
820 | config.refinement.max_steps) { | |
821 | //println!("…Ok"); | |
822 | Some(()) | |
823 | } else { | |
824 | //println!("…Fail"); | |
825 | None | |
826 | } | |
827 | }); | |
828 | debug_assert!(μ.len() >= n_before_merge); | |
829 | merged += μ.len() - n_before_merge; | |
830 | } | |
831 | ||
832 | let n_before_prune = μ.len(); | |
833 | (residual, τ) = match specialisation.update(&mut μ, &μ_base) { | |
834 | (r, None) => (r, τ), | |
835 | (r, Some(new_τ)) => (r, new_τ) | |
836 | }; | |
837 | debug_assert!(μ.len() <= n_before_prune); | |
838 | pruned += n_before_prune - μ.len(); | |
839 | ||
840 | this_iters += 1; | |
841 | ||
842 | // Give function value if needed | |
843 | state.if_verbose(|| { | |
844 | let value_μ = specialisation.value_μ(&μ); | |
845 | // Plot if so requested | |
846 | plotter.plot_spikes( | |
847 | format!("iter {} end; {}", state.iteration(), within_tolerances), &d, | |
848 | "start".to_string(), Some(&minus_τv), | |
849 | Some(target_bounds), value_μ, | |
850 | ); | |
851 | // Calculate mean inner iterations and reset relevant counters | |
852 | // Return the statistics | |
853 | let res = IterInfo { | |
854 | value : specialisation.calculate_fit(&μ, &residual) + α * value_μ.norm(Radon), | |
855 | n_spikes : value_μ.len(), | |
856 | inner_iters, | |
857 | this_iters, | |
858 | merged, | |
859 | pruned, | |
860 | ε : ε_prev, | |
861 | maybe_ε1 : Some(ε1), | |
862 | postprocessing: config.postprocessing.then(|| value_μ.clone()), | |
863 | }; | |
864 | inner_iters = 0; | |
865 | this_iters = 0; | |
866 | merged = 0; | |
867 | pruned = 0; | |
868 | res | |
869 | }) | |
870 | }); | |
871 | ||
872 | specialisation.postprocess(μ, config.final_merging) | |
873 | } | |
874 | ||
875 | ||
876 | ||
877 |