219 merging : SpikeMergingMethod::None, |
223 merging : SpikeMergingMethod::None, |
220 //merging : Default::default(), |
224 //merging : Default::default(), |
221 final_merging : Default::default(), |
225 final_merging : Default::default(), |
222 merge_every : 10, |
226 merge_every : 10, |
223 merge_tolerance_mult : 2.0, |
227 merge_tolerance_mult : 2.0, |
224 postprocessing : false, |
228 // postprocessing : false, |
225 } |
229 } |
|
230 } |
|
231 } |
|
232 |
|
233 impl<F : Float> FBGenericConfig<F> { |
|
234 /// Check if merging should be attempted this iteration |
|
235 pub fn merge_now<I : AlgIterator>(&self, state : &AlgIteratorIteration<I>) -> bool { |
|
236 state.iteration() % self.merge_every == 0 |
226 } |
237 } |
227 } |
238 } |
228 |
239 |
229 /// TODO: document. |
240 /// TODO: document. |
230 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike |
241 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike |
231 /// locations, while `ν_delta` may have different locations. |
242 /// locations, while `ν_delta` may have different locations. |
232 #[replace_float_literals(F::cast_from(literal))] |
243 #[replace_float_literals(F::cast_from(literal))] |
233 pub(crate) fn insert_and_reweigh< |
244 pub(crate) fn insert_and_reweigh< |
234 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize |
245 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, I, const N : usize |
235 >( |
246 >( |
236 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
247 μ : &mut RNDM<F, N>, |
237 minus_τv : &BTFN<F, GA, BTA, N>, |
248 τv : &BTFN<F, GA, BTA, N>, |
238 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
249 μ_base : &RNDM<F, N>, |
239 ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>, |
250 ν_delta: Option<&RNDM<F, N>>, |
240 op𝒟 : &'a 𝒟, |
251 op𝒟 : &'a 𝒟, |
241 op𝒟norm : F, |
252 op𝒟norm : F, |
242 τ : F, |
253 τ : F, |
243 ε : F, |
254 ε : F, |
244 config : &FBGenericConfig<F>, |
255 config : &FBGenericConfig<F>, |
245 reg : &Reg, |
256 reg : &Reg, |
246 state : &State, |
257 state : &AlgIteratorIteration<I>, |
247 stats : &mut IterInfo<F, N>, |
258 stats : &mut IterInfo<F, N>, |
248 ) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool) |
259 ) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool) |
249 where F : Float + ToNalgebraRealField, |
260 where F : Float + ToNalgebraRealField, |
250 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
261 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
251 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
262 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
253 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
264 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
254 𝒟::Codomain : RealMapping<F, N>, |
265 𝒟::Codomain : RealMapping<F, N>, |
255 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
266 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
256 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
267 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
257 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
268 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
258 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
259 Reg : RegTerm<F, N>, |
269 Reg : RegTerm<F, N>, |
260 State : AlgIteratorState { |
270 I : AlgIterator { |
261 |
271 |
262 // Maximum insertion count and measure difference calculation depend on insertion style. |
272 // Maximum insertion count and measure difference calculation depend on insertion style. |
263 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
273 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
264 (i, Some((l, k))) if i <= l => (k, false), |
274 (i, Some((l, k))) if i <= l => (k, false), |
265 _ => (config.max_insertions, !state.is_quiet()), |
275 _ => (config.max_insertions, !state.is_quiet()), |
266 }; |
276 }; |
267 |
277 |
268 // TODO: should avoid a copy of μ_base here. |
278 let ω0 = match ν_delta { |
269 let ω0 = op𝒟.apply(match ν_delta { |
279 None => op𝒟.apply(μ_base), |
270 None => μ_base.clone(), |
280 Some(ν) => op𝒟.apply(μ_base + ν), |
271 Some(ν_d) => &*μ_base + ν_d, |
281 }; |
272 }); |
|
273 |
282 |
274 // Add points to support until within error tolerance or maximum insertion count reached. |
283 // Add points to support until within error tolerance or maximum insertion count reached. |
275 let mut count = 0; |
284 let mut count = 0; |
276 let (within_tolerances, d) = 'insertion: loop { |
285 let (within_tolerances, d) = 'insertion: loop { |
277 if μ.len() > 0 { |
286 if μ.len() > 0 { |
278 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
287 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
279 // from the beginning of the iteration are all contained in the immutable c and g. |
288 // from the beginning of the iteration are all contained in the immutable c and g. |
|
289 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional |
|
290 // problems have not yet been updated to sign change. |
280 let à = op𝒟.findim_matrix(μ.iter_locations()); |
291 let à = op𝒟.findim_matrix(μ.iter_locations()); |
281 let g̃ = DVector::from_iterator(μ.len(), |
292 let g̃ = DVector::from_iterator(μ.len(), |
282 μ.iter_locations() |
293 μ.iter_locations() |
283 .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ)) |
294 .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) |
284 .map(F::to_nalgebra_mixed)); |
295 .map(F::to_nalgebra_mixed)); |
285 let mut x = μ.masses_dvector(); |
296 let mut x = μ.masses_dvector(); |
286 |
297 |
287 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. |
298 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. |
288 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ |
299 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ |
344 } |
353 } |
345 |
354 |
346 (d, within_tolerances) |
355 (d, within_tolerances) |
347 } |
356 } |
348 |
357 |
349 #[replace_float_literals(F::cast_from(literal))] |
358 pub(crate) fn prune_with_stats<F : Float, const N : usize>( |
350 pub(crate) fn prune_and_maybe_simple_merge< |
359 μ : &mut RNDM<F, N>, |
351 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize |
360 ) -> usize { |
352 >( |
|
353 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
354 minus_τv : &BTFN<F, GA, BTA, N>, |
|
355 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
|
356 op𝒟 : &'a 𝒟, |
|
357 τ : F, |
|
358 ε : F, |
|
359 config : &FBGenericConfig<F>, |
|
360 reg : &Reg, |
|
361 state : &State, |
|
362 stats : &mut IterInfo<F, N>, |
|
363 ) |
|
364 where F : Float + ToNalgebraRealField, |
|
365 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
366 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
367 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
|
368 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
|
369 𝒟::Codomain : RealMapping<F, N>, |
|
370 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
371 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
372 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
373 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
374 Reg : RegTerm<F, N>, |
|
375 State : AlgIteratorState { |
|
376 if state.iteration() % config.merge_every == 0 { |
|
377 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
|
378 let mut d = minus_τv + op𝒟.preapply(μ_base.sub_matching(&μ_candidate)); |
|
379 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
|
380 }); |
|
381 } |
|
382 |
|
383 let n_before_prune = μ.len(); |
361 let n_before_prune = μ.len(); |
384 μ.prune(); |
362 μ.prune(); |
385 debug_assert!(μ.len() <= n_before_prune); |
363 debug_assert!(μ.len() <= n_before_prune); |
386 stats.pruned += n_before_prune - μ.len(); |
364 n_before_prune - μ.len() |
387 } |
365 } |
388 |
366 |
389 #[replace_float_literals(F::cast_from(literal))] |
367 #[replace_float_literals(F::cast_from(literal))] |
390 pub(crate) fn postprocess< |
368 pub(crate) fn postprocess< |
391 F : Float, |
369 F : Float, |
392 V : Euclidean<F> + Clone, |
370 V : Euclidean<F> + Clone, |
393 A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>, |
371 A : GEMV<F, RNDM<F, N>, Codomain = V>, |
394 D : DataTerm<F, V, N>, |
372 D : DataTerm<F, V, N>, |
395 const N : usize |
373 const N : usize |
396 > ( |
374 > ( |
397 mut μ : DiscreteMeasure<Loc<F, N>, F>, |
375 mut μ : RNDM<F, N>, |
398 config : &FBGenericConfig<F>, |
376 config : &FBGenericConfig<F>, |
399 dataterm : D, |
377 dataterm : D, |
400 opA : &A, |
378 opA : &A, |
401 b : &V, |
379 b : &V, |
402 ) -> DiscreteMeasure<Loc<F, N>, F> |
380 ) -> RNDM<F, N> |
403 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { |
381 where |
|
382 RNDM<F, N> : SpikeMerging<F>, |
|
383 for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>, |
|
384 { |
404 μ.merge_spikes_fitness(config.merging, |
385 μ.merge_spikes_fitness(config.merging, |
405 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
386 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
406 |&v| v); |
387 |&v| v); |
407 μ.prune(); |
388 μ.prune(); |
408 μ |
389 μ |
435 reg : Reg, |
416 reg : Reg, |
436 op𝒟 : &'a 𝒟, |
417 op𝒟 : &'a 𝒟, |
437 fbconfig : &FBConfig<F>, |
418 fbconfig : &FBConfig<F>, |
438 iterator : I, |
419 iterator : I, |
439 mut plotter : SeqPlotter<F, N>, |
420 mut plotter : SeqPlotter<F, N>, |
440 ) -> DiscreteMeasure<Loc<F, N>, F> |
421 ) -> RNDM<F, N> |
441 where F : Float + ToNalgebraRealField, |
422 where F : Float + ToNalgebraRealField, |
442 I : AlgIteratorFactory<IterInfo<F, N>>, |
423 I : AlgIteratorFactory<IterInfo<F, N>>, |
443 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
424 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
444 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
|
445 A::Observable : std::ops::MulAssign<F>, |
|
446 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
425 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
447 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
426 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
448 + Lipschitz<&'a 𝒟, FloatType=F>, |
427 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
449 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
428 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
450 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
429 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
451 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
430 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
452 𝒟::Codomain : RealMapping<F, N>, |
431 𝒟::Codomain : RealMapping<F, N>, |
453 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
432 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
454 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
433 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
455 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
434 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
456 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
435 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
457 PlotLookup : Plotting<N>, |
436 PlotLookup : Plotting<N>, |
458 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
437 RNDM<F, N> : SpikeMerging<F>, |
459 Reg : RegTerm<F, N> { |
438 Reg : RegTerm<F, N> { |
460 |
439 |
461 // Set up parameters |
440 // Set up parameters |
462 let config = &fbconfig.generic; |
441 let config = &fbconfig.generic; |
463 let op𝒟norm = op𝒟.opnorm_bound(); |
442 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
464 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); |
443 let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap(); |
465 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
444 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
466 // by τ compared to the conditional gradient approach. |
445 // by τ compared to the conditional gradient approach. |
467 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
446 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
468 let mut ε = tolerance.initial(); |
447 let mut ε = tolerance.initial(); |
469 |
448 |
470 // Initialise iterates |
449 // Initialise iterates |
471 let mut μ = DiscreteMeasure::new(); |
450 let mut μ = DiscreteMeasure::new(); |
472 let mut residual = -b; |
451 let mut residual = -b; |
|
452 |
|
453 // Statistics |
|
454 let full_stats = |residual : &A::Observable, |
|
455 μ : &RNDM<F, N>, |
|
456 ε, stats| IterInfo { |
|
457 value : residual.norm2_squared_div2() + reg.apply(μ), |
|
458 n_spikes : μ.len(), |
|
459 ε, |
|
460 //postprocessing: config.postprocessing.then(|| μ.clone()), |
|
461 .. stats |
|
462 }; |
473 let mut stats = IterInfo::new(); |
463 let mut stats = IterInfo::new(); |
474 |
464 |
475 // Run the algorithm |
465 // Run the algorithm |
476 iterator.iterate(|state| { |
466 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
477 // Calculate smooth part of surrogate model. |
467 // Calculate smooth part of surrogate model. |
478 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
468 let τv = opA.preadjoint().apply(residual * τ); |
479 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
480 // the residual and replacing it below before the end of this closure. |
|
481 residual *= -τ; |
|
482 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
483 let minus_τv = opA.preadjoint().apply(r); |
|
484 |
469 |
485 // Save current base point |
470 // Save current base point |
486 let μ_base = μ.clone(); |
471 let μ_base = μ.clone(); |
487 |
472 |
488 // Insert and reweigh |
473 // Insert and reweigh |
489 let (d, within_tolerances) = insert_and_reweigh( |
474 let (d, _within_tolerances) = insert_and_reweigh( |
490 &mut μ, &minus_τv, &μ_base, None, |
475 &mut μ, &τv, &μ_base, None, |
491 op𝒟, op𝒟norm, |
476 op𝒟, op𝒟norm, |
492 τ, ε, |
477 τ, ε, |
493 config, ®, state, &mut stats |
478 config, ®, &state, &mut stats |
494 ); |
479 ); |
495 |
480 |
496 // Prune and possibly merge spikes |
481 // Prune and possibly merge spikes |
497 prune_and_maybe_simple_merge( |
482 if config.merge_now(&state) { |
498 &mut μ, &minus_τv, &μ_base, |
483 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
499 op𝒟, |
484 let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); |
500 τ, ε, |
485 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
501 config, ®, state, &mut stats |
486 }); |
502 ); |
487 } |
|
488 stats.pruned += prune_with_stats(&mut μ); |
503 |
489 |
504 // Update residual |
490 // Update residual |
505 residual = calculate_residual(&μ, opA, b); |
491 residual = calculate_residual(&μ, opA, b); |
506 |
492 |
|
493 let iter = state.iteration(); |
|
494 stats.this_iters += 1; |
|
495 |
|
496 // Give statistics if needed |
|
497 state.if_verbose(|| { |
|
498 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); |
|
499 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
|
500 }); |
|
501 |
507 // Update main tolerance for next iteration |
502 // Update main tolerance for next iteration |
508 let ε_prev = ε; |
503 ε = tolerance.update(ε, iter); |
509 ε = tolerance.update(ε, state.iteration()); |
504 } |
510 stats.this_iters += 1; |
|
511 |
|
512 // Give function value if needed |
|
513 state.if_verbose(|| { |
|
514 // Plot if so requested |
|
515 plotter.plot_spikes( |
|
516 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, |
|
517 "start".to_string(), Some(&minus_τv), |
|
518 reg.target_bounds(τ, ε_prev), &μ, |
|
519 ); |
|
520 // Calculate mean inner iterations and reset relevant counters. |
|
521 // Return the statistics |
|
522 let res = IterInfo { |
|
523 value : residual.norm2_squared_div2() + reg.apply(&μ), |
|
524 n_spikes : μ.len(), |
|
525 ε : ε_prev, |
|
526 postprocessing: config.postprocessing.then(|| μ.clone()), |
|
527 .. stats |
|
528 }; |
|
529 stats = IterInfo::new(); |
|
530 res |
|
531 }) |
|
532 }); |
|
533 |
505 |
534 postprocess(μ, config, L2Squared, opA, b) |
506 postprocess(μ, config, L2Squared, opA, b) |
535 } |
507 } |
536 |
508 |
537 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. |
509 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. |
561 reg : Reg, |
533 reg : Reg, |
562 op𝒟 : &'a 𝒟, |
534 op𝒟 : &'a 𝒟, |
563 fbconfig : &FBConfig<F>, |
535 fbconfig : &FBConfig<F>, |
564 iterator : I, |
536 iterator : I, |
565 mut plotter : SeqPlotter<F, N>, |
537 mut plotter : SeqPlotter<F, N>, |
566 ) -> DiscreteMeasure<Loc<F, N>, F> |
538 ) -> RNDM<F, N> |
567 where F : Float + ToNalgebraRealField, |
539 where F : Float + ToNalgebraRealField, |
568 I : AlgIteratorFactory<IterInfo<F, N>>, |
540 I : AlgIteratorFactory<IterInfo<F, N>>, |
569 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
541 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
570 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
|
571 A::Observable : std::ops::MulAssign<F>, |
|
572 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
542 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
573 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
543 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
574 + Lipschitz<&'a 𝒟, FloatType=F>, |
544 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
575 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
545 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
576 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
546 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
577 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
547 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
578 𝒟::Codomain : RealMapping<F, N>, |
548 𝒟::Codomain : RealMapping<F, N>, |
579 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
549 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
580 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
550 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
581 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
551 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
582 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
552 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
583 PlotLookup : Plotting<N>, |
553 PlotLookup : Plotting<N>, |
584 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
554 RNDM<F, N> : SpikeMerging<F>, |
585 Reg : RegTerm<F, N> { |
555 Reg : RegTerm<F, N> { |
586 |
556 |
587 // Set up parameters |
557 // Set up parameters |
588 let config = &fbconfig.generic; |
558 let config = &fbconfig.generic; |
589 let op𝒟norm = op𝒟.opnorm_bound(); |
559 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
590 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); |
560 let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap(); |
591 let mut λ = 1.0; |
561 let mut λ = 1.0; |
592 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
562 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
593 // by τ compared to the conditional gradient approach. |
563 // by τ compared to the conditional gradient approach. |
594 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
564 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
595 let mut ε = tolerance.initial(); |
565 let mut ε = tolerance.initial(); |
596 |
566 |
597 // Initialise iterates |
567 // Initialise iterates |
598 let mut μ = DiscreteMeasure::new(); |
568 let mut μ = DiscreteMeasure::new(); |
599 let mut μ_prev = DiscreteMeasure::new(); |
569 let mut μ_prev = DiscreteMeasure::new(); |
600 let mut residual = -b; |
570 let mut residual = -b; |
|
571 let mut warned_merging = false; |
|
572 |
|
573 // Statistics |
|
574 let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo { |
|
575 value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), |
|
576 n_spikes : ν.len(), |
|
577 ε, |
|
578 // postprocessing: config.postprocessing.then(|| ν.clone()), |
|
579 .. stats |
|
580 }; |
601 let mut stats = IterInfo::new(); |
581 let mut stats = IterInfo::new(); |
602 let mut warned_merging = false; |
|
603 |
582 |
604 // Run the algorithm |
583 // Run the algorithm |
605 iterator.iterate(|state| { |
584 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
606 // Calculate smooth part of surrogate model. |
585 // Calculate smooth part of surrogate model. |
607 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
586 let τv = opA.preadjoint().apply(residual * τ); |
608 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
609 // the residual and replacing it below before the end of this closure. |
|
610 residual *= -τ; |
|
611 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
612 let minus_τv = opA.preadjoint().apply(r); |
|
613 |
587 |
614 // Save current base point |
588 // Save current base point |
615 let μ_base = μ.clone(); |
589 let μ_base = μ.clone(); |
616 |
590 |
617 // Insert new spikes and reweigh |
591 // Insert new spikes and reweigh |
618 let (d, within_tolerances) = insert_and_reweigh( |
592 let (d, _within_tolerances) = insert_and_reweigh( |
619 &mut μ, &minus_τv, &μ_base, None, |
593 &mut μ, &τv, &μ_base, None, |
620 op𝒟, op𝒟norm, |
594 op𝒟, op𝒟norm, |
621 τ, ε, |
595 τ, ε, |
622 config, ®, state, &mut stats |
596 config, ®, &state, &mut stats |
623 ); |
597 ); |
624 |
598 |
625 // (Do not) merge spikes. |
599 // (Do not) merge spikes. |
626 if state.iteration() % config.merge_every == 0 { |
600 if config.merge_now(&state) { |
627 match config.merging { |
601 match config.merging { |
628 SpikeMergingMethod::None => { }, |
602 SpikeMergingMethod::None => { }, |
629 _ => if !warned_merging { |
603 _ => if !warned_merging { |
630 let err = format!("Merging not supported for μFISTA"); |
604 let err = format!("Merging not supported for μFISTA"); |
631 println!("{}", err.red()); |
605 println!("{}", err.red()); |
651 stats.pruned += n_before_prune - μ.len(); |
625 stats.pruned += n_before_prune - μ.len(); |
652 |
626 |
653 // Update residual |
627 // Update residual |
654 residual = calculate_residual(&μ, opA, b); |
628 residual = calculate_residual(&μ, opA, b); |
655 |
629 |
|
630 let iter = state.iteration(); |
|
631 stats.this_iters += 1; |
|
632 |
|
633 // Give statistics if needed |
|
634 state.if_verbose(|| { |
|
635 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ_prev); |
|
636 full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new())) |
|
637 }); |
|
638 |
656 // Update main tolerance for next iteration |
639 // Update main tolerance for next iteration |
657 let ε_prev = ε; |
640 ε = tolerance.update(ε, iter); |
658 ε = tolerance.update(ε, state.iteration()); |
641 } |
659 stats.this_iters += 1; |
|
660 |
|
661 // Give function value if needed |
|
662 state.if_verbose(|| { |
|
663 // Plot if so requested |
|
664 plotter.plot_spikes( |
|
665 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, |
|
666 "start".to_string(), Some(&minus_τv), |
|
667 reg.target_bounds(τ, ε_prev), &μ_prev, |
|
668 ); |
|
669 // Calculate mean inner iterations and reset relevant counters. |
|
670 // Return the statistics |
|
671 let res = IterInfo { |
|
672 value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev), |
|
673 n_spikes : μ_prev.len(), |
|
674 ε : ε_prev, |
|
675 postprocessing: config.postprocessing.then(|| μ_prev.clone()), |
|
676 .. stats |
|
677 }; |
|
678 stats = IterInfo::new(); |
|
679 res |
|
680 }) |
|
681 }); |
|
682 |
642 |
683 postprocess(μ_prev, config, L2Squared, opA, b) |
643 postprocess(μ_prev, config, L2Squared, opA, b) |
684 } |
644 } |