12 * Bredies K., Pikkarainen H. - _Inverse problems in spaces of measures_, |
12 * Bredies K., Pikkarainen H. - _Inverse problems in spaces of measures_, |
13 DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205). |
13 DOI: [10.1051/cocv/2011205](https://doi.org/0.1051/cocv/2011205). |
14 */ |
14 */ |
15 |
15 |
16 use numeric_literals::replace_float_literals; |
16 use numeric_literals::replace_float_literals; |
|
17 use nalgebra::{DMatrix, DVector}; |
17 use serde::{Serialize, Deserialize}; |
18 use serde::{Serialize, Deserialize}; |
18 //use colored::Colorize; |
19 //use colored::Colorize; |
19 |
20 |
20 use alg_tools::iterate::{ |
21 use alg_tools::iterate::{ |
21 AlgIteratorFactory, |
22 AlgIteratorFactory, |
22 AlgIteratorState, |
|
23 AlgIteratorOptions, |
23 AlgIteratorOptions, |
24 ValueIteratorFactory, |
24 ValueIteratorFactory, |
25 }; |
25 }; |
26 use alg_tools::euclidean::Euclidean; |
26 use alg_tools::euclidean::Euclidean; |
27 use alg_tools::norms::Norm; |
27 use alg_tools::norms::Norm; |
28 use alg_tools::linops::Apply; |
28 use alg_tools::linops::Mapping; |
29 use alg_tools::sets::Cube; |
29 use alg_tools::sets::Cube; |
30 use alg_tools::loc::Loc; |
30 use alg_tools::loc::Loc; |
31 use alg_tools::bisection_tree::{ |
31 use alg_tools::bisection_tree::{ |
32 BTFN, |
32 BTFN, |
33 Bounds, |
33 Bounds, |
109 merging : Default::default(), |
111 merging : Default::default(), |
110 } |
112 } |
111 } |
113 } |
112 } |
114 } |
113 |
115 |
114 /// Helper struct for pre-initialising the finite-dimensional subproblems solver |
116 pub trait FindimQuadraticModel<Domain, F> : ForwardModel<DiscreteMeasure<Domain, F>, F> |
115 /// [`prepare_optimise_weights`]. |
117 where |
116 /// |
118 F : Float + ToNalgebraRealField, |
117 /// The pre-initialisation is done by [`prepare_optimise_weights`]. |
119 Domain : Clone + PartialEq, |
|
120 { |
|
121 /// Return A_*A and A_* b |
|
122 fn findim_quadratic_model( |
|
123 &self, |
|
124 μ : &DiscreteMeasure<Domain, F>, |
|
125 b : &Self::Observable |
|
126 ) -> (DMatrix<F::MixedType>, DVector<F::MixedType>); |
|
127 } |
|
128 |
|
129 /// Helper struct for pre-initialising the finite-dimensional subproblem solver. |
118 pub struct FindimData<F : Float> { |
130 pub struct FindimData<F : Float> { |
119 /// ‖A‖^2 |
131 /// ‖A‖^2 |
120 opAnorm_squared : F, |
132 opAnorm_squared : F, |
121 /// Bound $M_0$ from the Bredies–Pikkarainen article. |
133 /// Bound $M_0$ from the Bredies–Pikkarainen article. |
122 m0 : F |
134 m0 : F |
123 } |
135 } |
124 |
136 |
125 /// Trait for finite dimensional weight optimisation. |
137 /// Trait for finite dimensional weight optimisation. |
126 pub trait WeightOptim< |
138 pub trait WeightOptim< |
127 F : Float + ToNalgebraRealField, |
139 F : Float + ToNalgebraRealField, |
128 A : ForwardModel<Loc<F, N>, F>, |
140 A : ForwardModel<RNDM<F, N>, F>, |
129 I : AlgIteratorFactory<F>, |
141 I : AlgIteratorFactory<F>, |
130 const N : usize |
142 const N : usize |
131 > { |
143 > { |
132 |
144 |
133 /// Return a pre-initialisation struct for [`Self::optimise_weights`]. |
145 /// Return a pre-initialisation struct for [`Self::optimise_weights`]. |
164 } |
176 } |
165 |
177 |
166 /// Trait for regularisation terms supported by [`pointsource_fw_reg`]. |
178 /// Trait for regularisation terms supported by [`pointsource_fw_reg`]. |
167 pub trait RegTermFW< |
179 pub trait RegTermFW< |
168 F : Float + ToNalgebraRealField, |
180 F : Float + ToNalgebraRealField, |
169 A : ForwardModel<Loc<F, N>, F>, |
181 A : ForwardModel<RNDM<F, N>, F>, |
170 I : AlgIteratorFactory<F>, |
182 I : AlgIteratorFactory<F>, |
171 const N : usize |
183 const N : usize |
172 > : RegTerm<F, N> |
184 > : RegTerm<F, N> |
173 + WeightOptim<F, A, I, N> |
185 + WeightOptim<F, A, I, N> |
174 + for<'a> Apply<&'a DiscreteMeasure<Loc<F, N>, F>, Output = F> { |
186 + Mapping<RNDM<F, N>, Codomain = F> { |
175 |
187 |
176 /// With $g = A\_\*(Aμ-b)$, returns $(x, g(x))$ for $x$ a new point to be inserted |
188 /// With $g = A\_\*(Aμ-b)$, returns $(x, g(x))$ for $x$ a new point to be inserted |
177 /// into $μ$, as determined by the regulariser. |
189 /// into $μ$, as determined by the regulariser. |
178 /// |
190 /// |
179 /// The parameters `refinement_tolerance` and `max_steps` are passed to relevant |
191 /// The parameters `refinement_tolerance` and `max_steps` are passed to relevant |
199 |
211 |
200 #[replace_float_literals(F::cast_from(literal))] |
212 #[replace_float_literals(F::cast_from(literal))] |
201 impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N> |
213 impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N> |
202 for RadonRegTerm<F> |
214 for RadonRegTerm<F> |
203 where I : AlgIteratorFactory<F>, |
215 where I : AlgIteratorFactory<F>, |
204 A : ForwardModel<Loc<F, N>, F> { |
216 A : FindimQuadraticModel<Loc<F, N>, F> { |
205 |
217 |
206 fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> { |
218 fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> { |
207 FindimData{ |
219 FindimData{ |
208 opAnorm_squared : opA.opnorm_bound().powi(2), |
220 opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2), |
209 m0 : b.norm2_squared() / (2.0 * self.α()), |
221 m0 : b.norm2_squared() / (2.0 * self.α()), |
210 } |
222 } |
211 } |
223 } |
212 |
224 |
213 fn optimise_weights<'a>( |
225 fn optimise_weights<'a>( |
214 &self, |
226 &self, |
215 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
227 μ : &mut RNDM<F, N>, |
216 opA : &'a A, |
228 opA : &'a A, |
217 b : &A::Observable, |
229 b : &A::Observable, |
218 findim_data : &FindimData<F>, |
230 findim_data : &FindimData<F>, |
219 inner : &InnerSettings<F>, |
231 inner : &InnerSettings<F>, |
220 iterator : I |
232 iterator : I |
243 } |
255 } |
244 |
256 |
245 #[replace_float_literals(F::cast_from(literal))] |
257 #[replace_float_literals(F::cast_from(literal))] |
246 impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N> |
258 impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N> |
247 for RadonRegTerm<F> |
259 for RadonRegTerm<F> |
248 where Cube<F, N> : P2Minimise<Loc<F, N>, F>, |
260 where |
249 I : AlgIteratorFactory<F>, |
261 Cube<F, N> : P2Minimise<Loc<F, N>, F>, |
250 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
262 I : AlgIteratorFactory<F>, |
251 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
263 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
252 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
264 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
253 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>> { |
265 A : FindimQuadraticModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
|
266 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
267 // FIXME: the following *should not* be needed, they are already implied |
|
268 RNDM<F, N> : Mapping<A::PreadjointCodomain, Codomain = F>, |
|
269 DeltaMeasure<Loc<F, N>, F> : Mapping<A::PreadjointCodomain, Codomain = F>, |
|
270 //A : Mapping<RNDM<F, N>, Codomain = A::Observable>, |
|
271 //A : Mapping<DeltaMeasure<Loc<F, N>, F>, Codomain = A::Observable>, |
|
272 { |
254 |
273 |
255 fn find_insertion( |
274 fn find_insertion( |
256 &self, |
275 &self, |
257 g : &mut A::PreadjointCodomain, |
276 g : &mut A::PreadjointCodomain, |
258 refinement_tolerance : F, |
277 refinement_tolerance : F, |
296 |
315 |
297 #[replace_float_literals(F::cast_from(literal))] |
316 #[replace_float_literals(F::cast_from(literal))] |
298 impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N> |
317 impl<F : Float + ToNalgebraRealField, A, I, const N : usize> WeightOptim<F, A, I, N> |
299 for NonnegRadonRegTerm<F> |
318 for NonnegRadonRegTerm<F> |
300 where I : AlgIteratorFactory<F>, |
319 where I : AlgIteratorFactory<F>, |
301 A : ForwardModel<Loc<F, N>, F> { |
320 A : FindimQuadraticModel<Loc<F, N>, F> { |
302 |
321 |
303 fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> { |
322 fn prepare_optimise_weights(&self, opA : &A, b : &A::Observable) -> FindimData<F> { |
304 FindimData{ |
323 FindimData{ |
305 opAnorm_squared : opA.opnorm_bound().powi(2), |
324 opAnorm_squared : opA.opnorm_bound(Radon, L2).powi(2), |
306 m0 : b.norm2_squared() / (2.0 * self.α()), |
325 m0 : b.norm2_squared() / (2.0 * self.α()), |
307 } |
326 } |
308 } |
327 } |
309 |
328 |
310 fn optimise_weights<'a>( |
329 fn optimise_weights<'a>( |
311 &self, |
330 &self, |
312 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
331 μ : &mut RNDM<F, N>, |
313 opA : &'a A, |
332 opA : &'a A, |
314 b : &A::Observable, |
333 b : &A::Observable, |
315 findim_data : &FindimData<F>, |
334 findim_data : &FindimData<F>, |
316 inner : &InnerSettings<F>, |
335 inner : &InnerSettings<F>, |
317 iterator : I |
336 iterator : I |
340 } |
359 } |
341 |
360 |
342 #[replace_float_literals(F::cast_from(literal))] |
361 #[replace_float_literals(F::cast_from(literal))] |
343 impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N> |
362 impl<F : Float + ToNalgebraRealField, A, I, S, GA, BTA, const N : usize> RegTermFW<F, A, I, N> |
344 for NonnegRadonRegTerm<F> |
363 for NonnegRadonRegTerm<F> |
345 where Cube<F, N> : P2Minimise<Loc<F, N>, F>, |
364 where |
346 I : AlgIteratorFactory<F>, |
365 Cube<F, N> : P2Minimise<Loc<F, N>, F>, |
347 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
366 I : AlgIteratorFactory<F>, |
348 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
367 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
349 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
368 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
350 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>> { |
369 A : FindimQuadraticModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
|
370 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
371 // FIXME: the following *should not* be needed, they are already implied |
|
372 RNDM<F, N> : Mapping<A::PreadjointCodomain, Codomain = F>, |
|
373 DeltaMeasure<Loc<F, N>, F> : Mapping<A::PreadjointCodomain, Codomain = F>, |
|
374 { |
351 |
375 |
352 fn find_insertion( |
376 fn find_insertion( |
353 &self, |
377 &self, |
354 g : &mut A::PreadjointCodomain, |
378 g : &mut A::PreadjointCodomain, |
355 refinement_tolerance : F, |
379 refinement_tolerance : F, |
407 reg : Reg, |
431 reg : Reg, |
408 //domain : Cube<F, N>, |
432 //domain : Cube<F, N>, |
409 config : &FWConfig<F>, |
433 config : &FWConfig<F>, |
410 iterator : I, |
434 iterator : I, |
411 mut plotter : SeqPlotter<F, N>, |
435 mut plotter : SeqPlotter<F, N>, |
412 ) -> DiscreteMeasure<Loc<F, N>, F> |
436 ) -> RNDM<F, N> |
413 where F : Float + ToNalgebraRealField, |
437 where F : Float + ToNalgebraRealField, |
414 I : AlgIteratorFactory<IterInfo<F, N>>, |
438 I : AlgIteratorFactory<IterInfo<F, N>>, |
415 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
439 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
416 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
|
417 A::Observable : std::ops::MulAssign<F>, |
|
418 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
440 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
419 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
441 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, |
420 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
442 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
421 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
443 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
422 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
444 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
423 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
445 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
424 PlotLookup : Plotting<N>, |
446 PlotLookup : Plotting<N>, |
425 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
447 RNDM<F, N> : SpikeMerging<F>, |
426 Reg : RegTermFW<F, A, ValueIteratorFactory<F, AlgIteratorOptions>, N> { |
448 Reg : RegTermFW<F, A, ValueIteratorFactory<F, AlgIteratorOptions>, N> { |
427 |
449 |
428 // Set up parameters |
450 // Set up parameters |
429 // We multiply tolerance by α for all algoritms. |
451 // We multiply tolerance by α for all algoritms. |
430 let tolerance = config.tolerance * reg.tolerance_scaling(); |
452 let tolerance = config.tolerance * reg.tolerance_scaling(); |
436 |
458 |
437 // Initialise iterates |
459 // Initialise iterates |
438 let mut μ = DiscreteMeasure::new(); |
460 let mut μ = DiscreteMeasure::new(); |
439 let mut residual = -b; |
461 let mut residual = -b; |
440 |
462 |
441 let mut inner_iters = 0; |
463 // Statistics |
442 let mut this_iters = 0; |
464 let full_stats = |residual : &A::Observable, |
443 let mut pruned = 0; |
465 ν : &RNDM<F, N>, |
444 let mut merged = 0; |
466 ε, stats| IterInfo { |
|
467 value : residual.norm2_squared_div2() + reg.apply(ν), |
|
468 n_spikes : ν.len(), |
|
469 ε, |
|
470 .. stats |
|
471 }; |
|
472 let mut stats = IterInfo::new(); |
445 |
473 |
446 // Run the algorithm |
474 // Run the algorithm |
447 iterator.iterate(|state| { |
475 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
448 // Update tolerance |
|
449 let inner_tolerance = ε * config.inner.tolerance_mult; |
476 let inner_tolerance = ε * config.inner.tolerance_mult; |
450 let refinement_tolerance = ε * config.refinement.tolerance_mult; |
477 let refinement_tolerance = ε * config.refinement.tolerance_mult; |
451 let ε_prev = ε; |
|
452 ε = tolerance.update(ε, state.iteration()); |
|
453 |
478 |
454 // Calculate smooth part of surrogate model. |
479 // Calculate smooth part of surrogate model. |
455 // |
480 let mut g = preadjA.apply(residual * (-1.0)); |
456 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
|
457 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
458 // the residual and replacing it below before the end of this closure. |
|
459 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
460 let mut g = -preadjA.apply(r); |
|
461 |
481 |
462 // Find absolute value maximising point |
482 // Find absolute value maximising point |
463 let (ξ, v_ξ) = reg.find_insertion(&mut g, refinement_tolerance, |
483 let (ξ, v_ξ) = reg.find_insertion(&mut g, refinement_tolerance, |
464 config.refinement.max_steps); |
484 config.refinement.max_steps); |
465 |
485 |
466 let inner_it = match config.variant { |
486 let inner_it = match config.variant { |
467 FWVariant::FullyCorrective => { |
487 FWVariant::FullyCorrective => { |
468 // No point in optimising the weight here: the finite-dimensional algorithm is fast. |
488 // No point in optimising the weight here: the finite-dimensional algorithm is fast. |
469 μ += DeltaMeasure { x : ξ, α : 0.0 }; |
489 μ += DeltaMeasure { x : ξ, α : 0.0 }; |
|
490 stats.inserted += 1; |
470 config.inner.iterator_options.stop_target(inner_tolerance) |
491 config.inner.iterator_options.stop_target(inner_tolerance) |
471 }, |
492 }, |
472 FWVariant::Relaxed => { |
493 FWVariant::Relaxed => { |
473 // Perform a relaxed initialisation of μ |
494 // Perform a relaxed initialisation of μ |
474 reg.relaxed_insert(&mut μ, &g, opA, ξ, v_ξ, &findim_data); |
495 reg.relaxed_insert(&mut μ, &g, opA, ξ, v_ξ, &findim_data); |
|
496 stats.inserted += 1; |
475 // The stop_target is only needed for the type system. |
497 // The stop_target is only needed for the type system. |
476 AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0) |
498 AlgIteratorOptions{ max_iter : 1, .. config.inner.iterator_options}.stop_target(0.0) |
477 } |
499 } |
478 }; |
500 }; |
479 |
501 |
480 inner_iters += reg.optimise_weights(&mut μ, opA, b, &findim_data, &config.inner, inner_it); |
502 stats.inner_iters += reg.optimise_weights(&mut μ, opA, b, &findim_data, |
|
503 &config.inner, inner_it); |
481 |
504 |
482 // Merge spikes and update residual for next step and `if_verbose` below. |
505 // Merge spikes and update residual for next step and `if_verbose` below. |
483 let (r, count) = μ.merge_spikes_fitness(config.merging, |
506 let (r, count) = μ.merge_spikes_fitness(config.merging, |
484 |μ̃| opA.apply(μ̃) - b, |
507 |μ̃| opA.apply(μ̃) - b, |
485 A::Observable::norm2_squared); |
508 A::Observable::norm2_squared); |
486 residual = r; |
509 residual = r; |
487 merged += count; |
510 stats.merged += count; |
488 |
|
489 |
511 |
490 // Prune points with zero mass |
512 // Prune points with zero mass |
491 let n_before_prune = μ.len(); |
513 let n_before_prune = μ.len(); |
492 μ.prune(); |
514 μ.prune(); |
493 debug_assert!(μ.len() <= n_before_prune); |
515 debug_assert!(μ.len() <= n_before_prune); |
494 pruned += n_before_prune - μ.len(); |
516 stats.pruned += n_before_prune - μ.len(); |
495 |
517 |
496 this_iters +=1; |
518 stats.this_iters += 1; |
497 |
519 let iter = state.iteration(); |
498 // Give function value if needed |
520 |
|
521 // Give statistics if needed |
499 state.if_verbose(|| { |
522 state.if_verbose(|| { |
500 plotter.plot_spikes( |
523 plotter.plot_spikes(iter, Some(&g), Option::<&S>::None, &μ); |
501 format!("iter {} start", state.iteration()), &g, |
524 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
502 "".to_string(), None::<&A::PreadjointCodomain>, |
525 }); |
503 None, &μ |
526 |
504 ); |
527 // Update tolerance |
505 let res = IterInfo { |
528 ε = tolerance.update(ε, iter); |
506 value : residual.norm2_squared_div2() + reg.apply(&μ), |
529 } |
507 n_spikes : μ.len(), |
|
508 inner_iters, |
|
509 this_iters, |
|
510 merged, |
|
511 pruned, |
|
512 ε : ε_prev, |
|
513 postprocessing : None, |
|
514 untransported_fraction : None, |
|
515 transport_error : None, |
|
516 }; |
|
517 inner_iters = 0; |
|
518 this_iters = 0; |
|
519 merged = 0; |
|
520 pruned = 0; |
|
521 res |
|
522 }) |
|
523 }); |
|
524 |
530 |
525 // Return final iterate |
531 // Return final iterate |
526 μ |
532 μ |
527 } |
533 } |