Tue, 31 Dec 2024 09:34:24 -0500
Early transport sketches
0 | 1 | /*! |
2 | Solver for the point source localisation problem using a forward-backward splitting method. | |
3 | ||
4 | This corresponds to the manuscript | |
5 | ||
13
bdc57366d4f5
arXiv links, README beautification
Tuomo Valkonen <tuomov@iki.fi>
parents:
8
diff
changeset
|
6 | * Valkonen T. - _Proximal methods for point source localisation_, |
bdc57366d4f5
arXiv links, README beautification
Tuomo Valkonen <tuomov@iki.fi>
parents:
8
diff
changeset
|
7 | [arXiv:2212.02991](https://arxiv.org/abs/2212.02991). |
0 | 8 | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
9 | The main routine is [`pointsource_fb_reg`]. It is based on [`generic_pointsource_fb_reg`], which is |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
10 | also used by our [primal-dual proximal splitting][crate::pdps] implementation. |
0 | 11 | |
12 | FISTA-type inertia can also be enabled through [`FBConfig::meta`]. | |
13 | ||
14 | ## Problem | |
15 | ||
16 | <p> | |
17 | Our objective is to solve | |
18 | $$ | |
19 | \min_{μ ∈ ℳ(Ω)}~ F_0(Aμ-b) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ), | |
20 | $$ | |
21 | where $F_0(y)=\frac{1}{2}\|y\|_2^2$ and the forward operator $A \in 𝕃(ℳ(Ω); ℝ^n)$. | |
22 | </p> | |
23 | ||
24 | ## Approach | |
25 | ||
26 | <p> | |
27 | As documented in more detail in the paper, on each step we approximately solve | |
28 | $$ | |
29 | \min_{μ ∈ ℳ(Ω)}~ F(x) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(x) + \frac{1}{2}\|μ-μ^k|_𝒟^2, | |
30 | $$ | |
31 | where $𝒟: 𝕃(ℳ(Ω); C_c(Ω))$ is typically a convolution operator. | |
32 | </p> | |
33 | ||
34 | ## Finite-dimensional subproblems. | |
35 | ||
36 | With $C$ a projection from [`DiscreteMeasure`] to the weights, and $x^k$ such that $x^k=Cμ^k$, we | |
37 | form the discretised linearised inner problem | |
38 | <p> | |
39 | $$ | |
40 | \min_{x ∈ ℝ^n}~ τ\bigl(F(Cx^k) + [C^*∇F(Cx^k)]^⊤(x-x^k) + α {\vec 1}^⊤ x\bigr) | |
41 | + δ_{≥ 0}(x) + \frac{1}{2}\|x-x^k\|_{C^*𝒟C}^2, | |
42 | $$ | |
43 | equivalently | |
44 | $$ | |
45 | \begin{aligned} | |
46 | \min_x~ & τF(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k | |
47 | \\ | |
48 | & | |
49 | - [C^*𝒟C x^k - τC^*∇F(Cx^k)]^⊤ x | |
50 | \\ | |
51 | & | |
52 | + \frac{1}{2} x^⊤ C^*𝒟C x | |
53 | + τα {\vec 1}^⊤ x + δ_{≥ 0}(x), | |
54 | \end{aligned} | |
55 | $$ | |
56 | In other words, we obtain the quadratic non-negativity constrained problem | |
57 | $$ | |
58 | \min_{x ∈ ℝ^n}~ \frac{1}{2} x^⊤ Ã x - b̃^⊤ x + c + τα {\vec 1}^⊤ x + δ_{≥ 0}(x). | |
59 | $$ | |
60 | where | |
61 | $$ | |
62 | \begin{aligned} | |
63 | Ã & = C^*𝒟C, | |
64 | \\ | |
65 | g̃ & = C^*𝒟C x^k - τ C^*∇F(Cx^k) | |
66 | = C^* 𝒟 μ^k - τ C^*A^*(Aμ^k - b) | |
67 | \\ | |
68 | c & = τ F(Cx^k) - τ[C^*∇F(Cx^k)]^⊤x^k + \frac{1}{2} (x^k)^⊤ C^*𝒟C x^k | |
69 | \\ | |
70 | & | |
71 | = \frac{τ}{2} \|Aμ^k-b\|^2 - τ[Aμ^k-b]^⊤Aμ^k + \frac{1}{2} \|μ_k\|_{𝒟}^2 | |
72 | \\ | |
73 | & | |
74 | = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2. | |
75 | \end{aligned} | |
76 | $$ | |
77 | </p> | |
78 | ||
79 | We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by | |
80 | [`InnerSettings`] in [`FBGenericConfig::inner`]. | |
81 | */ | |
82 | ||
83 | use numeric_literals::replace_float_literals; | |
84 | use serde::{Serialize, Deserialize}; | |
85 | use colored::Colorize; | |
32 | 86 | use nalgebra::DVector; |
0 | 87 | |
88 | use alg_tools::iterate::{ | |
89 | AlgIteratorFactory, | |
90 | AlgIteratorState, | |
91 | }; | |
92 | use alg_tools::euclidean::Euclidean; | |
32 | 93 | use alg_tools::linops::{Apply, GEMV}; |
0 | 94 | use alg_tools::sets::Cube; |
95 | use alg_tools::loc::Loc; | |
96 | use alg_tools::bisection_tree::{ | |
97 | BTFN, | |
98 | PreBTFN, | |
99 | Bounds, | |
100 | BTNodeLookup, | |
101 | BTNode, | |
102 | BTSearch, | |
103 | P2Minimise, | |
104 | SupportGenerator, | |
105 | LocalAnalysis, | |
32 | 106 | BothGenerators, |
0 | 107 | }; |
108 | use alg_tools::mapping::RealMapping; | |
109 | use alg_tools::nalgebra_support::ToNalgebraRealField; | |
110 | ||
111 | use crate::types::*; | |
112 | use crate::measures::{ | |
113 | DiscreteMeasure, | |
114 | DeltaMeasure, | |
115 | }; | |
116 | use crate::measures::merging::{ | |
117 | SpikeMergingMethod, | |
118 | SpikeMerging, | |
119 | }; | |
120 | use crate::forward_model::ForwardModel; | |
32 | 121 | use crate::seminorms::DiscreteMeasureOp; |
0 | 122 | use crate::subproblem::{ |
123 | InnerSettings, | |
124 | InnerMethod, | |
125 | }; | |
126 | use crate::tolerance::Tolerance; | |
127 | use crate::plot::{ | |
128 | SeqPlotter, | |
129 | Plotting, | |
130 | PlotLookup | |
131 | }; | |
32 | 132 | use crate::regularisation::RegTerm; |
133 | use crate::dataterm::{ | |
134 | calculate_residual, | |
135 | L2Squared, | |
136 | DataTerm, | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
137 | }; |
0 | 138 | |
139 | /// Method for constructing $μ$ on each iteration | |
140 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
141 | #[allow(dead_code)] | |
142 | pub enum InsertionStyle { | |
143 | /// Resuse previous $μ$ from previous iteration, optimising weights | |
144 | /// before inserting new spikes. | |
145 | Reuse, | |
146 | /// Start each iteration with $μ=0$. | |
147 | Zero, | |
148 | } | |
149 | ||
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
150 | /// Settings for [`pointsource_fb_reg`]. |
0 | 151 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
152 | #[serde(default)] | |
153 | pub struct FBConfig<F : Float> { | |
154 | /// Step length scaling | |
155 | pub τ0 : F, | |
156 | /// Generic parameters | |
157 | pub insertion : FBGenericConfig<F>, | |
158 | } | |
159 | ||
160 | /// Settings for the solution of the stepwise optimality condition in algorithms based on | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
161 | /// [`generic_pointsource_fb_reg`]. |
0 | 162 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
163 | #[serde(default)] | |
164 | pub struct FBGenericConfig<F : Float> { | |
165 | /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`]. | |
166 | pub insertion_style : InsertionStyle, | |
167 | /// Tolerance for point insertion. | |
168 | pub tolerance : Tolerance<F>, | |
169 | /// Stop looking for predual maximum (where to isert a new point) below | |
170 | /// `tolerance` multiplied by this factor. | |
171 | pub insertion_cutoff_factor : F, | |
172 | /// Settings for branch and bound refinement when looking for predual maxima | |
173 | pub refinement : RefinementSettings<F>, | |
174 | /// Maximum insertions within each outer iteration | |
175 | pub max_insertions : usize, | |
176 | /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. | |
177 | pub bootstrap_insertions : Option<(usize, usize)>, | |
178 | /// Inner method settings | |
179 | pub inner : InnerSettings<F>, | |
180 | /// Spike merging method | |
181 | pub merging : SpikeMergingMethod<F>, | |
182 | /// Tolerance multiplier for merges | |
183 | pub merge_tolerance_mult : F, | |
184 | /// Spike merging method after the last step | |
185 | pub final_merging : SpikeMergingMethod<F>, | |
186 | /// Iterations between merging heuristic tries | |
187 | pub merge_every : usize, | |
188 | /// Save $μ$ for postprocessing optimisation | |
189 | pub postprocessing : bool | |
190 | } | |
191 | ||
192 | #[replace_float_literals(F::cast_from(literal))] | |
193 | impl<F : Float> Default for FBConfig<F> { | |
194 | fn default() -> Self { | |
195 | FBConfig { | |
196 | τ0 : 0.99, | |
197 | insertion : Default::default() | |
198 | } | |
199 | } | |
200 | } | |
201 | ||
202 | #[replace_float_literals(F::cast_from(literal))] | |
203 | impl<F : Float> Default for FBGenericConfig<F> { | |
204 | fn default() -> Self { | |
205 | FBGenericConfig { | |
206 | insertion_style : InsertionStyle::Reuse, | |
207 | tolerance : Default::default(), | |
208 | insertion_cutoff_factor : 1.0, | |
209 | refinement : Default::default(), | |
210 | max_insertions : 100, | |
211 | //bootstrap_insertions : None, | |
212 | bootstrap_insertions : Some((10, 1)), | |
213 | inner : InnerSettings { | |
214 | method : InnerMethod::SSN, | |
215 | .. Default::default() | |
216 | }, | |
217 | merging : SpikeMergingMethod::None, | |
218 | //merging : Default::default(), | |
219 | final_merging : Default::default(), | |
220 | merge_every : 10, | |
221 | merge_tolerance_mult : 2.0, | |
222 | postprocessing : false, | |
223 | } | |
224 | } | |
225 | } | |
226 | ||
227 | #[replace_float_literals(F::cast_from(literal))] | |
32 | 228 | pub(crate) fn μ_diff<F : Float, const N : usize>( |
229 | μ_new : &DiscreteMeasure<Loc<F, N>, F>, | |
230 | μ_base : &DiscreteMeasure<Loc<F, N>, F>, | |
231 | ν_delta : Option<&DiscreteMeasure<Loc<F, N>, F>>, | |
232 | config : &FBGenericConfig<F> | |
233 | ) -> DiscreteMeasure<Loc<F, N>, F> { | |
234 | let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style { | |
235 | InsertionStyle::Reuse => { | |
236 | μ_new.iter_spikes() | |
237 | .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0))) | |
238 | .map(|(δ, α_base)| (δ.x, α_base - δ.α)) | |
239 | .collect() | |
240 | }, | |
241 | InsertionStyle::Zero => { | |
242 | μ_new.iter_spikes() | |
243 | .map(|δ| -δ) | |
244 | .chain(μ_base.iter_spikes().copied()) | |
245 | .collect() | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
246 | } |
32 | 247 | }; |
248 | ν.prune(); // Potential small performance improvement | |
249 | // Add ν_delta if given | |
250 | match ν_delta { | |
251 | None => ν, | |
252 | Some(ν_d) => ν + ν_d, | |
0 | 253 | } |
254 | } | |
255 | ||
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
256 | #[replace_float_literals(F::cast_from(literal))] |
32 | 257 | pub(crate) fn insert_and_reweigh< |
258 | 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize | |
259 | >( | |
260 | μ : &mut DiscreteMeasure<Loc<F, N>, F>, | |
261 | minus_τv : &BTFN<F, GA, BTA, N>, | |
262 | μ_base : &DiscreteMeasure<Loc<F, N>, F>, | |
263 | ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>, | |
264 | op𝒟 : &'a 𝒟, | |
265 | op𝒟norm : F, | |
266 | τ : F, | |
267 | ε : F, | |
268 | config : &FBGenericConfig<F>, | |
269 | reg : &Reg, | |
270 | state : &State, | |
271 | stats : &mut IterInfo<F, N>, | |
272 | ) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool) | |
273 | where F : Float + ToNalgebraRealField, | |
274 | GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, | |
275 | BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, | |
276 | G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, | |
277 | 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, | |
278 | 𝒟::Codomain : RealMapping<F, N>, | |
279 | S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
280 | K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
281 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, | |
282 | DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, | |
283 | Reg : RegTerm<F, N>, | |
284 | State : AlgIteratorState { | |
285 | ||
286 | // Maximum insertion count and measure difference calculation depend on insertion style. | |
287 | let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { | |
288 | (i, Some((l, k))) if i <= l => (k, false), | |
289 | _ => (config.max_insertions, !state.is_quiet()), | |
290 | }; | |
291 | let max_insertions = match config.insertion_style { | |
292 | InsertionStyle::Zero => { | |
293 | todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled."); | |
294 | // let n = μ.len(); | |
295 | // μ = DiscreteMeasure::new(); | |
296 | // n + m | |
297 | }, | |
298 | InsertionStyle::Reuse => m, | |
299 | }; | |
300 | ||
301 | // TODO: should avoid a second copy of μ here; μ_base already stores a copy. | |
302 | let ω0 = op𝒟.apply(match ν_delta { | |
303 | None => μ.clone(), | |
304 | Some(ν_d) => &*μ + ν_d, | |
305 | }); | |
306 | ||
307 | // Add points to support until within error tolerance or maximum insertion count reached. | |
308 | let mut count = 0; | |
309 | let (within_tolerances, d) = 'insertion: loop { | |
310 | if μ.len() > 0 { | |
311 | // Form finite-dimensional subproblem. The subproblem references to the original μ^k | |
312 | // from the beginning of the iteration are all contained in the immutable c and g. | |
313 | let à = op𝒟.findim_matrix(μ.iter_locations()); | |
314 | let g̃ = DVector::from_iterator(μ.len(), | |
315 | μ.iter_locations() | |
316 | .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ)) | |
317 | .map(F::to_nalgebra_mixed)); | |
318 | let mut x = μ.masses_dvector(); | |
319 | ||
320 | // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. | |
321 | // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ | |
322 | // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ | |
323 | // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 | |
324 | // = n |𝒟| |x|_2, where n is the number of points. Therefore | |
325 | let Ã_normest = op𝒟norm * F::cast_from(μ.len()); | |
326 | ||
327 | // Solve finite-dimensional subproblem. | |
328 | stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); | |
329 | ||
330 | // Update masses of μ based on solution of finite-dimensional subproblem. | |
331 | μ.set_masses_dvector(&x); | |
332 | } | |
333 | ||
334 | // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality | |
335 | // conditions in the predual space, and finding new points for insertion, if necessary. | |
336 | let mut d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_base, ν_delta, config)); | |
337 | ||
338 | // If no merging heuristic is used, let's be more conservative about spike insertion, | |
339 | // and skip it after first round. If merging is done, being more greedy about spike | |
340 | // insertion also seems to improve performance. | |
341 | let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { | |
342 | false | |
343 | } else { | |
344 | count > 0 | |
345 | }; | |
346 | ||
347 | // Find a spike to insert, if needed | |
348 | let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( | |
349 | &mut d, τ, ε, skip_by_rough_check, config | |
350 | ) { | |
351 | None => break 'insertion (true, d), | |
352 | Some(res) => res, | |
353 | }; | |
354 | ||
355 | // Break if maximum insertion count reached | |
356 | if count >= max_insertions { | |
357 | break 'insertion (in_bounds, d) | |
358 | } | |
359 | ||
360 | // No point in optimising the weight here; the finite-dimensional algorithm is fast. | |
361 | *μ += DeltaMeasure { x : ξ, α : 0.0 }; | |
362 | count += 1; | |
363 | }; | |
364 | ||
365 | // TODO: should redo everything if some transports cause a problem. | |
366 | // Maybe implementation should call above loop as a closure. | |
367 | ||
368 | if !within_tolerances && warn_insertions { | |
369 | // Complain (but continue) if we failed to get within tolerances | |
370 | // by inserting more points. | |
371 | let err = format!("Maximum insertions reached without achieving \ | |
372 | subproblem solution tolerance"); | |
373 | println!("{}", err.red()); | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
374 | } |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
375 | |
32 | 376 | (d, within_tolerances) |
377 | } | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
378 | |
32 | 379 | #[replace_float_literals(F::cast_from(literal))] |
380 | pub(crate) fn prune_and_maybe_simple_merge< | |
381 | 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize | |
382 | >( | |
383 | μ : &mut DiscreteMeasure<Loc<F, N>, F>, | |
384 | minus_τv : &BTFN<F, GA, BTA, N>, | |
385 | μ_base : &DiscreteMeasure<Loc<F, N>, F>, | |
386 | op𝒟 : &'a 𝒟, | |
387 | τ : F, | |
388 | ε : F, | |
389 | config : &FBGenericConfig<F>, | |
390 | reg : &Reg, | |
391 | state : &State, | |
392 | stats : &mut IterInfo<F, N>, | |
393 | ) | |
394 | where F : Float + ToNalgebraRealField, | |
395 | GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, | |
396 | BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, | |
397 | G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, | |
398 | 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, | |
399 | 𝒟::Codomain : RealMapping<F, N>, | |
400 | S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
401 | K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
402 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, | |
403 | DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, | |
404 | Reg : RegTerm<F, N>, | |
405 | State : AlgIteratorState { | |
406 | if state.iteration() % config.merge_every == 0 { | |
407 | let n_before_merge = μ.len(); | |
408 | μ.merge_spikes(config.merging, |μ_candidate| { | |
409 | let μd = μ_diff(&μ_candidate, &μ_base, None, config); | |
410 | let mut d = minus_τv + op𝒟.preapply(μd); | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
411 | |
32 | 412 | reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
413 | .then_some(()) | |
414 | }); | |
415 | debug_assert!(μ.len() >= n_before_merge); | |
416 | stats.merged += μ.len() - n_before_merge; | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
417 | } |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
418 | |
32 | 419 | let n_before_prune = μ.len(); |
420 | μ.prune(); | |
421 | debug_assert!(μ.len() <= n_before_prune); | |
422 | stats.pruned += n_before_prune - μ.len(); | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
423 | } |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
424 | |
32 | 425 | #[replace_float_literals(F::cast_from(literal))] |
426 | pub(crate) fn postprocess< | |
427 | F : Float, | |
428 | V : Euclidean<F> + Clone, | |
429 | A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>, | |
430 | D : DataTerm<F, V, N>, | |
431 | const N : usize | |
432 | > ( | |
433 | mut μ : DiscreteMeasure<Loc<F, N>, F>, | |
434 | config : &FBGenericConfig<F>, | |
435 | dataterm : D, | |
436 | opA : &A, | |
437 | b : &V, | |
438 | ) -> DiscreteMeasure<Loc<F, N>, F> | |
439 | where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { | |
440 | μ.merge_spikes_fitness(config.merging, | |
441 | |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), | |
442 | |&v| v); | |
443 | μ.prune(); | |
444 | μ | |
445 | } | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
446 | |
32 | 447 | /// Iteratively solve the pointsource localisation problem using forward-backward splitting. |
0 | 448 | /// |
32 | 449 | /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
0 | 450 | /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
451 | /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution | |
452 | /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control | |
453 | /// as documented in [`alg_tools::iterate`]. | |
454 | /// | |
32 | 455 | /// For details on the mathematical formulation, see the [module level](self) documentation. |
456 | /// | |
0 | 457 | /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
458 | /// sums of simple functions usign bisection trees, and the related | |
459 | /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions | |
460 | /// active at a specific points, and to maximise their sums. Through the implementation of the | |
461 | /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features | |
462 | /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. | |
463 | /// | |
464 | /// Returns the final iterate. | |
465 | #[replace_float_literals(F::cast_from(literal))] | |
32 | 466 | pub fn pointsource_fb_reg< |
467 | 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
468 | >( |
0 | 469 | opA : &'a A, |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
470 | b : &A::Observable, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
471 | reg : Reg, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
472 | op𝒟 : &'a 𝒟, |
32 | 473 | fbconfig : &FBConfig<F>, |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
474 | iterator : I, |
32 | 475 | mut plotter : SeqPlotter<F, N>, |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
476 | ) -> DiscreteMeasure<Loc<F, N>, F> |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
477 | where F : Float + ToNalgebraRealField, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
478 | I : AlgIteratorFactory<IterInfo<F, N>>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
479 | for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
480 | //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
481 | A::Observable : std::ops::MulAssign<F>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
482 | GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
483 | A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
32 | 484 | + Lipschitz<&'a 𝒟, FloatType=F>, |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
485 | BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
486 | G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
487 | 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
488 | 𝒟::Codomain : RealMapping<F, N>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
489 | S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
490 | K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
491 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
492 | Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
493 | PlotLookup : Plotting<N>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
494 | DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
495 | Reg : RegTerm<F, N> { |
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
496 | |
32 | 497 | // Set up parameters |
498 | let config = &fbconfig.insertion; | |
499 | let op𝒟norm = op𝒟.opnorm_bound(); | |
500 | let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); | |
501 | // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled | |
502 | // by τ compared to the conditional gradient approach. | |
503 | let tolerance = config.tolerance * τ * reg.tolerance_scaling(); | |
504 | let mut ε = tolerance.initial(); | |
505 | ||
506 | // Initialise iterates | |
507 | let mut μ = DiscreteMeasure::new(); | |
508 | let mut residual = -b; | |
509 | let mut stats = IterInfo::new(); | |
510 | ||
511 | // Run the algorithm | |
512 | iterator.iterate(|state| { | |
513 | // Calculate smooth part of surrogate model. | |
514 | // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` | |
515 | // has no significant overhead. For some reosn Rust doesn't allow us simply moving | |
516 | // the residual and replacing it below before the end of this closure. | |
517 | residual *= -τ; | |
518 | let r = std::mem::replace(&mut residual, opA.empty_observable()); | |
519 | let minus_τv = opA.preadjoint().apply(r); | |
520 | ||
521 | // Save current base point | |
522 | let μ_base = μ.clone(); | |
523 | ||
524 | // Insert and reweigh | |
525 | let (d, within_tolerances) = insert_and_reweigh( | |
526 | &mut μ, &minus_τv, &μ_base, None, | |
527 | op𝒟, op𝒟norm, | |
528 | τ, ε, | |
529 | config, ®, state, &mut stats | |
530 | ); | |
531 | ||
532 | // Prune and possibly merge spikes | |
533 | prune_and_maybe_simple_merge( | |
534 | &mut μ, &minus_τv, &μ_base, | |
535 | op𝒟, | |
536 | τ, ε, | |
537 | config, ®, state, &mut stats | |
538 | ); | |
539 | ||
540 | // Update residual | |
541 | residual = calculate_residual(&μ, opA, b); | |
542 | ||
543 | // Update main tolerance for next iteration | |
544 | let ε_prev = ε; | |
545 | ε = tolerance.update(ε, state.iteration()); | |
546 | stats.this_iters += 1; | |
547 | ||
548 | // Give function value if needed | |
549 | state.if_verbose(|| { | |
550 | // Plot if so requested | |
551 | plotter.plot_spikes( | |
552 | format!("iter {} end; {}", state.iteration(), within_tolerances), &d, | |
553 | "start".to_string(), Some(&minus_τv), | |
554 | reg.target_bounds(τ, ε_prev), &μ, | |
555 | ); | |
556 | // Calculate mean inner iterations and reset relevant counters. | |
557 | // Return the statistics | |
558 | let res = IterInfo { | |
559 | value : residual.norm2_squared_div2() + reg.apply(&μ), | |
560 | n_spikes : μ.len(), | |
561 | ε : ε_prev, | |
562 | postprocessing: config.postprocessing.then(|| μ.clone()), | |
563 | .. stats | |
564 | }; | |
565 | stats = IterInfo::new(); | |
566 | res | |
567 | }) | |
568 | }); | |
569 | ||
570 | postprocess(μ, config, L2Squared, opA, b) | |
571 | } | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
572 | |
32 | 573 | /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. |
574 | /// | |
575 | /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the | |
576 | /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. | |
577 | /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution | |
578 | /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control | |
579 | /// as documented in [`alg_tools::iterate`]. | |
580 | /// | |
581 | /// For details on the mathematical formulation, see the [module level](self) documentation. | |
582 | /// | |
583 | /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of | |
584 | /// sums of simple functions usign bisection trees, and the related | |
585 | /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions | |
586 | /// active at a specific points, and to maximise their sums. Through the implementation of the | |
587 | /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features | |
588 | /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. | |
589 | /// | |
590 | /// Returns the final iterate. | |
591 | #[replace_float_literals(F::cast_from(literal))] | |
592 | pub fn pointsource_fista_reg< | |
593 | 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize | |
594 | >( | |
595 | opA : &'a A, | |
596 | b : &A::Observable, | |
597 | reg : Reg, | |
598 | op𝒟 : &'a 𝒟, | |
599 | fbconfig : &FBConfig<F>, | |
600 | iterator : I, | |
601 | mut plotter : SeqPlotter<F, N>, | |
602 | ) -> DiscreteMeasure<Loc<F, N>, F> | |
603 | where F : Float + ToNalgebraRealField, | |
604 | I : AlgIteratorFactory<IterInfo<F, N>>, | |
605 | for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, | |
606 | //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow | |
607 | A::Observable : std::ops::MulAssign<F>, | |
608 | GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, | |
609 | A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> | |
610 | + Lipschitz<&'a 𝒟, FloatType=F>, | |
611 | BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, | |
612 | G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, | |
613 | 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, | |
614 | 𝒟::Codomain : RealMapping<F, N>, | |
615 | S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
616 | K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
617 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, | |
618 | Cube<F, N>: P2Minimise<Loc<F, N>, F>, | |
619 | PlotLookup : Plotting<N>, | |
620 | DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, | |
621 | Reg : RegTerm<F, N> { | |
622 | ||
623 | // Set up parameters | |
624 | let config = &fbconfig.insertion; | |
625 | let op𝒟norm = op𝒟.opnorm_bound(); | |
626 | let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); | |
627 | let mut λ = 1.0; | |
628 | // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled | |
629 | // by τ compared to the conditional gradient approach. | |
630 | let tolerance = config.tolerance * τ * reg.tolerance_scaling(); | |
631 | let mut ε = tolerance.initial(); | |
632 | ||
633 | // Initialise iterates | |
634 | let mut μ = DiscreteMeasure::new(); | |
635 | let mut μ_prev = DiscreteMeasure::new(); | |
636 | let mut residual = -b; | |
637 | let mut stats = IterInfo::new(); | |
638 | let mut warned_merging = false; | |
639 | ||
640 | // Run the algorithm | |
641 | iterator.iterate(|state| { | |
642 | // Calculate smooth part of surrogate model. | |
643 | // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` | |
644 | // has no significant overhead. For some reosn Rust doesn't allow us simply moving | |
645 | // the residual and replacing it below before the end of this closure. | |
646 | residual *= -τ; | |
647 | let r = std::mem::replace(&mut residual, opA.empty_observable()); | |
648 | let minus_τv = opA.preadjoint().apply(r); | |
649 | ||
650 | // Save current base point | |
651 | let μ_base = μ.clone(); | |
652 | ||
653 | // Insert new spikes and reweigh | |
654 | let (d, within_tolerances) = insert_and_reweigh( | |
655 | &mut μ, &minus_τv, &μ_base, None, | |
656 | op𝒟, op𝒟norm, | |
657 | τ, ε, | |
658 | config, ®, state, &mut stats | |
659 | ); | |
660 | ||
661 | // (Do not) merge spikes. | |
662 | if state.iteration() % config.merge_every == 0 { | |
663 | match config.merging { | |
664 | SpikeMergingMethod::None => { }, | |
665 | _ => if !warned_merging { | |
666 | let err = format!("Merging not supported for μFISTA"); | |
667 | println!("{}", err.red()); | |
668 | warned_merging = true; | |
669 | } | |
670 | } | |
671 | } | |
672 | ||
673 | // Update inertial prameters | |
674 | let λ_prev = λ; | |
675 | λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); | |
676 | let θ = λ / λ_prev - λ; | |
677 | ||
678 | // Perform inertial update on μ. | |
679 | // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ | |
680 | // and μ_prev have zero weight. Since both have weights from the finite-dimensional | |
681 | // subproblem with a proximal projection step, this is likely to happen when the | |
682 | // spike is not needed. A copy of the pruned μ without artithmetic performed is | |
683 | // stored in μ_prev. | |
684 | let n_before_prune = μ.len(); | |
685 | μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); | |
686 | debug_assert!(μ.len() <= n_before_prune); | |
687 | stats.pruned += n_before_prune - μ.len(); | |
688 | ||
689 | // Update residual | |
690 | residual = calculate_residual(&μ, opA, b); | |
691 | ||
692 | // Update main tolerance for next iteration | |
693 | let ε_prev = ε; | |
694 | ε = tolerance.update(ε, state.iteration()); | |
695 | stats.this_iters += 1; | |
696 | ||
697 | // Give function value if needed | |
698 | state.if_verbose(|| { | |
699 | // Plot if so requested | |
700 | plotter.plot_spikes( | |
701 | format!("iter {} end; {}", state.iteration(), within_tolerances), &d, | |
702 | "start".to_string(), Some(&minus_τv), | |
703 | reg.target_bounds(τ, ε_prev), &μ_prev, | |
704 | ); | |
705 | // Calculate mean inner iterations and reset relevant counters. | |
706 | // Return the statistics | |
707 | let res = IterInfo { | |
708 | value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev), | |
709 | n_spikes : μ_prev.len(), | |
710 | ε : ε_prev, | |
711 | postprocessing: config.postprocessing.then(|| μ_prev.clone()), | |
712 | .. stats | |
713 | }; | |
714 | stats = IterInfo::new(); | |
715 | res | |
716 | }) | |
717 | }); | |
718 | ||
719 | postprocess(μ_prev, config, L2Squared, opA, b) | |
24
d29d1fcf5423
Support arbitrary regularisation terms; implement non-positivity-constrained regularisation.
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
720 | } |