|
1 /*! |
|
2 Solver for the point source localisation problem using a conditional gradient method. |
|
3 |
|
4 We implement two variants, the “fully corrective” method from |
|
5 |
|
6 * Pieper K., Walter D. _Linear convergence of accelerated conditional gradient algorithms |
|
7 in spaces of measures_, DOI: [10.1051/cocv/2021042](https://doi.org/10.1051/cocv/2021042), |
|
8 arXiv: [1904.09218](https://doi.org/10.48550/arXiv.1904.09218). |
|
9 |
|
10 and what we call the “relaxed” method from |
|
11 |
|
12 * Bredies K., Pikkarainen H. - _Inverse problems in spaces of measures_, |
|
13 DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205). |
|
14 */ |
|
15 |
|
16 use numeric_literals::replace_float_literals; |
|
17 use serde::{Serialize, Deserialize}; |
|
18 //use colored::Colorize; |
|
19 |
|
20 use alg_tools::iterate::{ |
|
21 AlgIteratorFactory, |
|
22 AlgIteratorState, |
|
23 AlgIteratorOptions, |
|
24 }; |
|
25 use alg_tools::euclidean::Euclidean; |
|
26 use alg_tools::norms::Norm; |
|
27 use alg_tools::linops::Apply; |
|
28 use alg_tools::sets::Cube; |
|
29 use alg_tools::loc::Loc; |
|
30 use alg_tools::bisection_tree::{ |
|
31 BTFN, |
|
32 Bounds, |
|
33 BTNodeLookup, |
|
34 BTNode, |
|
35 BTSearch, |
|
36 P2Minimise, |
|
37 SupportGenerator, |
|
38 LocalAnalysis, |
|
39 }; |
|
40 use alg_tools::mapping::RealMapping; |
|
41 use alg_tools::nalgebra_support::ToNalgebraRealField; |
|
42 |
|
43 use crate::types::*; |
|
44 use crate::measures::{ |
|
45 DiscreteMeasure, |
|
46 DeltaMeasure, |
|
47 Radon, |
|
48 }; |
|
49 use crate::measures::merging::{ |
|
50 SpikeMergingMethod, |
|
51 SpikeMerging, |
|
52 }; |
|
53 use crate::forward_model::ForwardModel; |
|
54 #[allow(unused_imports)] // Used in documentation |
|
55 use crate::subproblem::{ |
|
56 quadratic_nonneg, |
|
57 InnerSettings, |
|
58 InnerMethod, |
|
59 }; |
|
60 use crate::tolerance::Tolerance; |
|
61 use crate::plot::{ |
|
62 SeqPlotter, |
|
63 Plotting, |
|
64 PlotLookup |
|
65 }; |
|
66 |
|
67 /// Settings for [`pointsource_fw`]. |
|
68 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
69 #[serde(default)] |
|
70 pub struct FWConfig<F : Float> { |
|
71 /// Tolerance for branch-and-bound new spike location discovery |
|
72 pub tolerance : Tolerance<F>, |
|
73 /// Inner problem solution configuration. Has to have `method` set to [`InnerMethod::FB`] |
|
74 /// as the conditional gradient subproblems' optimality conditions do not in general have an |
|
75 /// invertible Newton derivative for SSN. |
|
76 pub inner : InnerSettings<F>, |
|
77 /// Variant of the conditional gradient method |
|
78 pub variant : FWVariant, |
|
79 /// Settings for branch and bound refinement when looking for predual maxima |
|
80 pub refinement : RefinementSettings<F>, |
|
81 /// Spike merging heuristic |
|
82 pub merging : SpikeMergingMethod<F>, |
|
83 } |
|
84 |
|
85 /// Conditional gradient method variant; see also [`FWConfig`]. |
|
86 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
87 #[allow(dead_code)] |
|
88 pub enum FWVariant { |
|
89 /// Algorithm 2 of Walter-Pieper |
|
90 FullyCorrective, |
|
91 /// Bredies–Pikkarainen. Forces `FWConfig.inner.max_iter = 1`. |
|
92 Relaxed, |
|
93 } |
|
94 |
|
95 impl<F : Float> Default for FWConfig<F> { |
|
96 fn default() -> Self { |
|
97 FWConfig { |
|
98 tolerance : Default::default(), |
|
99 refinement : Default::default(), |
|
100 inner : Default::default(), |
|
101 variant : FWVariant::FullyCorrective, |
|
102 merging : Default::default(), |
|
103 } |
|
104 } |
|
105 } |
|
106 |
|
107 /// Helper struct for pre-initialising the finite-dimensional subproblems solver |
|
108 /// [`prepare_optimise_weights`]. |
|
109 /// |
|
110 /// The pre-initialisation is done by [`prepare_optimise_weights`]. |
|
111 pub struct FindimData<F : Float> { |
|
112 opAnorm_squared : F |
|
113 } |
|
114 |
|
115 /// Return a pre-initialisation struct for [`prepare_optimise_weights`]. |
|
116 /// |
|
117 /// The parameter `opA` is the forward operator $A$. |
|
118 pub fn prepare_optimise_weights<F, A, const N : usize>(opA : &A) -> FindimData<F> |
|
119 where F : Float + ToNalgebraRealField, |
|
120 A : ForwardModel<Loc<F, N>, F> { |
|
121 FindimData{ |
|
122 opAnorm_squared : opA.opnorm_bound().powi(2) |
|
123 } |
|
124 } |
|
125 |
|
126 /// Solve the finite-dimensional weight optimisation problem for the 2-norm-squared data fidelity |
|
127 /// point source localisation problem. |
|
128 /// |
|
129 /// That is, we minimise |
|
130 /// <div>$$ |
|
131 /// μ ↦ \frac{1}{2}\|Aμ-b\|_w^2 + α\|μ\|_ℳ + δ_{≥ 0}(μ) |
|
132 /// $$</div> |
|
133 /// only with respect to the weights of $μ$. |
|
134 /// |
|
135 /// The parameter `μ` is the discrete measure whose weights are to be optimised. |
|
136 /// The `opA` parameter is the forward operator $A$, while `b`$ and `α` are as in the |
|
137 /// objective above. The method parameter are set in `inner` (see [`InnerSettings`]), while |
|
138 /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to |
|
139 /// save intermediate iteration states as images. The parameter `findim_data` should be |
|
140 /// prepared using [`prepare_optimise_weights`]: |
|
141 /// |
|
142 /// Returns the number of iterations taken by the method configured in `inner`. |
|
143 pub fn optimise_weights<'a, F, A, I, const N : usize>( |
|
144 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
145 opA : &'a A, |
|
146 b : &A::Observable, |
|
147 α : F, |
|
148 findim_data : &FindimData<F>, |
|
149 inner : &InnerSettings<F>, |
|
150 iterator : I |
|
151 ) -> usize |
|
152 where F : Float + ToNalgebraRealField, |
|
153 I : AlgIteratorFactory<F>, |
|
154 A : ForwardModel<Loc<F, N>, F> |
|
155 { |
|
156 // Form and solve finite-dimensional subproblem. |
|
157 let (Ã, g̃) = opA.findim_quadratic_model(&μ, b); |
|
158 let mut x = μ.masses_dvector(); |
|
159 |
|
160 // `inner_τ1` is based on an estimate of the operator norm of $A$ from ℳ(Ω) to |
|
161 // ℝ^n. This estimate is a good one for the matrix norm from ℝ^m to ℝ^n when the |
|
162 // former is equipped with the 1-norm. We need the 2-norm. To pass from 1-norm to |
|
163 // 2-norm, we estimate |
|
164 // ‖A‖_{2,2} := sup_{‖x‖_2 ≤ 1} ‖Ax‖_2 ≤ sup_{‖x‖_1 ≤ C} ‖Ax‖_2 |
|
165 // = C sup_{‖x‖_1 ≤ 1} ‖Ax‖_2 = C ‖A‖_{1,2}, |
|
166 // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no |
|
167 // square root is needed when we scale: |
|
168 let inner_τ = inner.τ0 / (findim_data.opAnorm_squared * F::cast_from(μ.len())); |
|
169 let iters = quadratic_nonneg(inner.method, &Ã, &g̃, α, &mut x, inner_τ, iterator); |
|
170 // Update masses of μ based on solution of finite-dimensional subproblem. |
|
171 μ.set_masses_dvector(&x); |
|
172 |
|
173 iters |
|
174 } |
|
175 |
|
176 /// Solve point source localisation problem using a conditional gradient method |
|
177 /// for the 2-norm-squared data fidelity, i.e., the problem |
|
178 /// <div>$$ |
|
179 /// \min_μ \frac{1}{2}\|Aμ-b\|_w^2 + α\|μ\|_ℳ + δ_{≥ 0}(μ). |
|
180 /// $$</div> |
|
181 /// |
|
182 /// The `opA` parameter is the forward operator $A$, while `b`$ and `α` are as in the |
|
183 /// objective above. The method parameter are set in `config` (see [`FWConfig`]), while |
|
184 /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to |
|
185 /// save intermediate iteration states as images. |
|
186 #[replace_float_literals(F::cast_from(literal))] |
|
187 pub fn pointsource_fw<'a, F, I, A, GA, BTA, S, const N : usize>( |
|
188 opA : &'a A, |
|
189 b : &A::Observable, |
|
190 α : F, |
|
191 //domain : Cube<F, N>, |
|
192 config : &FWConfig<F>, |
|
193 iterator : I, |
|
194 mut plotter : SeqPlotter<F, N>, |
|
195 ) -> DiscreteMeasure<Loc<F, N>, F> |
|
196 where F : Float + ToNalgebraRealField, |
|
197 I : AlgIteratorFactory<IterInfo<F, N>>, |
|
198 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
|
199 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
|
200 A::Observable : std::ops::MulAssign<F>, |
|
201 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
202 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
|
203 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
204 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
205 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
206 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
|
207 PlotLookup : Plotting<N>, |
|
208 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { |
|
209 |
|
210 // Set up parameters |
|
211 // We multiply tolerance by α for all algoritms. |
|
212 let tolerance = config.tolerance * α; |
|
213 let mut ε = tolerance.initial(); |
|
214 let findim_data = prepare_optimise_weights(opA); |
|
215 let m0 = b.norm2_squared() / (2.0 * α); |
|
216 let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) }; |
|
217 |
|
218 // Initialise operators |
|
219 let preadjA = opA.preadjoint(); |
|
220 |
|
221 // Initialise iterates |
|
222 let mut μ = DiscreteMeasure::new(); |
|
223 let mut residual = -b; |
|
224 |
|
225 let mut inner_iters = 0; |
|
226 let mut this_iters = 0; |
|
227 let mut pruned = 0; |
|
228 let mut merged = 0; |
|
229 |
|
230 // Run the algorithm |
|
231 iterator.iterate(|state| { |
|
232 // Update tolerance |
|
233 let inner_tolerance = ε * config.inner.tolerance_mult; |
|
234 let refinement_tolerance = ε * config.refinement.tolerance_mult; |
|
235 let ε_prev = ε; |
|
236 ε = tolerance.update(ε, state.iteration()); |
|
237 |
|
238 // Calculate smooth part of surrogate model. |
|
239 // |
|
240 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
|
241 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
242 // the residual and replacing it below before the end of this closure. |
|
243 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
244 let mut g = -preadjA.apply(r); |
|
245 |
|
246 // Find absolute value maximising point |
|
247 let (ξmax, v_ξmax) = g.maximise(refinement_tolerance, |
|
248 config.refinement.max_steps); |
|
249 let (ξmin, v_ξmin) = g.minimise(refinement_tolerance, |
|
250 config.refinement.max_steps); |
|
251 let (ξ, v_ξ) = if v_ξmin < 0.0 && -v_ξmin > v_ξmax { |
|
252 (ξmin, v_ξmin) |
|
253 } else { |
|
254 (ξmax, v_ξmax) |
|
255 }; |
|
256 |
|
257 let inner_it = match config.variant { |
|
258 FWVariant::FullyCorrective => { |
|
259 // No point in optimising the weight here: the finite-dimensional algorithm is fast. |
|
260 μ += DeltaMeasure { x : ξ, α : 0.0 }; |
|
261 config.inner.iterator_options.stop_target(inner_tolerance) |
|
262 }, |
|
263 FWVariant::Relaxed => { |
|
264 // Perform a relaxed initialisation of μ |
|
265 let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ }; |
|
266 let δ = DeltaMeasure { x : ξ, α : v }; |
|
267 let dp = μ.apply(&g) - δ.apply(&g); |
|
268 let d = opA.apply(&μ) - opA.apply(&δ); |
|
269 let r = d.norm2_squared(); |
|
270 let s = if r == 0.0 { |
|
271 1.0 |
|
272 } else { |
|
273 1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r) |
|
274 }; |
|
275 μ *= 1.0 - s; |
|
276 μ += δ * s; |
|
277 // The stop_target is only needed for the type system. |
|
278 AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0) |
|
279 } |
|
280 }; |
|
281 |
|
282 inner_iters += optimise_weights(&mut μ, opA, b, α, &findim_data, &config.inner, inner_it); |
|
283 |
|
284 // Merge spikes and update residual for next step and `if_verbose` below. |
|
285 let n_before_merge = μ.len(); |
|
286 residual = μ.merge_spikes_fitness(config.merging, |
|
287 |μ̃| opA.apply(μ̃) - b, |
|
288 A::Observable::norm2_squared); |
|
289 assert!(μ.len() >= n_before_merge); |
|
290 merged += μ.len() - n_before_merge; |
|
291 |
|
292 |
|
293 // Prune points with zero mass |
|
294 let n_before_prune = μ.len(); |
|
295 μ.prune(); |
|
296 debug_assert!(μ.len() <= n_before_prune); |
|
297 pruned += n_before_prune - μ.len(); |
|
298 |
|
299 this_iters +=1; |
|
300 |
|
301 // Give function value if needed |
|
302 state.if_verbose(|| { |
|
303 plotter.plot_spikes( |
|
304 format!("iter {} start", state.iteration()), &g, |
|
305 "".to_string(), None::<&A::PreadjointCodomain>, |
|
306 None, &μ |
|
307 ); |
|
308 let res = IterInfo { |
|
309 value : residual.norm2_squared_div2() + α * μ.norm(Radon), |
|
310 n_spikes : μ.len(), |
|
311 inner_iters, |
|
312 this_iters, |
|
313 merged, |
|
314 pruned, |
|
315 ε : ε_prev, |
|
316 maybe_ε1 : None, |
|
317 postprocessing : None, |
|
318 }; |
|
319 inner_iters = 0; |
|
320 this_iters = 0; |
|
321 merged = 0; |
|
322 pruned = 0; |
|
323 res |
|
324 }) |
|
325 }); |
|
326 |
|
327 // Return final iterate |
|
328 μ |
|
329 } |
|
330 |
|
331 |
|
332 |
|
333 |