Wed, 30 Nov 2022 23:45:04 +0200
Sketch FBGenericConfig clap
0 | 1 | /*! |
2 | Solver for the point source localisation problem using a conditional gradient method. | |
3 | ||
4 | We implement two variants, the “fully corrective” method from | |
5 | ||
6 | * Pieper K., Walter D. _Linear convergence of accelerated conditional gradient algorithms | |
7 | in spaces of measures_, DOI: [10.1051/cocv/2021042](https://doi.org/10.1051/cocv/2021042), | |
8 | arXiv: [1904.09218](https://doi.org/10.48550/arXiv.1904.09218). | |
9 | ||
10 | and what we call the “relaxed” method from | |
11 | ||
12 | * Bredies K., Pikkarainen H. - _Inverse problems in spaces of measures_, | |
13 | DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205). | |
14 | */ | |
15 | ||
16 | use numeric_literals::replace_float_literals; | |
17 | use serde::{Serialize, Deserialize}; | |
18 | //use colored::Colorize; | |
19 | ||
20 | use alg_tools::iterate::{ | |
21 | AlgIteratorFactory, | |
22 | AlgIteratorState, | |
23 | AlgIteratorOptions, | |
24 | }; | |
25 | use alg_tools::euclidean::Euclidean; | |
26 | use alg_tools::norms::Norm; | |
27 | use alg_tools::linops::Apply; | |
28 | use alg_tools::sets::Cube; | |
29 | use alg_tools::loc::Loc; | |
30 | use alg_tools::bisection_tree::{ | |
31 | BTFN, | |
32 | Bounds, | |
33 | BTNodeLookup, | |
34 | BTNode, | |
35 | BTSearch, | |
36 | P2Minimise, | |
37 | SupportGenerator, | |
38 | LocalAnalysis, | |
39 | }; | |
40 | use alg_tools::mapping::RealMapping; | |
41 | use alg_tools::nalgebra_support::ToNalgebraRealField; | |
42 | ||
43 | use crate::types::*; | |
44 | use crate::measures::{ | |
45 | DiscreteMeasure, | |
46 | DeltaMeasure, | |
47 | Radon, | |
48 | }; | |
49 | use crate::measures::merging::{ | |
50 | SpikeMergingMethod, | |
51 | SpikeMerging, | |
52 | }; | |
53 | use crate::forward_model::ForwardModel; | |
54 | #[allow(unused_imports)] // Used in documentation | |
55 | use crate::subproblem::{ | |
56 | quadratic_nonneg, | |
57 | InnerSettings, | |
58 | InnerMethod, | |
59 | }; | |
60 | use crate::tolerance::Tolerance; | |
61 | use crate::plot::{ | |
62 | SeqPlotter, | |
63 | Plotting, | |
64 | PlotLookup | |
65 | }; | |
66 | ||
67 | /// Settings for [`pointsource_fw`]. | |
68 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
69 | #[serde(default)] | |
70 | pub struct FWConfig<F : Float> { | |
71 | /// Tolerance for branch-and-bound new spike location discovery | |
72 | pub tolerance : Tolerance<F>, | |
73 | /// Inner problem solution configuration. Has to have `method` set to [`InnerMethod::FB`] | |
74 | /// as the conditional gradient subproblems' optimality conditions do not in general have an | |
75 | /// invertible Newton derivative for SSN. | |
76 | pub inner : InnerSettings<F>, | |
77 | /// Variant of the conditional gradient method | |
78 | pub variant : FWVariant, | |
79 | /// Settings for branch and bound refinement when looking for predual maxima | |
80 | pub refinement : RefinementSettings<F>, | |
81 | /// Spike merging heuristic | |
82 | pub merging : SpikeMergingMethod<F>, | |
83 | } | |
84 | ||
85 | /// Conditional gradient method variant; see also [`FWConfig`]. | |
86 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
87 | #[allow(dead_code)] | |
88 | pub enum FWVariant { | |
89 | /// Algorithm 2 of Walter-Pieper | |
90 | FullyCorrective, | |
91 | /// Bredies–Pikkarainen. Forces `FWConfig.inner.max_iter = 1`. | |
92 | Relaxed, | |
93 | } | |
94 | ||
95 | impl<F : Float> Default for FWConfig<F> { | |
96 | fn default() -> Self { | |
97 | FWConfig { | |
98 | tolerance : Default::default(), | |
99 | refinement : Default::default(), | |
100 | inner : Default::default(), | |
101 | variant : FWVariant::FullyCorrective, | |
102 | merging : Default::default(), | |
103 | } | |
104 | } | |
105 | } | |
106 | ||
107 | /// Helper struct for pre-initialising the finite-dimensional subproblems solver | |
108 | /// [`prepare_optimise_weights`]. | |
109 | /// | |
110 | /// The pre-initialisation is done by [`prepare_optimise_weights`]. | |
111 | pub struct FindimData<F : Float> { | |
112 | opAnorm_squared : F | |
113 | } | |
114 | ||
115 | /// Return a pre-initialisation struct for [`prepare_optimise_weights`]. | |
116 | /// | |
117 | /// The parameter `opA` is the forward operator $A$. | |
118 | pub fn prepare_optimise_weights<F, A, const N : usize>(opA : &A) -> FindimData<F> | |
119 | where F : Float + ToNalgebraRealField, | |
120 | A : ForwardModel<Loc<F, N>, F> { | |
121 | FindimData{ | |
122 | opAnorm_squared : opA.opnorm_bound().powi(2) | |
123 | } | |
124 | } | |
125 | ||
126 | /// Solve the finite-dimensional weight optimisation problem for the 2-norm-squared data fidelity | |
127 | /// point source localisation problem. | |
128 | /// | |
129 | /// That is, we minimise | |
130 | /// <div>$$ | |
131 | /// μ ↦ \frac{1}{2}\|Aμ-b\|_w^2 + α\|μ\|_ℳ + δ_{≥ 0}(μ) | |
132 | /// $$</div> | |
133 | /// only with respect to the weights of $μ$. | |
134 | /// | |
135 | /// The parameter `μ` is the discrete measure whose weights are to be optimised. | |
136 | /// The `opA` parameter is the forward operator $A$, while `b`$ and `α` are as in the | |
137 | /// objective above. The method parameter are set in `inner` (see [`InnerSettings`]), while | |
138 | /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to | |
139 | /// save intermediate iteration states as images. The parameter `findim_data` should be | |
140 | /// prepared using [`prepare_optimise_weights`]: | |
141 | /// | |
142 | /// Returns the number of iterations taken by the method configured in `inner`. | |
143 | pub fn optimise_weights<'a, F, A, I, const N : usize>( | |
144 | μ : &mut DiscreteMeasure<Loc<F, N>, F>, | |
145 | opA : &'a A, | |
146 | b : &A::Observable, | |
147 | α : F, | |
148 | findim_data : &FindimData<F>, | |
149 | inner : &InnerSettings<F>, | |
150 | iterator : I | |
151 | ) -> usize | |
152 | where F : Float + ToNalgebraRealField, | |
153 | I : AlgIteratorFactory<F>, | |
154 | A : ForwardModel<Loc<F, N>, F> | |
155 | { | |
156 | // Form and solve finite-dimensional subproblem. | |
157 | let (Ã, g̃) = opA.findim_quadratic_model(&μ, b); | |
158 | let mut x = μ.masses_dvector(); | |
159 | ||
160 | // `inner_τ1` is based on an estimate of the operator norm of $A$ from ℳ(Ω) to | |
161 | // ℝ^n. This estimate is a good one for the matrix norm from ℝ^m to ℝ^n when the | |
162 | // former is equipped with the 1-norm. We need the 2-norm. To pass from 1-norm to | |
163 | // 2-norm, we estimate | |
164 | // ‖A‖_{2,2} := sup_{‖x‖_2 ≤ 1} ‖Ax‖_2 ≤ sup_{‖x‖_1 ≤ C} ‖Ax‖_2 | |
165 | // = C sup_{‖x‖_1 ≤ 1} ‖Ax‖_2 = C ‖A‖_{1,2}, | |
166 | // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no | |
167 | // square root is needed when we scale: | |
168 | let inner_τ = inner.τ0 / (findim_data.opAnorm_squared * F::cast_from(μ.len())); | |
169 | let iters = quadratic_nonneg(inner.method, &Ã, &g̃, α, &mut x, inner_τ, iterator); | |
170 | // Update masses of μ based on solution of finite-dimensional subproblem. | |
171 | μ.set_masses_dvector(&x); | |
172 | ||
173 | iters | |
174 | } | |
175 | ||
176 | /// Solve point source localisation problem using a conditional gradient method | |
177 | /// for the 2-norm-squared data fidelity, i.e., the problem | |
178 | /// <div>$$ | |
179 | /// \min_μ \frac{1}{2}\|Aμ-b\|_w^2 + α\|μ\|_ℳ + δ_{≥ 0}(μ). | |
180 | /// $$</div> | |
181 | /// | |
182 | /// The `opA` parameter is the forward operator $A$, while `b`$ and `α` are as in the | |
183 | /// objective above. The method parameter are set in `config` (see [`FWConfig`]), while | |
184 | /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to | |
185 | /// save intermediate iteration states as images. | |
186 | #[replace_float_literals(F::cast_from(literal))] | |
187 | pub fn pointsource_fw<'a, F, I, A, GA, BTA, S, const N : usize>( | |
188 | opA : &'a A, | |
189 | b : &A::Observable, | |
190 | α : F, | |
191 | //domain : Cube<F, N>, | |
192 | config : &FWConfig<F>, | |
193 | iterator : I, | |
194 | mut plotter : SeqPlotter<F, N>, | |
195 | ) -> DiscreteMeasure<Loc<F, N>, F> | |
196 | where F : Float + ToNalgebraRealField, | |
197 | I : AlgIteratorFactory<IterInfo<F, N>>, | |
198 | for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, | |
199 | //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow | |
200 | A::Observable : std::ops::MulAssign<F>, | |
201 | GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, | |
202 | A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, | |
203 | BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, | |
204 | S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
205 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, | |
206 | Cube<F, N>: P2Minimise<Loc<F, N>, F>, | |
207 | PlotLookup : Plotting<N>, | |
208 | DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { | |
209 | ||
210 | // Set up parameters | |
211 | // We multiply tolerance by α for all algoritms. | |
212 | let tolerance = config.tolerance * α; | |
213 | let mut ε = tolerance.initial(); | |
214 | let findim_data = prepare_optimise_weights(opA); | |
215 | let m0 = b.norm2_squared() / (2.0 * α); | |
216 | let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) }; | |
217 | ||
218 | // Initialise operators | |
219 | let preadjA = opA.preadjoint(); | |
220 | ||
221 | // Initialise iterates | |
222 | let mut μ = DiscreteMeasure::new(); | |
223 | let mut residual = -b; | |
224 | ||
225 | let mut inner_iters = 0; | |
226 | let mut this_iters = 0; | |
227 | let mut pruned = 0; | |
228 | let mut merged = 0; | |
229 | ||
230 | // Run the algorithm | |
231 | iterator.iterate(|state| { | |
232 | // Update tolerance | |
233 | let inner_tolerance = ε * config.inner.tolerance_mult; | |
234 | let refinement_tolerance = ε * config.refinement.tolerance_mult; | |
235 | let ε_prev = ε; | |
236 | ε = tolerance.update(ε, state.iteration()); | |
237 | ||
238 | // Calculate smooth part of surrogate model. | |
239 | // | |
240 | // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` | |
241 | // has no significant overhead. For some reosn Rust doesn't allow us simply moving | |
242 | // the residual and replacing it below before the end of this closure. | |
243 | let r = std::mem::replace(&mut residual, opA.empty_observable()); | |
244 | let mut g = -preadjA.apply(r); | |
245 | ||
246 | // Find absolute value maximising point | |
247 | let (ξmax, v_ξmax) = g.maximise(refinement_tolerance, | |
248 | config.refinement.max_steps); | |
249 | let (ξmin, v_ξmin) = g.minimise(refinement_tolerance, | |
250 | config.refinement.max_steps); | |
251 | let (ξ, v_ξ) = if v_ξmin < 0.0 && -v_ξmin > v_ξmax { | |
252 | (ξmin, v_ξmin) | |
253 | } else { | |
254 | (ξmax, v_ξmax) | |
255 | }; | |
256 | ||
257 | let inner_it = match config.variant { | |
258 | FWVariant::FullyCorrective => { | |
259 | // No point in optimising the weight here: the finite-dimensional algorithm is fast. | |
260 | μ += DeltaMeasure { x : ξ, α : 0.0 }; | |
261 | config.inner.iterator_options.stop_target(inner_tolerance) | |
262 | }, | |
263 | FWVariant::Relaxed => { | |
264 | // Perform a relaxed initialisation of μ | |
265 | let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ }; | |
266 | let δ = DeltaMeasure { x : ξ, α : v }; | |
267 | let dp = μ.apply(&g) - δ.apply(&g); | |
268 | let d = opA.apply(&μ) - opA.apply(&δ); | |
269 | let r = d.norm2_squared(); | |
270 | let s = if r == 0.0 { | |
271 | 1.0 | |
272 | } else { | |
273 | 1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r) | |
274 | }; | |
275 | μ *= 1.0 - s; | |
276 | μ += δ * s; | |
277 | // The stop_target is only needed for the type system. | |
278 | AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0) | |
279 | } | |
280 | }; | |
281 | ||
282 | inner_iters += optimise_weights(&mut μ, opA, b, α, &findim_data, &config.inner, inner_it); | |
283 | ||
284 | // Merge spikes and update residual for next step and `if_verbose` below. | |
285 | let n_before_merge = μ.len(); | |
286 | residual = μ.merge_spikes_fitness(config.merging, | |
287 | |μ̃| opA.apply(μ̃) - b, | |
288 | A::Observable::norm2_squared); | |
289 | assert!(μ.len() >= n_before_merge); | |
290 | merged += μ.len() - n_before_merge; | |
291 | ||
292 | ||
293 | // Prune points with zero mass | |
294 | let n_before_prune = μ.len(); | |
295 | μ.prune(); | |
296 | debug_assert!(μ.len() <= n_before_prune); | |
297 | pruned += n_before_prune - μ.len(); | |
298 | ||
299 | this_iters +=1; | |
300 | ||
301 | // Give function value if needed | |
302 | state.if_verbose(|| { | |
303 | plotter.plot_spikes( | |
304 | format!("iter {} start", state.iteration()), &g, | |
305 | "".to_string(), None::<&A::PreadjointCodomain>, | |
306 | None, &μ | |
307 | ); | |
308 | let res = IterInfo { | |
309 | value : residual.norm2_squared_div2() + α * μ.norm(Radon), | |
310 | n_spikes : μ.len(), | |
311 | inner_iters, | |
312 | this_iters, | |
313 | merged, | |
314 | pruned, | |
315 | ε : ε_prev, | |
316 | maybe_ε1 : None, | |
317 | postprocessing : None, | |
318 | }; | |
319 | inner_iters = 0; | |
320 | this_iters = 0; | |
321 | merged = 0; | |
322 | pruned = 0; | |
323 | res | |
324 | }) | |
325 | }); | |
326 | ||
327 | // Return final iterate | |
328 | μ | |
329 | } | |
330 | ||
331 | ||
332 | ||
333 |