src/radon_fb.rs

branch
dev
changeset 37
c5d8bd1a7728
parent 36
fb911f72e698
child 38
0f59c0d02e13
child 39
6316d68b58af
equal deleted inserted replaced
36:fb911f72e698 37:c5d8bd1a7728
1 /*!
2 Solver for the point source localisation problem using a simplified forward-backward splitting method.
3
4 Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map.
5 */
6
7 use numeric_literals::replace_float_literals;
8 use serde::{Serialize, Deserialize};
9 use colored::Colorize;
10 use nalgebra::DVector;
11
12 use alg_tools::iterate::{
13 AlgIteratorFactory,
14 AlgIteratorIteration,
15 AlgIterator
16 };
17 use alg_tools::euclidean::Euclidean;
18 use alg_tools::linops::Mapping;
19 use alg_tools::sets::Cube;
20 use alg_tools::loc::Loc;
21 use alg_tools::bisection_tree::{
22 BTFN,
23 Bounds,
24 BTNodeLookup,
25 BTNode,
26 BTSearch,
27 P2Minimise,
28 SupportGenerator,
29 LocalAnalysis,
30 };
31 use alg_tools::mapping::RealMapping;
32 use alg_tools::nalgebra_support::ToNalgebraRealField;
33 use alg_tools::norms::L2;
34
35 use crate::types::*;
36 use crate::measures::{
37 RNDM,
38 DiscreteMeasure,
39 DeltaMeasure,
40 Radon,
41 };
42 use crate::measures::merging::{
43 SpikeMergingMethod,
44 SpikeMerging,
45 };
46 use crate::forward_model::ForwardModel;
47 use crate::plot::{
48 SeqPlotter,
49 Plotting,
50 PlotLookup
51 };
52 use crate::regularisation::RegTerm;
53 use crate::dataterm::{
54 calculate_residual,
55 L2Squared,
56 DataTerm,
57 };
58
59 use crate::fb::{
60 FBGenericConfig,
61 postprocess,
62 prune_with_stats
63 };
64
65 /// Settings for [`pointsource_radon_fb_reg`].
66 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
67 #[serde(default)]
68 pub struct RadonFBConfig<F : Float> {
69 /// Step length scaling
70 pub τ0 : F,
71 /// Generic parameters
72 pub insertion : FBGenericConfig<F>,
73 }
74
75 #[replace_float_literals(F::cast_from(literal))]
76 impl<F : Float> Default for RadonFBConfig<F> {
77 fn default() -> Self {
78 RadonFBConfig {
79 τ0 : 0.99,
80 insertion : Default::default()
81 }
82 }
83 }
84
85 #[replace_float_literals(F::cast_from(literal))]
86 pub(crate) fn insert_and_reweigh<
87 'a, F, GA, BTA, S, Reg, I, const N : usize
88 >(
89 μ : &mut RNDM<F, N>,
90 τv : &mut BTFN<F, GA, BTA, N>,
91 μ_base : &mut RNDM<F, N>,
92 //_ν_delta: Option<&RNDM<F, N>>,
93 τ : F,
94 ε : F,
95 config : &FBGenericConfig<F>,
96 reg : &Reg,
97 _state : &AlgIteratorIteration<I>,
98 stats : &mut IterInfo<F, N>,
99 )
100 where F : Float + ToNalgebraRealField,
101 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
102 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
103 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
104 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
105 RNDM<F, N> : SpikeMerging<F>,
106 Reg : RegTerm<F, N>,
107 I : AlgIterator {
108
109 'i_and_w: for i in 0..=1 {
110 // Optimise weights
111 if μ.len() > 0 {
112 // Form finite-dimensional subproblem. The subproblem references to the original μ^k
113 // from the beginning of the iteration are all contained in the immutable c and g.
114 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
115 // problems have not yet been updated to sign change.
116 let g̃ = DVector::from_iterator(μ.len(),
117 μ.iter_locations()
118 .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ))));
119 let mut x = μ.masses_dvector();
120 let y = μ_base.masses_dvector();
121
122 // Solve finite-dimensional subproblem.
123 stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config);
124
125 // Update masses of μ based on solution of finite-dimensional subproblem.
126 μ.set_masses_dvector(&x);
127 }
128
129 if i>0 {
130 // Simple debugging test to see if more inserts would be needed. Doesn't seem so.
131 //let n = μ.dist_matching(μ_base);
132 //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n));
133 break 'i_and_w
134 }
135
136 // Calculate ‖μ - μ_base‖_ℳ
137 let n = μ.dist_matching(μ_base);
138
139 // Find a spike to insert, if needed.
140 // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ,
141 // which are supposed to have been guaranteed by the finite-dimensional weight optimisation.
142 match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) {
143 None => { break 'i_and_w },
144 Some((ξ, _v_ξ, _in_bounds)) => {
145 // Weight is found out by running the finite-dimensional optimisation algorithm
146 // above
147 *μ += DeltaMeasure { x : ξ, α : 0.0 };
148 *μ_base += DeltaMeasure { x : ξ, α : 0.0 };
149 stats.inserted += 1;
150 }
151 };
152 }
153 }
154
155
156 /// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting.
157 ///
158 /// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the
159 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
160 /// Finally, the `iterator` is an outer loop verbosity and iteration count control
161 /// as documented in [`alg_tools::iterate`].
162 ///
163 /// For details on the mathematical formulation, see the [module level](self) documentation.
164 ///
165 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
166 /// sums of simple functions usign bisection trees, and the related
167 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
168 /// active at a specific points, and to maximise their sums. Through the implementation of the
169 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
170 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
171 ///
172 /// Returns the final iterate.
173 #[replace_float_literals(F::cast_from(literal))]
174 pub fn pointsource_radon_fb_reg<
175 'a, F, I, A, GA, BTA, S, Reg, const N : usize
176 >(
177 opA : &'a A,
178 b : &A::Observable,
179 reg : Reg,
180 fbconfig : &RadonFBConfig<F>,
181 iterator : I,
182 mut _plotter : SeqPlotter<F, N>,
183 ) -> RNDM<F, N>
184 where F : Float + ToNalgebraRealField,
185 I : AlgIteratorFactory<IterInfo<F, N>>,
186 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
187 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
188 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
189 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
190 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
191 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
192 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
193 RNDM<F, N> : SpikeMerging<F>,
194 Reg : RegTerm<F, N> {
195
196 // Set up parameters
197 let config = &fbconfig.insertion;
198 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ
199 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such
200 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L.
201 let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2);
202 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
203 // by τ compared to the conditional gradient approach.
204 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
205 let mut ε = tolerance.initial();
206
207 // Initialise iterates
208 let mut μ = DiscreteMeasure::new();
209 let mut residual = -b;
210
211 // Statistics
212 let full_stats = |residual : &A::Observable,
213 μ : &RNDM<F, N>,
214 ε, stats| IterInfo {
215 value : residual.norm2_squared_div2() + reg.apply(μ),
216 n_spikes : μ.len(),
217 ε,
218 // postprocessing: config.postprocessing.then(|| μ.clone()),
219 .. stats
220 };
221 let mut stats = IterInfo::new();
222
223 // Run the algorithm
224 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
225 // Calculate smooth part of surrogate model.
226 let mut τv = opA.preadjoint().apply(residual * τ);
227
228 // Save current base point
229 let mut μ_base = μ.clone();
230
231 // Insert and reweigh
232 insert_and_reweigh(
233 &mut μ, &mut τv, &mut μ_base, //None,
234 τ, ε,
235 config, &reg, &state, &mut stats
236 );
237
238 // Prune and possibly merge spikes
239 assert!(μ_base.len() <= μ.len());
240 if config.merge_now(&state) {
241 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
242 // Important: μ_candidate's new points are afterwards,
243 // and do not conflict with μ_base.
244 // TODO: could simplify to requiring μ_base instead of μ_radon.
245 // but may complicate with sliding base's exgtra points that need to be
246 // after μ_candidate's extra points.
247 // TODO: doesn't seem to work, maybe need to merge μ_base as well?
248 // Although that doesn't seem to make sense.
249 let μ_radon = μ_candidate.sub_matching(&μ_base);
250 reg.verify_merge_candidate_radonsq(&mut τv, μ_candidate, τ, ε, &config, &μ_radon)
251 //let n = μ_candidate.dist_matching(μ_base);
252 //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none()
253 });
254 }
255 stats.pruned += prune_with_stats(&mut μ);
256
257 // Update residual
258 residual = calculate_residual(&μ, opA, b);
259
260 let iter = state.iteration();
261 stats.this_iters += 1;
262
263 // Give statistics if needed
264 state.if_verbose(|| {
265 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
266 });
267
268 // Update main tolerance for next iteration
269 ε = tolerance.update(ε, iter);
270 }
271
272 postprocess(μ, config, L2Squared, opA, b)
273 }
274
275 /// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting.
276 ///
277 /// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the
278 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
279 /// Finally, the `iterator` is an outer loop verbosity and iteration count control
280 /// as documented in [`alg_tools::iterate`].
281 ///
282 /// For details on the mathematical formulation, see the [module level](self) documentation.
283 ///
284 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
285 /// sums of simple functions usign bisection trees, and the related
286 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
287 /// active at a specific points, and to maximise their sums. Through the implementation of the
288 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
289 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
290 ///
291 /// Returns the final iterate.
292 #[replace_float_literals(F::cast_from(literal))]
293 pub fn pointsource_radon_fista_reg<
294 'a, F, I, A, GA, BTA, S, Reg, const N : usize
295 >(
296 opA : &'a A,
297 b : &A::Observable,
298 reg : Reg,
299 fbconfig : &RadonFBConfig<F>,
300 iterator : I,
301 mut plotter : SeqPlotter<F, N>,
302 ) -> RNDM<F, N>
303 where F : Float + ToNalgebraRealField,
304 I : AlgIteratorFactory<IterInfo<F, N>>,
305 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
306 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
307 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>,
308 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
309 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
310 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
311 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
312 PlotLookup : Plotting<N>,
313 RNDM<F, N> : SpikeMerging<F>,
314 Reg : RegTerm<F, N> {
315
316 // Set up parameters
317 let config = &fbconfig.insertion;
318 // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ
319 // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such
320 // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L.
321 let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2);
322 let mut λ = 1.0;
323 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
324 // by τ compared to the conditional gradient approach.
325 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
326 let mut ε = tolerance.initial();
327
328 // Initialise iterates
329 let mut μ = DiscreteMeasure::new();
330 let mut μ_prev = DiscreteMeasure::new();
331 let mut residual = -b;
332 let mut warned_merging = false;
333
334 // Statistics
335 let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo {
336 value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν),
337 n_spikes : ν.len(),
338 ε,
339 // postprocessing: config.postprocessing.then(|| ν.clone()),
340 .. stats
341 };
342 let mut stats = IterInfo::new();
343
344 // Run the algorithm
345 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
346 // Calculate smooth part of surrogate model.
347 let mut τv = opA.preadjoint().apply(residual * τ);
348
349 // Save current base point
350 let mut μ_base = μ.clone();
351
352 // Insert new spikes and reweigh
353 insert_and_reweigh(
354 &mut μ, &mut τv, &mut μ_base, //None,
355 τ, ε,
356 config, &reg, &state, &mut stats
357 );
358
359 // (Do not) merge spikes.
360 if config.merge_now(&state) {
361 match config.merging {
362 SpikeMergingMethod::None => { },
363 _ => if !warned_merging {
364 let err = format!("Merging not supported for μFISTA");
365 println!("{}", err.red());
366 warned_merging = true;
367 }
368 }
369 }
370
371 // Update inertial prameters
372 let λ_prev = λ;
373 λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() );
374 let θ = λ / λ_prev - λ;
375
376 // Perform inertial update on μ.
377 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ
378 // and μ_prev have zero weight. Since both have weights from the finite-dimensional
379 // subproblem with a proximal projection step, this is likely to happen when the
380 // spike is not needed. A copy of the pruned μ without artithmetic performed is
381 // stored in μ_prev.
382 let n_before_prune = μ.len();
383 μ.pruning_sub(1.0 + θ, θ, &mut μ_prev);
384 debug_assert!(μ.len() <= n_before_prune);
385 stats.pruned += n_before_prune - μ.len();
386
387 // Update residual
388 residual = calculate_residual(&μ, opA, b);
389
390 let iter = state.iteration();
391 stats.this_iters += 1;
392
393 // Give statistics if needed
394 state.if_verbose(|| {
395 plotter.plot_spikes(iter, Option::<&S>::None, Some(&τv), &μ_prev);
396 full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new()))
397 });
398
399 // Update main tolerance for next iteration
400 ε = tolerance.update(ε, iter);
401 }
402
403 postprocess(μ_prev, config, L2Squared, opA, b)
404 }

mercurial