Tue, 29 Nov 2022 15:36:12 +0200
fubar
0 | 1 | /*! |
2 | Solver for the point source localisation problem using a conditional gradient method. | |
3 | ||
4 | We implement two variants, the “fully corrective” method from | |
5 | ||
6 | * Pieper K., Walter D. _Linear convergence of accelerated conditional gradient algorithms | |
7 | in spaces of measures_, DOI: [10.1051/cocv/2021042](https://doi.org/10.1051/cocv/2021042), | |
8 | arXiv: [1904.09218](https://doi.org/10.48550/arXiv.1904.09218). | |
9 | ||
10 | and what we call the “relaxed” method from | |
11 | ||
12 | * Bredies K., Pikkarainen H. - _Inverse problems in spaces of measures_, | |
13 | DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205). | |
14 | */ | |
15 | ||
16 | use numeric_literals::replace_float_literals; | |
17 | use serde::{Serialize, Deserialize}; | |
18 | //use colored::Colorize; | |
19 | ||
20 | use alg_tools::iterate::{ | |
21 | AlgIteratorFactory, | |
22 | AlgIteratorState, | |
23 | AlgIteratorOptions, | |
24 | }; | |
25 | use alg_tools::euclidean::Euclidean; | |
26 | use alg_tools::norms::Norm; | |
27 | use alg_tools::linops::Apply; | |
28 | use alg_tools::sets::Cube; | |
29 | use alg_tools::loc::Loc; | |
30 | use alg_tools::bisection_tree::{ | |
31 | BTFN, | |
32 | Bounds, | |
33 | BTNodeLookup, | |
34 | BTNode, | |
35 | BTSearch, | |
36 | P2Minimise, | |
37 | SupportGenerator, | |
38 | LocalAnalysis, | |
39 | }; | |
40 | use alg_tools::mapping::RealMapping; | |
41 | use alg_tools::nalgebra_support::ToNalgebraRealField; | |
42 | ||
43 | use crate::types::*; | |
44 | use crate::measures::{ | |
45 | DiscreteMeasure, | |
46 | DeltaMeasure, | |
47 | Radon, | |
48 | }; | |
49 | use crate::measures::merging::{ | |
50 | SpikeMergingMethod, | |
51 | SpikeMerging, | |
52 | }; | |
53 | use crate::forward_model::ForwardModel; | |
54 | #[allow(unused_imports)] // Used in documentation | |
2 | 55 | use crate::subproblem::InnerSettings; |
56 | use crate::weight_optim::{ | |
57 | prepare_optimise_weights, | |
58 | optimise_weights_l2 | |
0 | 59 | }; |
60 | use crate::tolerance::Tolerance; | |
61 | use crate::plot::{ | |
62 | SeqPlotter, | |
63 | Plotting, | |
64 | PlotLookup | |
65 | }; | |
66 | ||
67 | /// Settings for [`pointsource_fw`]. | |
68 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
69 | #[serde(default)] | |
70 | pub struct FWConfig<F : Float> { | |
71 | /// Tolerance for branch-and-bound new spike location discovery | |
72 | pub tolerance : Tolerance<F>, | |
73 | /// Inner problem solution configuration. Has to have `method` set to [`InnerMethod::FB`] | |
74 | /// as the conditional gradient subproblems' optimality conditions do not in general have an | |
75 | /// invertible Newton derivative for SSN. | |
76 | pub inner : InnerSettings<F>, | |
77 | /// Variant of the conditional gradient method | |
78 | pub variant : FWVariant, | |
79 | /// Settings for branch and bound refinement when looking for predual maxima | |
80 | pub refinement : RefinementSettings<F>, | |
81 | /// Spike merging heuristic | |
82 | pub merging : SpikeMergingMethod<F>, | |
83 | } | |
84 | ||
85 | /// Conditional gradient method variant; see also [`FWConfig`]. | |
86 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
87 | #[allow(dead_code)] | |
88 | pub enum FWVariant { | |
89 | /// Algorithm 2 of Walter-Pieper | |
90 | FullyCorrective, | |
91 | /// Bredies–Pikkarainen. Forces `FWConfig.inner.max_iter = 1`. | |
92 | Relaxed, | |
93 | } | |
94 | ||
95 | impl<F : Float> Default for FWConfig<F> { | |
96 | fn default() -> Self { | |
97 | FWConfig { | |
98 | tolerance : Default::default(), | |
99 | refinement : Default::default(), | |
100 | inner : Default::default(), | |
101 | variant : FWVariant::FullyCorrective, | |
102 | merging : Default::default(), | |
103 | } | |
104 | } | |
105 | } | |
106 | ||
107 | /// Solve point source localisation problem using a conditional gradient method | |
108 | /// for the 2-norm-squared data fidelity, i.e., the problem | |
109 | /// <div>$$ | |
110 | /// \min_μ \frac{1}{2}\|Aμ-b\|_w^2 + α\|μ\|_ℳ + δ_{≥ 0}(μ). | |
111 | /// $$</div> | |
112 | /// | |
113 | /// The `opA` parameter is the forward operator $A$, while `b`$ and `α` are as in the | |
114 | /// objective above. The method parameter are set in `config` (see [`FWConfig`]), while | |
115 | /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to | |
116 | /// save intermediate iteration states as images. | |
117 | #[replace_float_literals(F::cast_from(literal))] | |
118 | pub fn pointsource_fw<'a, F, I, A, GA, BTA, S, const N : usize>( | |
119 | opA : &'a A, | |
120 | b : &A::Observable, | |
121 | α : F, | |
122 | //domain : Cube<F, N>, | |
123 | config : &FWConfig<F>, | |
124 | iterator : I, | |
125 | mut plotter : SeqPlotter<F, N>, | |
126 | ) -> DiscreteMeasure<Loc<F, N>, F> | |
127 | where F : Float + ToNalgebraRealField, | |
128 | I : AlgIteratorFactory<IterInfo<F, N>>, | |
129 | for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, | |
130 | //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow | |
131 | A::Observable : std::ops::MulAssign<F>, | |
132 | GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, | |
133 | A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, | |
134 | BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, | |
135 | S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, | |
136 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, | |
137 | Cube<F, N>: P2Minimise<Loc<F, N>, F>, | |
138 | PlotLookup : Plotting<N>, | |
139 | DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { | |
140 | ||
141 | // Set up parameters | |
142 | // We multiply tolerance by α for all algoritms. | |
143 | let tolerance = config.tolerance * α; | |
144 | let mut ε = tolerance.initial(); | |
145 | let findim_data = prepare_optimise_weights(opA); | |
146 | let m0 = b.norm2_squared() / (2.0 * α); | |
147 | let φ = |t| if t <= m0 { α * t } else { α / (2.0 * m0) * (t*t + m0 * m0) }; | |
148 | ||
149 | // Initialise operators | |
150 | let preadjA = opA.preadjoint(); | |
151 | ||
152 | // Initialise iterates | |
153 | let mut μ = DiscreteMeasure::new(); | |
154 | let mut residual = -b; | |
155 | ||
156 | let mut inner_iters = 0; | |
157 | let mut this_iters = 0; | |
158 | let mut pruned = 0; | |
159 | let mut merged = 0; | |
160 | ||
161 | // Run the algorithm | |
162 | iterator.iterate(|state| { | |
163 | // Update tolerance | |
164 | let inner_tolerance = ε * config.inner.tolerance_mult; | |
165 | let refinement_tolerance = ε * config.refinement.tolerance_mult; | |
166 | let ε_prev = ε; | |
167 | ε = tolerance.update(ε, state.iteration()); | |
168 | ||
169 | // Calculate smooth part of surrogate model. | |
170 | // | |
171 | // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` | |
172 | // has no significant overhead. For some reosn Rust doesn't allow us simply moving | |
173 | // the residual and replacing it below before the end of this closure. | |
174 | let r = std::mem::replace(&mut residual, opA.empty_observable()); | |
175 | let mut g = -preadjA.apply(r); | |
176 | ||
177 | // Find absolute value maximising point | |
178 | let (ξmax, v_ξmax) = g.maximise(refinement_tolerance, | |
179 | config.refinement.max_steps); | |
180 | let (ξmin, v_ξmin) = g.minimise(refinement_tolerance, | |
181 | config.refinement.max_steps); | |
182 | let (ξ, v_ξ) = if v_ξmin < 0.0 && -v_ξmin > v_ξmax { | |
183 | (ξmin, v_ξmin) | |
184 | } else { | |
185 | (ξmax, v_ξmax) | |
186 | }; | |
187 | ||
188 | let inner_it = match config.variant { | |
189 | FWVariant::FullyCorrective => { | |
190 | // No point in optimising the weight here: the finite-dimensional algorithm is fast. | |
191 | μ += DeltaMeasure { x : ξ, α : 0.0 }; | |
192 | config.inner.iterator_options.stop_target(inner_tolerance) | |
193 | }, | |
194 | FWVariant::Relaxed => { | |
195 | // Perform a relaxed initialisation of μ | |
196 | let v = if v_ξ.abs() <= α { 0.0 } else { m0 / α * v_ξ }; | |
197 | let δ = DeltaMeasure { x : ξ, α : v }; | |
198 | let dp = μ.apply(&g) - δ.apply(&g); | |
199 | let d = opA.apply(&μ) - opA.apply(&δ); | |
200 | let r = d.norm2_squared(); | |
201 | let s = if r == 0.0 { | |
202 | 1.0 | |
203 | } else { | |
204 | 1.0.min( (α * μ.norm(Radon) - φ(v.abs()) - dp) / r) | |
205 | }; | |
206 | μ *= 1.0 - s; | |
207 | μ += δ * s; | |
208 | // The stop_target is only needed for the type system. | |
209 | AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0) | |
210 | } | |
211 | }; | |
212 | ||
2 | 213 | inner_iters += optimise_weights_l2(&mut μ, opA, b, α, &findim_data, &config.inner, inner_it); |
0 | 214 | |
215 | // Merge spikes and update residual for next step and `if_verbose` below. | |
216 | let n_before_merge = μ.len(); | |
217 | residual = μ.merge_spikes_fitness(config.merging, | |
218 | |μ̃| opA.apply(μ̃) - b, | |
219 | A::Observable::norm2_squared); | |
220 | assert!(μ.len() >= n_before_merge); | |
221 | merged += μ.len() - n_before_merge; | |
222 | ||
223 | ||
224 | // Prune points with zero mass | |
225 | let n_before_prune = μ.len(); | |
226 | μ.prune(); | |
227 | debug_assert!(μ.len() <= n_before_prune); | |
228 | pruned += n_before_prune - μ.len(); | |
229 | ||
230 | this_iters +=1; | |
231 | ||
232 | // Give function value if needed | |
233 | state.if_verbose(|| { | |
234 | plotter.plot_spikes( | |
235 | format!("iter {} start", state.iteration()), &g, | |
236 | "".to_string(), None::<&A::PreadjointCodomain>, | |
237 | None, &μ | |
238 | ); | |
239 | let res = IterInfo { | |
240 | value : residual.norm2_squared_div2() + α * μ.norm(Radon), | |
241 | n_spikes : μ.len(), | |
242 | inner_iters, | |
243 | this_iters, | |
244 | merged, | |
245 | pruned, | |
246 | ε : ε_prev, | |
247 | maybe_ε1 : None, | |
248 | postprocessing : None, | |
249 | }; | |
250 | inner_iters = 0; | |
251 | this_iters = 0; | |
252 | merged = 0; | |
253 | pruned = 0; | |
254 | res | |
255 | }) | |
256 | }); | |
257 | ||
258 | // Return final iterate | |
259 | μ | |
260 | } | |
261 | ||
262 | ||
263 | ||
264 |