src/frank_wolfe.rs

changeset 0
eb3c7813b67a
child 2
7a953a87b6c1
child 4
5aa5c279e341
child 8
ea3ca78873e8
equal deleted inserted replaced
-1:000000000000 0:eb3c7813b67a
1 /*!
2 Solver for the point source localisation problem using a conditional gradient method.
3
4 We implement two variants, the “fully corrective” method from
5
6 * Pieper K., Walter D. _Linear convergence of accelerated conditional gradient algorithms
7 in spaces of measures_, DOI: [10.1051/cocv/2021042](https://doi.org/10.1051/cocv/2021042),
8 arXiv: [1904.09218](https://doi.org/10.48550/arXiv.1904.09218).
9
10 and what we call the “relaxed” method from
11
12 * Bredies K., Pikkarainen H. - _Inverse problems in spaces of measures_,
13 DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205).
14 */
15
16 use numeric_literals::replace_float_literals;
17 use serde::{Serialize, Deserialize};
18 //use colored::Colorize;
19
20 use alg_tools::iterate::{
21 AlgIteratorFactory,
22 AlgIteratorState,
23 AlgIteratorOptions,
24 };
25 use alg_tools::euclidean::Euclidean;
26 use alg_tools::norms::Norm;
27 use alg_tools::linops::Apply;
28 use alg_tools::sets::Cube;
29 use alg_tools::loc::Loc;
30 use alg_tools::bisection_tree::{
31 BTFN,
32 Bounds,
33 BTNodeLookup,
34 BTNode,
35 BTSearch,
36 P2Minimise,
37 SupportGenerator,
38 LocalAnalysis,
39 };
40 use alg_tools::mapping::RealMapping;
41 use alg_tools::nalgebra_support::ToNalgebraRealField;
42
43 use crate::types::*;
44 use crate::measures::{
45 DiscreteMeasure,
46 DeltaMeasure,
47 Radon,
48 };
49 use crate::measures::merging::{
50 SpikeMergingMethod,
51 SpikeMerging,
52 };
53 use crate::forward_model::ForwardModel;
54 #[allow(unused_imports)] // Used in documentation
55 use crate::subproblem::{
56 quadratic_nonneg,
57 InnerSettings,
58 InnerMethod,
59 };
60 use crate::tolerance::Tolerance;
61 use crate::plot::{
62 SeqPlotter,
63 Plotting,
64 PlotLookup
65 };
66
67 /// Settings for [`pointsource_fw`].
68 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
69 #[serde(default)]
70 pub struct FWConfig<F : Float> {
71 /// Tolerance for branch-and-bound new spike location discovery
72 pub tolerance : Tolerance<F>,
73 /// Inner problem solution configuration. Has to have `method` set to [`InnerMethod::FB`]
74 /// as the conditional gradient subproblems' optimality conditions do not in general have an
75 /// invertible Newton derivative for SSN.
76 pub inner : InnerSettings<F>,
77 /// Variant of the conditional gradient method
78 pub variant : FWVariant,
79 /// Settings for branch and bound refinement when looking for predual maxima
80 pub refinement : RefinementSettings<F>,
81 /// Spike merging heuristic
82 pub merging : SpikeMergingMethod<F>,
83 }
84
85 /// Conditional gradient method variant; see also [`FWConfig`].
86 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
87 #[allow(dead_code)]
88 pub enum FWVariant {
89 /// Algorithm 2 of Walter-Pieper
90 FullyCorrective,
91 /// Bredies–Pikkarainen. Forces `FWConfig.inner.max_iter = 1`.
92 Relaxed,
93 }
94
95 impl<F : Float> Default for FWConfig<F> {
96 fn default() -> Self {
97 FWConfig {
98 tolerance : Default::default(),
99 refinement : Default::default(),
100 inner : Default::default(),
101 variant : FWVariant::FullyCorrective,
102 merging : Default::default(),
103 }
104 }
105 }
106
107 /// Helper struct for pre-initialising the finite-dimensional subproblems solver
108 /// [`prepare_optimise_weights`].
109 ///
110 /// The pre-initialisation is done by [`prepare_optimise_weights`].
111 pub struct FindimData<F : Float> {
112 opAnorm_squared : F
113 }
114
115 /// Return a pre-initialisation struct for [`prepare_optimise_weights`].
116 ///
117 /// The parameter `opA` is the forward operator $A$.
118 pub fn prepare_optimise_weights<F, A, const N : usize>(opA : &A) -> FindimData<F>
119 where F : Float + ToNalgebraRealField,
120 A : ForwardModel<Loc<F, N>, F> {
121 FindimData{
122 opAnorm_squared : opA.opnorm_bound().powi(2)
123 }
124 }
125
126 /// Solve the finite-dimensional weight optimisation problem for the 2-norm-squared data fidelity
127 /// point source localisation problem.
128 ///
129 /// That is, we minimise
130 /// <div>$$
131 /// μ ↦ \frac{1}{2}\|Aμ-b\|_w^2 + α\|μ\|_ℳ + δ_{≥ 0}(μ)
132 /// $$</div>
133 /// only with respect to the weights of $μ$.
134 ///
135 /// The parameter `μ` is the discrete measure whose weights are to be optimised.
136 /// The `opA` parameter is the forward operator $A$, while `b`$ and `α` are as in the
137 /// objective above. The method parameter are set in `inner` (see [`InnerSettings`]), while
138 /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to
139 /// save intermediate iteration states as images. The parameter `findim_data` should be
140 /// prepared using [`prepare_optimise_weights`]:
141 ///
142 /// Returns the number of iterations taken by the method configured in `inner`.
143 pub fn optimise_weights<'a, F, A, I, const N : usize>(
144 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
145 opA : &'a A,
146 b : &A::Observable,
147 α : F,
148 findim_data : &FindimData<F>,
149 inner : &InnerSettings<F>,
150 iterator : I
151 ) -> usize
152 where F : Float + ToNalgebraRealField,
153 I : AlgIteratorFactory<F>,
154 A : ForwardModel<Loc<F, N>, F>
155 {
156 // Form and solve finite-dimensional subproblem.
157 let (Ã, g̃) = opA.findim_quadratic_model(&μ, b);
158 let mut x = μ.masses_dvector();
159
160 // `inner_τ1` is based on an estimate of the operator norm of $A$ from ℳ(Ω) to
161 // ℝ^n. This estimate is a good one for the matrix norm from ℝ^m to ℝ^n when the
162 // former is equipped with the 1-norm. We need the 2-norm. To pass from 1-norm to
163 // 2-norm, we estimate
164 // ‖A‖_{2,2} := sup_{‖x‖_2 ≤ 1} ‖Ax‖_2 ≤ sup_{‖x‖_1 ≤ C} ‖Ax‖_2
165 // = C sup_{‖x‖_1 ≤ 1} ‖Ax‖_2 = C ‖A‖_{1,2},
166 // where C = √m satisfies ‖x‖_1 ≤ C ‖x‖_2. Since we are intested in ‖A_*A‖, no
167 // square root is needed when we scale:
168 let inner_τ = inner.τ0 / (findim_data.opAnorm_squared * F::cast_from(μ.len()));
169 let iters = quadratic_nonneg(inner.method, &Ã, &g̃, α, &mut x, inner_τ, iterator);
170 // Update masses of μ based on solution of finite-dimensional subproblem.
171 μ.set_masses_dvector(&x);
172
173 iters
174 }
175
176 /// Solve point source localisation problem using a conditional gradient method
177 /// for the 2-norm-squared data fidelity, i.e., the problem
178 /// <div>$$
179 /// \min_μ \frac{1}{2}\|Aμ-b\|_w^2 + α\|μ\|_ℳ + δ_{≥ 0}(μ).
180 /// $$</div>
181 ///
182 /// The `opA` parameter is the forward operator $A$, while `b`$ and `α` are as in the
183 /// objective above. The method parameter are set in `config` (see [`FWConfig`]), while
184 /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to
185 /// save intermediate iteration states as images.
186 #[replace_float_literals(F::cast_from(literal))]
187 pub fn pointsource_fw<'a, F, I, A, GA, BTA, S, const N : usize>(
188 opA : &'a A,
189 b : &A::Observable,
190 α : F,
191 //domain : Cube<F, N>,
192 config : &FWConfig<F>,
193 iterator : I,
194 mut plotter : SeqPlotter<F, N>,
195 ) -> DiscreteMeasure<Loc<F, N>, F>
196 where F : Float + ToNalgebraRealField,
197 I : AlgIteratorFactory<IterInfo<F, N>>,
198 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
199 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
200 A::Observable : std::ops::MulAssign<F>,
201 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
202 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
203 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
204 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
205 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
206 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
207 PlotLookup : Plotting<N>,
208 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
209
210 // Set up parameters
211 // We multiply tolerance by α for all algoritms.
212 let tolerance = config.tolerance * α;
213 let mut ε = tolerance.initial();
214 let findim_data = prepare_optimise_weights(opA);
215 let m0 = b.norm2_squared() / (2.0 * α);
216 let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) };
217
218 // Initialise operators
219 let preadjA = opA.preadjoint();
220
221 // Initialise iterates
222 let mut μ = DiscreteMeasure::new();
223 let mut residual = -b;
224
225 let mut inner_iters = 0;
226 let mut this_iters = 0;
227 let mut pruned = 0;
228 let mut merged = 0;
229
230 // Run the algorithm
231 iterator.iterate(|state| {
232 // Update tolerance
233 let inner_tolerance = ε * config.inner.tolerance_mult;
234 let refinement_tolerance = ε * config.refinement.tolerance_mult;
235 let ε_prev = ε;
236 ε = tolerance.update(ε, state.iteration());
237
238 // Calculate smooth part of surrogate model.
239 //
240 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
241 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
242 // the residual and replacing it below before the end of this closure.
243 let r = std::mem::replace(&mut residual, opA.empty_observable());
244 let mut g = -preadjA.apply(r);
245
246 // Find absolute value maximising point
247 let (ξmax, v_ξmax) = g.maximise(refinement_tolerance,
248 config.refinement.max_steps);
249 let (ξmin, v_ξmin) = g.minimise(refinement_tolerance,
250 config.refinement.max_steps);
251 let (ξ, v_ξ) = if v_ξmin < 0.0 && -v_ξmin > v_ξmax {
252 (ξmin, v_ξmin)
253 } else {
254 (ξmax, v_ξmax)
255 };
256
257 let inner_it = match config.variant {
258 FWVariant::FullyCorrective => {
259 // No point in optimising the weight here: the finite-dimensional algorithm is fast.
260 μ += DeltaMeasure { x : ξ, α : 0.0 };
261 config.inner.iterator_options.stop_target(inner_tolerance)
262 },
263 FWVariant::Relaxed => {
264 // Perform a relaxed initialisation of μ
265 let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ };
266 let δ = DeltaMeasure { x : ξ, α : v };
267 let dp = μ.apply(&g) - δ.apply(&g);
268 let d = opA.apply(&μ) - opA.apply(&δ);
269 let r = d.norm2_squared();
270 let s = if r == 0.0 {
271 1.0
272 } else {
273 1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r)
274 };
275 μ *= 1.0 - s;
276 μ += δ * s;
277 // The stop_target is only needed for the type system.
278 AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0)
279 }
280 };
281
282 inner_iters += optimise_weights(&mut μ, opA, b, α, &findim_data, &config.inner, inner_it);
283
284 // Merge spikes and update residual for next step and `if_verbose` below.
285 let n_before_merge = μ.len();
286 residual = μ.merge_spikes_fitness(config.merging,
287 |μ̃| opA.apply(μ̃) - b,
288 A::Observable::norm2_squared);
289 assert!(μ.len() >= n_before_merge);
290 merged += μ.len() - n_before_merge;
291
292
293 // Prune points with zero mass
294 let n_before_prune = μ.len();
295 μ.prune();
296 debug_assert!(μ.len() <= n_before_prune);
297 pruned += n_before_prune - μ.len();
298
299 this_iters +=1;
300
301 // Give function value if needed
302 state.if_verbose(|| {
303 plotter.plot_spikes(
304 format!("iter {} start", state.iteration()), &g,
305 "".to_string(), None::<&A::PreadjointCodomain>,
306 None, &μ
307 );
308 let res = IterInfo {
309 value : residual.norm2_squared_div2() + α * μ.norm(Radon),
310 n_spikes : μ.len(),
311 inner_iters,
312 this_iters,
313 merged,
314 pruned,
315 ε : ε_prev,
316 maybe_ε1 : None,
317 postprocessing : None,
318 };
319 inner_iters = 0;
320 this_iters = 0;
321 merged = 0;
322 pruned = 0;
323 res
324 })
325 });
326
327 // Return final iterate
328 μ
329 }
330
331
332
333

mercurial