src/fb.rs

branch
dev
changeset 35
b087e3eab191
parent 34
efa60bc4f743
child 37
c5d8bd1a7728
equal deleted inserted replaced
34:efa60bc4f743 35:b087e3eab191
4 This corresponds to the manuscript 4 This corresponds to the manuscript
5 5
6 * Valkonen T. - _Proximal methods for point source localisation_, 6 * Valkonen T. - _Proximal methods for point source localisation_,
7 [arXiv:2212.02991](https://arxiv.org/abs/2212.02991). 7 [arXiv:2212.02991](https://arxiv.org/abs/2212.02991).
8 8
9 The main routine is [`pointsource_fb_reg`]. It is based on [`generic_pointsource_fb_reg`], which is 9 The main routine is [`pointsource_fb_reg`].
10 also used by our [primal-dual proximal splitting][crate::pdps] implementation.
11
12 FISTA-type inertia can also be enabled through [`FBConfig::meta`].
13 10
14 ## Problem 11 ## Problem
15 12
16 <p> 13 <p>
17 Our objective is to solve 14 Our objective is to solve
74 = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2. 71 = -\frac{τ}{2} \|Aμ^k-b\|^2 + τ[Aμ^k-b]^⊤ b + \frac{1}{2} \|μ_k\|_{𝒟}^2.
75 \end{aligned} 72 \end{aligned}
76 $$ 73 $$
77 </p> 74 </p>
78 75
79 We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by 76 We solve this with either SSN or FB as determined by
80 [`InnerSettings`] in [`FBGenericConfig::inner`]. 77 [`InnerSettings`] in [`FBGenericConfig::inner`].
81 */ 78 */
82 79
83 use numeric_literals::replace_float_literals; 80 use numeric_literals::replace_float_literals;
84 use serde::{Serialize, Deserialize}; 81 use serde::{Serialize, Deserialize};
85 use colored::Colorize; 82 use colored::Colorize;
86 use nalgebra::DVector; 83 use nalgebra::DVector;
87 84
88 use alg_tools::iterate::{ 85 use alg_tools::iterate::{
89 AlgIteratorFactory, 86 AlgIteratorFactory,
90 AlgIteratorState, 87 AlgIteratorIteration,
88 AlgIterator,
91 }; 89 };
92 use alg_tools::euclidean::Euclidean; 90 use alg_tools::euclidean::Euclidean;
93 use alg_tools::linops::{Apply, GEMV}; 91 use alg_tools::linops::{Mapping, GEMV};
94 use alg_tools::sets::Cube; 92 use alg_tools::sets::Cube;
95 use alg_tools::loc::Loc; 93 use alg_tools::loc::Loc;
96 use alg_tools::bisection_tree::{ 94 use alg_tools::bisection_tree::{
97 BTFN, 95 BTFN,
98 PreBTFN, 96 PreBTFN,
105 LocalAnalysis, 103 LocalAnalysis,
106 BothGenerators, 104 BothGenerators,
107 }; 105 };
108 use alg_tools::mapping::RealMapping; 106 use alg_tools::mapping::RealMapping;
109 use alg_tools::nalgebra_support::ToNalgebraRealField; 107 use alg_tools::nalgebra_support::ToNalgebraRealField;
108 use alg_tools::instance::Instance;
109 use alg_tools::norms::Linfinity;
110 110
111 use crate::types::*; 111 use crate::types::*;
112 use crate::measures::{ 112 use crate::measures::{
113 DiscreteMeasure, 113 DiscreteMeasure,
114 RNDM,
114 DeltaMeasure, 115 DeltaMeasure,
116 Radon,
115 }; 117 };
116 use crate::measures::merging::{ 118 use crate::measures::merging::{
117 SpikeMergingMethod, 119 SpikeMergingMethod,
118 SpikeMerging, 120 SpikeMerging,
119 }; 121 };
120 use crate::forward_model::ForwardModel; 122 use crate::forward_model::{
123 ForwardModel,
124 AdjointProductBoundedBy
125 };
121 use crate::seminorms::DiscreteMeasureOp; 126 use crate::seminorms::DiscreteMeasureOp;
122 use crate::subproblem::{ 127 use crate::subproblem::{
123 InnerSettings, 128 InnerSettings,
124 InnerMethod, 129 InnerMethod,
125 }; 130 };
144 pub τ0 : F, 149 pub τ0 : F,
145 /// Generic parameters 150 /// Generic parameters
146 pub generic : FBGenericConfig<F>, 151 pub generic : FBGenericConfig<F>,
147 } 152 }
148 153
149 /// Settings for the solution of the stepwise optimality condition in algorithms based on 154 /// Settings for the solution of the stepwise optimality condition.
150 /// [`generic_pointsource_fb_reg`].
151 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 155 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
152 #[serde(default)] 156 #[serde(default)]
153 pub struct FBGenericConfig<F : Float> { 157 pub struct FBGenericConfig<F : Float> {
154 /// Tolerance for point insertion. 158 /// Tolerance for point insertion.
155 pub tolerance : Tolerance<F>, 159 pub tolerance : Tolerance<F>,
186 pub final_merging : SpikeMergingMethod<F>, 190 pub final_merging : SpikeMergingMethod<F>,
187 191
188 /// Iterations between merging heuristic tries 192 /// Iterations between merging heuristic tries
189 pub merge_every : usize, 193 pub merge_every : usize,
190 194
191 /// Save $μ$ for postprocessing optimisation 195 // /// Save $μ$ for postprocessing optimisation
192 pub postprocessing : bool 196 // pub postprocessing : bool
193 } 197 }
194 198
195 #[replace_float_literals(F::cast_from(literal))] 199 #[replace_float_literals(F::cast_from(literal))]
196 impl<F : Float> Default for FBConfig<F> { 200 impl<F : Float> Default for FBConfig<F> {
197 fn default() -> Self { 201 fn default() -> Self {
219 merging : SpikeMergingMethod::None, 223 merging : SpikeMergingMethod::None,
220 //merging : Default::default(), 224 //merging : Default::default(),
221 final_merging : Default::default(), 225 final_merging : Default::default(),
222 merge_every : 10, 226 merge_every : 10,
223 merge_tolerance_mult : 2.0, 227 merge_tolerance_mult : 2.0,
224 postprocessing : false, 228 // postprocessing : false,
225 } 229 }
230 }
231 }
232
233 impl<F : Float> FBGenericConfig<F> {
234 /// Check if merging should be attempted this iteration
235 pub fn merge_now<I : AlgIterator>(&self, state : &AlgIteratorIteration<I>) -> bool {
236 state.iteration() % self.merge_every == 0
226 } 237 }
227 } 238 }
228 239
229 /// TODO: document. 240 /// TODO: document.
230 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike 241 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike
231 /// locations, while `ν_delta` may have different locations. 242 /// locations, while `ν_delta` may have different locations.
232 #[replace_float_literals(F::cast_from(literal))] 243 #[replace_float_literals(F::cast_from(literal))]
233 pub(crate) fn insert_and_reweigh< 244 pub(crate) fn insert_and_reweigh<
234 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize 245 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, I, const N : usize
235 >( 246 >(
236 μ : &mut DiscreteMeasure<Loc<F, N>, F>, 247 μ : &mut RNDM<F, N>,
237 minus_τv : &BTFN<F, GA, BTA, N>, 248 τv : &BTFN<F, GA, BTA, N>,
238 μ_base : &DiscreteMeasure<Loc<F, N>, F>, 249 μ_base : &RNDM<F, N>,
239 ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>, 250 ν_delta: Option<&RNDM<F, N>>,
240 op𝒟 : &'a 𝒟, 251 op𝒟 : &'a 𝒟,
241 op𝒟norm : F, 252 op𝒟norm : F,
242 τ : F, 253 τ : F,
243 ε : F, 254 ε : F,
244 config : &FBGenericConfig<F>, 255 config : &FBGenericConfig<F>,
245 reg : &Reg, 256 reg : &Reg,
246 state : &State, 257 state : &AlgIteratorIteration<I>,
247 stats : &mut IterInfo<F, N>, 258 stats : &mut IterInfo<F, N>,
248 ) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool) 259 ) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool)
249 where F : Float + ToNalgebraRealField, 260 where F : Float + ToNalgebraRealField,
250 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 261 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
251 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 262 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
253 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 264 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
254 𝒟::Codomain : RealMapping<F, N>, 265 𝒟::Codomain : RealMapping<F, N>,
255 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 266 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
256 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 267 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
257 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 268 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
258 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
259 Reg : RegTerm<F, N>, 269 Reg : RegTerm<F, N>,
260 State : AlgIteratorState { 270 I : AlgIterator {
261 271
262 // Maximum insertion count and measure difference calculation depend on insertion style. 272 // Maximum insertion count and measure difference calculation depend on insertion style.
263 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { 273 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
264 (i, Some((l, k))) if i <= l => (k, false), 274 (i, Some((l, k))) if i <= l => (k, false),
265 _ => (config.max_insertions, !state.is_quiet()), 275 _ => (config.max_insertions, !state.is_quiet()),
266 }; 276 };
267 277
268 // TODO: should avoid a copy of μ_base here. 278 let ω0 = match ν_delta {
269 let ω0 = op𝒟.apply(match ν_delta { 279 None => op𝒟.apply(μ_base),
270 None => μ_base.clone(), 280 Some(ν) => op𝒟.apply(μ_base + ν),
271 Some(ν_d) => &*μ_base + ν_d, 281 };
272 });
273 282
274 // Add points to support until within error tolerance or maximum insertion count reached. 283 // Add points to support until within error tolerance or maximum insertion count reached.
275 let mut count = 0; 284 let mut count = 0;
276 let (within_tolerances, d) = 'insertion: loop { 285 let (within_tolerances, d) = 'insertion: loop {
277 if μ.len() > 0 { 286 if μ.len() > 0 {
278 // Form finite-dimensional subproblem. The subproblem references to the original μ^k 287 // Form finite-dimensional subproblem. The subproblem references to the original μ^k
279 // from the beginning of the iteration are all contained in the immutable c and g. 288 // from the beginning of the iteration are all contained in the immutable c and g.
289 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
290 // problems have not yet been updated to sign change.
280 let à = op𝒟.findim_matrix(μ.iter_locations()); 291 let à = op𝒟.findim_matrix(μ.iter_locations());
281 let g̃ = DVector::from_iterator(μ.len(), 292 let g̃ = DVector::from_iterator(μ.len(),
282 μ.iter_locations() 293 μ.iter_locations()
283 .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ)) 294 .map(|ζ| ω0.apply(ζ) - τv.apply(ζ))
284 .map(F::to_nalgebra_mixed)); 295 .map(F::to_nalgebra_mixed));
285 let mut x = μ.masses_dvector(); 296 let mut x = μ.masses_dvector();
286 297
287 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. 298 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
288 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ 299 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
296 307
297 // Update masses of μ based on solution of finite-dimensional subproblem. 308 // Update masses of μ based on solution of finite-dimensional subproblem.
298 μ.set_masses_dvector(&x); 309 μ.set_masses_dvector(&x);
299 } 310 }
300 311
301 // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality 312 // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality
302 // conditions in the predual space, and finding new points for insertion, if necessary. 313 // conditions in the predual space, and finding new points for insertion, if necessary.
303 let mut d = minus_τv + op𝒟.preapply(match ν_delta { 314 let mut d = τv + match ν_delta {
304 None => μ_base.sub_matching(μ), 315 None => op𝒟.preapply(μ.sub_matching(μ_base)),
305 Some(ν) => μ_base.sub_matching(μ) + ν 316 Some(ν) => op𝒟.preapply(μ.sub_matching(μ_base) - ν)
306 }); 317 };
307 318
308 // If no merging heuristic is used, let's be more conservative about spike insertion, 319 // If no merging heuristic is used, let's be more conservative about spike insertion,
309 // and skip it after first round. If merging is done, being more greedy about spike 320 // and skip it after first round. If merging is done, being more greedy about spike
310 // insertion also seems to improve performance. 321 // insertion also seems to improve performance.
311 let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { 322 let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging {
328 } 339 }
329 340
330 // No point in optimising the weight here; the finite-dimensional algorithm is fast. 341 // No point in optimising the weight here; the finite-dimensional algorithm is fast.
331 *μ += DeltaMeasure { x : ξ, α : 0.0 }; 342 *μ += DeltaMeasure { x : ξ, α : 0.0 };
332 count += 1; 343 count += 1;
344 stats.inserted += 1;
333 }; 345 };
334
335 // TODO: should redo everything if some transports cause a problem.
336 // Maybe implementation should call above loop as a closure.
337 346
338 if !within_tolerances && warn_insertions { 347 if !within_tolerances && warn_insertions {
339 // Complain (but continue) if we failed to get within tolerances 348 // Complain (but continue) if we failed to get within tolerances
340 // by inserting more points. 349 // by inserting more points.
341 let err = format!("Maximum insertions reached without achieving \ 350 let err = format!("Maximum insertions reached without achieving \
344 } 353 }
345 354
346 (d, within_tolerances) 355 (d, within_tolerances)
347 } 356 }
348 357
349 #[replace_float_literals(F::cast_from(literal))] 358 pub(crate) fn prune_with_stats<F : Float, const N : usize>(
350 pub(crate) fn prune_and_maybe_simple_merge< 359 μ : &mut RNDM<F, N>,
351 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize 360 ) -> usize {
352 >(
353 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
354 minus_τv : &BTFN<F, GA, BTA, N>,
355 μ_base : &DiscreteMeasure<Loc<F, N>, F>,
356 op𝒟 : &'a 𝒟,
357 τ : F,
358 ε : F,
359 config : &FBGenericConfig<F>,
360 reg : &Reg,
361 state : &State,
362 stats : &mut IterInfo<F, N>,
363 )
364 where F : Float + ToNalgebraRealField,
365 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
366 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
367 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
368 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
369 𝒟::Codomain : RealMapping<F, N>,
370 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
371 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
372 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
373 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
374 Reg : RegTerm<F, N>,
375 State : AlgIteratorState {
376 if state.iteration() % config.merge_every == 0 {
377 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
378 let mut d = minus_τv + op𝒟.preapply(μ_base.sub_matching(&μ_candidate));
379 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
380 });
381 }
382
383 let n_before_prune = μ.len(); 361 let n_before_prune = μ.len();
384 μ.prune(); 362 μ.prune();
385 debug_assert!(μ.len() <= n_before_prune); 363 debug_assert!(μ.len() <= n_before_prune);
386 stats.pruned += n_before_prune - μ.len(); 364 n_before_prune - μ.len()
387 } 365 }
388 366
389 #[replace_float_literals(F::cast_from(literal))] 367 #[replace_float_literals(F::cast_from(literal))]
390 pub(crate) fn postprocess< 368 pub(crate) fn postprocess<
391 F : Float, 369 F : Float,
392 V : Euclidean<F> + Clone, 370 V : Euclidean<F> + Clone,
393 A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>, 371 A : GEMV<F, RNDM<F, N>, Codomain = V>,
394 D : DataTerm<F, V, N>, 372 D : DataTerm<F, V, N>,
395 const N : usize 373 const N : usize
396 > ( 374 > (
397 mut μ : DiscreteMeasure<Loc<F, N>, F>, 375 mut μ : RNDM<F, N>,
398 config : &FBGenericConfig<F>, 376 config : &FBGenericConfig<F>,
399 dataterm : D, 377 dataterm : D,
400 opA : &A, 378 opA : &A,
401 b : &V, 379 b : &V,
402 ) -> DiscreteMeasure<Loc<F, N>, F> 380 ) -> RNDM<F, N>
403 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { 381 where
382 RNDM<F, N> : SpikeMerging<F>,
383 for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>,
384 {
404 μ.merge_spikes_fitness(config.merging, 385 μ.merge_spikes_fitness(config.merging,
405 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), 386 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
406 |&v| v); 387 |&v| v);
407 μ.prune(); 388 μ.prune();
408 μ 389 μ
435 reg : Reg, 416 reg : Reg,
436 op𝒟 : &'a 𝒟, 417 op𝒟 : &'a 𝒟,
437 fbconfig : &FBConfig<F>, 418 fbconfig : &FBConfig<F>,
438 iterator : I, 419 iterator : I,
439 mut plotter : SeqPlotter<F, N>, 420 mut plotter : SeqPlotter<F, N>,
440 ) -> DiscreteMeasure<Loc<F, N>, F> 421 ) -> RNDM<F, N>
441 where F : Float + ToNalgebraRealField, 422 where F : Float + ToNalgebraRealField,
442 I : AlgIteratorFactory<IterInfo<F, N>>, 423 I : AlgIteratorFactory<IterInfo<F, N>>,
443 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, 424 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
444 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
445 A::Observable : std::ops::MulAssign<F>,
446 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 425 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
447 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> 426 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
448 + Lipschitz<&'a 𝒟, FloatType=F>, 427 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
449 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 428 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
450 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 429 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
451 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 430 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
452 𝒟::Codomain : RealMapping<F, N>, 431 𝒟::Codomain : RealMapping<F, N>,
453 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 432 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
454 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 433 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
455 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 434 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
456 Cube<F, N>: P2Minimise<Loc<F, N>, F>, 435 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
457 PlotLookup : Plotting<N>, 436 PlotLookup : Plotting<N>,
458 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 437 RNDM<F, N> : SpikeMerging<F>,
459 Reg : RegTerm<F, N> { 438 Reg : RegTerm<F, N> {
460 439
461 // Set up parameters 440 // Set up parameters
462 let config = &fbconfig.generic; 441 let config = &fbconfig.generic;
463 let op𝒟norm = op𝒟.opnorm_bound(); 442 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
464 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); 443 let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap();
465 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 444 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
466 // by τ compared to the conditional gradient approach. 445 // by τ compared to the conditional gradient approach.
467 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 446 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
468 let mut ε = tolerance.initial(); 447 let mut ε = tolerance.initial();
469 448
470 // Initialise iterates 449 // Initialise iterates
471 let mut μ = DiscreteMeasure::new(); 450 let mut μ = DiscreteMeasure::new();
472 let mut residual = -b; 451 let mut residual = -b;
452
453 // Statistics
454 let full_stats = |residual : &A::Observable,
455 μ : &RNDM<F, N>,
456 ε, stats| IterInfo {
457 value : residual.norm2_squared_div2() + reg.apply(μ),
458 n_spikes : μ.len(),
459 ε,
460 //postprocessing: config.postprocessing.then(|| μ.clone()),
461 .. stats
462 };
473 let mut stats = IterInfo::new(); 463 let mut stats = IterInfo::new();
474 464
475 // Run the algorithm 465 // Run the algorithm
476 iterator.iterate(|state| { 466 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
477 // Calculate smooth part of surrogate model. 467 // Calculate smooth part of surrogate model.
478 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` 468 let τv = opA.preadjoint().apply(residual * τ);
479 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
480 // the residual and replacing it below before the end of this closure.
481 residual *= -τ;
482 let r = std::mem::replace(&mut residual, opA.empty_observable());
483 let minus_τv = opA.preadjoint().apply(r);
484 469
485 // Save current base point 470 // Save current base point
486 let μ_base = μ.clone(); 471 let μ_base = μ.clone();
487 472
488 // Insert and reweigh 473 // Insert and reweigh
489 let (d, within_tolerances) = insert_and_reweigh( 474 let (d, _within_tolerances) = insert_and_reweigh(
490 &mut μ, &minus_τv, &μ_base, None, 475 &mut μ, &τv, &μ_base, None,
491 op𝒟, op𝒟norm, 476 op𝒟, op𝒟norm,
492 τ, ε, 477 τ, ε,
493 config, &reg, state, &mut stats 478 config, &reg, &state, &mut stats
494 ); 479 );
495 480
496 // Prune and possibly merge spikes 481 // Prune and possibly merge spikes
497 prune_and_maybe_simple_merge( 482 if config.merge_now(&state) {
498 &mut μ, &minus_τv, &μ_base, 483 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
499 op𝒟, 484 let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base));
500 τ, ε, 485 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
501 config, &reg, state, &mut stats 486 });
502 ); 487 }
488 stats.pruned += prune_with_stats(&mut μ);
503 489
504 // Update residual 490 // Update residual
505 residual = calculate_residual(&μ, opA, b); 491 residual = calculate_residual(&μ, opA, b);
506 492
493 let iter = state.iteration();
494 stats.this_iters += 1;
495
496 // Give statistics if needed
497 state.if_verbose(|| {
498 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ);
499 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
500 });
501
507 // Update main tolerance for next iteration 502 // Update main tolerance for next iteration
508 let ε_prev = ε; 503 ε = tolerance.update(ε, iter);
509 ε = tolerance.update(ε, state.iteration()); 504 }
510 stats.this_iters += 1;
511
512 // Give function value if needed
513 state.if_verbose(|| {
514 // Plot if so requested
515 plotter.plot_spikes(
516 format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
517 "start".to_string(), Some(&minus_τv),
518 reg.target_bounds(τ, ε_prev), &μ,
519 );
520 // Calculate mean inner iterations and reset relevant counters.
521 // Return the statistics
522 let res = IterInfo {
523 value : residual.norm2_squared_div2() + reg.apply(&μ),
524 n_spikes : μ.len(),
525 ε : ε_prev,
526 postprocessing: config.postprocessing.then(|| μ.clone()),
527 .. stats
528 };
529 stats = IterInfo::new();
530 res
531 })
532 });
533 505
534 postprocess(μ, config, L2Squared, opA, b) 506 postprocess(μ, config, L2Squared, opA, b)
535 } 507 }
536 508
537 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. 509 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting.
561 reg : Reg, 533 reg : Reg,
562 op𝒟 : &'a 𝒟, 534 op𝒟 : &'a 𝒟,
563 fbconfig : &FBConfig<F>, 535 fbconfig : &FBConfig<F>,
564 iterator : I, 536 iterator : I,
565 mut plotter : SeqPlotter<F, N>, 537 mut plotter : SeqPlotter<F, N>,
566 ) -> DiscreteMeasure<Loc<F, N>, F> 538 ) -> RNDM<F, N>
567 where F : Float + ToNalgebraRealField, 539 where F : Float + ToNalgebraRealField,
568 I : AlgIteratorFactory<IterInfo<F, N>>, 540 I : AlgIteratorFactory<IterInfo<F, N>>,
569 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, 541 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
570 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
571 A::Observable : std::ops::MulAssign<F>,
572 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 542 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
573 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> 543 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
574 + Lipschitz<&'a 𝒟, FloatType=F>, 544 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
575 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 545 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
576 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 546 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
577 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 547 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
578 𝒟::Codomain : RealMapping<F, N>, 548 𝒟::Codomain : RealMapping<F, N>,
579 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 549 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
580 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 550 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
581 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 551 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
582 Cube<F, N>: P2Minimise<Loc<F, N>, F>, 552 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
583 PlotLookup : Plotting<N>, 553 PlotLookup : Plotting<N>,
584 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 554 RNDM<F, N> : SpikeMerging<F>,
585 Reg : RegTerm<F, N> { 555 Reg : RegTerm<F, N> {
586 556
587 // Set up parameters 557 // Set up parameters
588 let config = &fbconfig.generic; 558 let config = &fbconfig.generic;
589 let op𝒟norm = op𝒟.opnorm_bound(); 559 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
590 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); 560 let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap();
591 let mut λ = 1.0; 561 let mut λ = 1.0;
592 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 562 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
593 // by τ compared to the conditional gradient approach. 563 // by τ compared to the conditional gradient approach.
594 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 564 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
595 let mut ε = tolerance.initial(); 565 let mut ε = tolerance.initial();
596 566
597 // Initialise iterates 567 // Initialise iterates
598 let mut μ = DiscreteMeasure::new(); 568 let mut μ = DiscreteMeasure::new();
599 let mut μ_prev = DiscreteMeasure::new(); 569 let mut μ_prev = DiscreteMeasure::new();
600 let mut residual = -b; 570 let mut residual = -b;
571 let mut warned_merging = false;
572
573 // Statistics
574 let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo {
575 value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν),
576 n_spikes : ν.len(),
577 ε,
578 // postprocessing: config.postprocessing.then(|| ν.clone()),
579 .. stats
580 };
601 let mut stats = IterInfo::new(); 581 let mut stats = IterInfo::new();
602 let mut warned_merging = false;
603 582
604 // Run the algorithm 583 // Run the algorithm
605 iterator.iterate(|state| { 584 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
606 // Calculate smooth part of surrogate model. 585 // Calculate smooth part of surrogate model.
607 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` 586 let τv = opA.preadjoint().apply(residual * τ);
608 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
609 // the residual and replacing it below before the end of this closure.
610 residual *= -τ;
611 let r = std::mem::replace(&mut residual, opA.empty_observable());
612 let minus_τv = opA.preadjoint().apply(r);
613 587
614 // Save current base point 588 // Save current base point
615 let μ_base = μ.clone(); 589 let μ_base = μ.clone();
616 590
617 // Insert new spikes and reweigh 591 // Insert new spikes and reweigh
618 let (d, within_tolerances) = insert_and_reweigh( 592 let (d, _within_tolerances) = insert_and_reweigh(
619 &mut μ, &minus_τv, &μ_base, None, 593 &mut μ, &τv, &μ_base, None,
620 op𝒟, op𝒟norm, 594 op𝒟, op𝒟norm,
621 τ, ε, 595 τ, ε,
622 config, &reg, state, &mut stats 596 config, &reg, &state, &mut stats
623 ); 597 );
624 598
625 // (Do not) merge spikes. 599 // (Do not) merge spikes.
626 if state.iteration() % config.merge_every == 0 { 600 if config.merge_now(&state) {
627 match config.merging { 601 match config.merging {
628 SpikeMergingMethod::None => { }, 602 SpikeMergingMethod::None => { },
629 _ => if !warned_merging { 603 _ => if !warned_merging {
630 let err = format!("Merging not supported for μFISTA"); 604 let err = format!("Merging not supported for μFISTA");
631 println!("{}", err.red()); 605 println!("{}", err.red());
651 stats.pruned += n_before_prune - μ.len(); 625 stats.pruned += n_before_prune - μ.len();
652 626
653 // Update residual 627 // Update residual
654 residual = calculate_residual(&μ, opA, b); 628 residual = calculate_residual(&μ, opA, b);
655 629
630 let iter = state.iteration();
631 stats.this_iters += 1;
632
633 // Give statistics if needed
634 state.if_verbose(|| {
635 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ_prev);
636 full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new()))
637 });
638
656 // Update main tolerance for next iteration 639 // Update main tolerance for next iteration
657 let ε_prev = ε; 640 ε = tolerance.update(ε, iter);
658 ε = tolerance.update(ε, state.iteration()); 641 }
659 stats.this_iters += 1;
660
661 // Give function value if needed
662 state.if_verbose(|| {
663 // Plot if so requested
664 plotter.plot_spikes(
665 format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
666 "start".to_string(), Some(&minus_τv),
667 reg.target_bounds(τ, ε_prev), &μ_prev,
668 );
669 // Calculate mean inner iterations and reset relevant counters.
670 // Return the statistics
671 let res = IterInfo {
672 value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev),
673 n_spikes : μ_prev.len(),
674 ε : ε_prev,
675 postprocessing: config.postprocessing.then(|| μ_prev.clone()),
676 .. stats
677 };
678 stats = IterInfo::new();
679 res
680 })
681 });
682 642
683 postprocess(μ_prev, config, L2Squared, opA, b) 643 postprocess(μ_prev, config, L2Squared, opA, b)
684 } 644 }

mercurial