src/fb.rs

changeset 24
d29d1fcf5423
parent 13
bdc57366d4f5
equal deleted inserted replaced
23:9869fa1e0ccd 24:d29d1fcf5423
4 This corresponds to the manuscript 4 This corresponds to the manuscript
5 5
6 * Valkonen T. - _Proximal methods for point source localisation_, 6 * Valkonen T. - _Proximal methods for point source localisation_,
7 [arXiv:2212.02991](https://arxiv.org/abs/2212.02991). 7 [arXiv:2212.02991](https://arxiv.org/abs/2212.02991).
8 8
9 The main routine is [`pointsource_fb`]. It is based on [`generic_pointsource_fb`], which is also 9 The main routine is [`pointsource_fb_reg`]. It is based on [`generic_pointsource_fb_reg`], which is
10 used by our [primal-dual proximal splitting][crate::pdps] implementation. 10 also used by our [primal-dual proximal splitting][crate::pdps] implementation.
11 11
12 FISTA-type inertia can also be enabled through [`FBConfig::meta`]. 12 FISTA-type inertia can also be enabled through [`FBConfig::meta`].
13 13
14 ## Problem 14 ## Problem
15 15
81 */ 81 */
82 82
83 use numeric_literals::replace_float_literals; 83 use numeric_literals::replace_float_literals;
84 use serde::{Serialize, Deserialize}; 84 use serde::{Serialize, Deserialize};
85 use colored::Colorize; 85 use colored::Colorize;
86 use nalgebra::DVector; 86 use nalgebra::{DVector, DMatrix};
87 87
88 use alg_tools::iterate::{ 88 use alg_tools::iterate::{
89 AlgIteratorFactory, 89 AlgIteratorFactory,
90 AlgIteratorState, 90 AlgIteratorState,
91 }; 91 };
92 use alg_tools::euclidean::Euclidean; 92 use alg_tools::euclidean::Euclidean;
93 use alg_tools::norms::Norm;
94 use alg_tools::linops::Apply; 93 use alg_tools::linops::Apply;
95 use alg_tools::sets::Cube; 94 use alg_tools::sets::Cube;
96 use alg_tools::loc::Loc; 95 use alg_tools::loc::Loc;
96 use alg_tools::mapping::Mapping;
97 use alg_tools::bisection_tree::{ 97 use alg_tools::bisection_tree::{
98 BTFN, 98 BTFN,
99 PreBTFN, 99 PreBTFN,
100 Bounds, 100 Bounds,
101 BTNodeLookup, 101 BTNodeLookup,
111 111
112 use crate::types::*; 112 use crate::types::*;
113 use crate::measures::{ 113 use crate::measures::{
114 DiscreteMeasure, 114 DiscreteMeasure,
115 DeltaMeasure, 115 DeltaMeasure,
116 Radon
117 }; 116 };
118 use crate::measures::merging::{ 117 use crate::measures::merging::{
119 SpikeMergingMethod, 118 SpikeMergingMethod,
120 SpikeMerging, 119 SpikeMerging,
121 }; 120 };
122 use crate::forward_model::ForwardModel; 121 use crate::forward_model::ForwardModel;
123 use crate::seminorms::{ 122 use crate::seminorms::{
124 DiscreteMeasureOp, Lipschitz 123 DiscreteMeasureOp, Lipschitz
125 }; 124 };
126 use crate::subproblem::{ 125 use crate::subproblem::{
127 quadratic_nonneg, 126 nonneg::quadratic_nonneg,
127 unconstrained::quadratic_unconstrained,
128 InnerSettings, 128 InnerSettings,
129 InnerMethod, 129 InnerMethod,
130 }; 130 };
131 use crate::tolerance::Tolerance; 131 use crate::tolerance::Tolerance;
132 use crate::plot::{ 132 use crate::plot::{
133 SeqPlotter, 133 SeqPlotter,
134 Plotting, 134 Plotting,
135 PlotLookup 135 PlotLookup
136 }; 136 };
137 use crate::regularisation::{
138 NonnegRadonRegTerm,
139 RadonRegTerm,
140 };
137 141
138 /// Method for constructing $μ$ on each iteration 142 /// Method for constructing $μ$ on each iteration
139 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 143 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
140 #[allow(dead_code)] 144 #[allow(dead_code)]
141 pub enum InsertionStyle { 145 pub enum InsertionStyle {
154 None, 158 None,
155 /// FISTA-style inertia 159 /// FISTA-style inertia
156 InertiaFISTA, 160 InertiaFISTA,
157 } 161 }
158 162
159 /// Settings for [`pointsource_fb`]. 163 /// Settings for [`pointsource_fb_reg`].
160 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 164 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
161 #[serde(default)] 165 #[serde(default)]
162 pub struct FBConfig<F : Float> { 166 pub struct FBConfig<F : Float> {
163 /// Step length scaling 167 /// Step length scaling
164 pub τ0 : F, 168 pub τ0 : F,
167 /// Generic parameters 171 /// Generic parameters
168 pub insertion : FBGenericConfig<F>, 172 pub insertion : FBGenericConfig<F>,
169 } 173 }
170 174
171 /// Settings for the solution of the stepwise optimality condition in algorithms based on 175 /// Settings for the solution of the stepwise optimality condition in algorithms based on
172 /// [`generic_pointsource_fb`]. 176 /// [`generic_pointsource_fb_reg`].
173 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 177 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
174 #[serde(default)] 178 #[serde(default)]
175 pub struct FBGenericConfig<F : Float> { 179 pub struct FBGenericConfig<F : Float> {
176 /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`]. 180 /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`].
177 pub insertion_style : InsertionStyle, 181 pub insertion_style : InsertionStyle,
234 postprocessing : false, 238 postprocessing : false,
235 } 239 }
236 } 240 }
237 } 241 }
238 242
239 /// Trait for specialisation of [`generic_pointsource_fb`] to basic FB, FISTA. 243 /// Trait for specialisation of [`generic_pointsource_fb_reg`] to basic FB, FISTA.
240 /// 244 ///
241 /// The idea is that the residual $Aμ - b$ in the forward step can be replaced by an arbitrary 245 /// The idea is that the residual $Aμ - b$ in the forward step can be replaced by an arbitrary
242 /// value. For example, to implement [primal-dual proximal splitting][crate::pdps] we replace it 246 /// value. For example, to implement [primal-dual proximal splitting][crate::pdps] we replace it
243 /// with the dual variable $y$. We can then also implement alternative data terms, as the 247 /// with the dual variable $y$. We can then also implement alternative data terms, as the
244 /// (pre)differential of $F(μ)=F\_0(Aμ-b)$ is $F\'(μ) = A\_*F\_0\'(Aμ-b)$. In the case of the 248 /// (pre)differential of $F(μ)=F\_0(Aμ-b)$ is $F\'(μ) = A\_*F\_0\'(Aμ-b)$. In the case of the
299 -> &'c DiscreteMeasure<Loc<F, N>, F> { 303 -> &'c DiscreteMeasure<Loc<F, N>, F> {
300 μ 304 μ
301 } 305 }
302 } 306 }
303 307
304 /// Specialisation of [`generic_pointsource_fb`] to basic μFB. 308 /// Specialisation of [`generic_pointsource_fb_reg`] to basic μFB.
305 struct BasicFB< 309 struct BasicFB<
306 'a, 310 'a,
307 F : Float + ToNalgebraRealField, 311 F : Float + ToNalgebraRealField,
308 A : ForwardModel<Loc<F, N>, F>, 312 A : ForwardModel<Loc<F, N>, F>,
309 const N : usize 313 const N : usize
338 self.opA.gemv(&mut residual, 1.0, μ, -1.0); 342 self.opA.gemv(&mut residual, 1.0, μ, -1.0);
339 residual.norm2_squared_div2() 343 residual.norm2_squared_div2()
340 } 344 }
341 } 345 }
342 346
343 /// Specialisation of [`generic_pointsource_fb`] to FISTA. 347 /// Specialisation of [`generic_pointsource_fb_reg`] to FISTA.
344 struct FISTA< 348 struct FISTA<
345 'a, 349 'a,
346 F : Float + ToNalgebraRealField, 350 F : Float + ToNalgebraRealField,
347 A : ForwardModel<Loc<F, N>, F>, 351 A : ForwardModel<Loc<F, N>, F>,
348 const N : usize 352 const N : usize
421 -> &'c DiscreteMeasure<Loc<F, N>, F> { 425 -> &'c DiscreteMeasure<Loc<F, N>, F> {
422 &self.μ_prev 426 &self.μ_prev
423 } 427 }
424 } 428 }
425 429
426 /// Iteratively solve the pointsource localisation problem using forward-backward splitting 430
427 /// 431 /// Abstraction of regularisation terms for [`generic_pointsource_fb_reg`].
428 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the 432 pub trait RegTerm<F : Float + ToNalgebraRealField, const N : usize>
429 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. 433 : for<'a> Apply<&'a DiscreteMeasure<Loc<F, N>, F>, Output = F> {
430 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution 434 /// Approximately solve the problem
431 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control 435 /// <div>$$
432 /// as documented in [`alg_tools::iterate`]. 436 /// \min_{x ∈ ℝ^n} \frac{1}{2} x^⊤Ax - g^⊤ x + τ G(x)
433 /// 437 /// $$</div>
434 /// For details on the mathematical formulation, see the [module level](self) documentation. 438 /// for $G$ depending on the trait implementation.
435 /// 439 ///
436 /// Returns the final iterate. 440 /// The parameter `mA` is $A$. An estimate for its opeator norm should be provided in
441 /// `mA_normest`. The initial iterate and output is `x`. The current main tolerance is `ε`.
442 ///
443 /// Returns the number of iterations taken.
444 fn solve_findim(
445 &self,
446 mA : &DMatrix<F::MixedType>,
447 g : &DVector<F::MixedType>,
448 τ : F,
449 x : &mut DVector<F::MixedType>,
450 mA_normest : F,
451 ε : F,
452 config : &FBGenericConfig<F>
453 ) -> usize;
454
455 /// Find a point where `d` may violate the tolerance `ε`.
456 ///
457 /// If `skip_by_rough_check` is set, do not find the point if a rough check indicates that we
458 /// are in bounds. `ε` is the current main tolerance and `τ` a scaling factor for the
459 /// regulariser.
460 ///
461 /// Returns `None` if `d` is in bounds either based on the rough check, or a more precise check
462 /// terminating early. Otherwise returns a possibly violating point, the value of `d` there,
463 /// and a boolean indicating whether the found point is in bounds.
464 fn find_tolerance_violation<G, BT>(
465 &self,
466 d : &mut BTFN<F, G, BT, N>,
467 τ : F,
468 ε : F,
469 skip_by_rough_check : bool,
470 config : &FBGenericConfig<F>,
471 ) -> Option<(Loc<F, N>, F, bool)>
472 where BT : BTSearch<F, N, Agg=Bounds<F>>,
473 G : SupportGenerator<F, N, Id=BT::Data>,
474 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
475 + LocalAnalysis<F, Bounds<F>, N>;
476
477 /// Verify that `d` is in bounds `ε` for a merge candidate `μ`
478 ///
479 /// `ε` is the current main tolerance and `τ` a scaling factor for the regulariser.
480 fn verify_merge_candidate<G, BT>(
481 &self,
482 d : &mut BTFN<F, G, BT, N>,
483 μ : &DiscreteMeasure<Loc<F, N>, F>,
484 τ : F,
485 ε : F,
486 config : &FBGenericConfig<F>,
487 ) -> bool
488 where BT : BTSearch<F, N, Agg=Bounds<F>>,
489 G : SupportGenerator<F, N, Id=BT::Data>,
490 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
491 + LocalAnalysis<F, Bounds<F>, N>;
492
493 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>>;
494
495 /// Returns a scaling factor for the tolerance sequence.
496 ///
497 /// Typically this is the regularisation parameter.
498 fn tolerance_scaling(&self) -> F;
499 }
500
437 #[replace_float_literals(F::cast_from(literal))] 501 #[replace_float_literals(F::cast_from(literal))]
438 pub fn pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, const N : usize>( 502 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for NonnegRadonRegTerm<F>
439 opA : &'a A, 503 where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
440 b : &A::Observable, 504 fn solve_findim(
441 α : F, 505 &self,
442 op𝒟 : &'a 𝒟, 506 mA : &DMatrix<F::MixedType>,
443 config : &FBConfig<F>, 507 g : &DVector<F::MixedType>,
444 iterator : I, 508 τ : F,
445 plotter : SeqPlotter<F, N> 509 x : &mut DVector<F::MixedType>,
446 ) -> DiscreteMeasure<Loc<F, N>, F> 510 mA_normest : F,
447 where F : Float + ToNalgebraRealField, 511 ε : F,
448 I : AlgIteratorFactory<IterInfo<F, N>>, 512 config : &FBGenericConfig<F>
449 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, 513 ) -> usize {
450 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow 514 let inner_tolerance = ε * config.inner.tolerance_mult;
451 A::Observable : std::ops::MulAssign<F>, 515 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
452 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 516 let inner_τ = config.inner.τ0 / mA_normest;
453 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> 517 quadratic_nonneg(config.inner.method, mA, g, τ * self.α(), x,
454 + Lipschitz<𝒟, FloatType=F>, 518 inner_τ, inner_it)
455 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 519 }
456 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 520
457 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 521 #[inline]
458 𝒟::Codomain : RealMapping<F, N>, 522 fn find_tolerance_violation<G, BT>(
459 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 523 &self,
460 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 524 d : &mut BTFN<F, G, BT, N>,
461 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 525 τ : F,
462 Cube<F, N>: P2Minimise<Loc<F, N>, F>, 526 ε : F,
463 PlotLookup : Plotting<N>, 527 skip_by_rough_check : bool,
464 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { 528 config : &FBGenericConfig<F>,
465 529 ) -> Option<(Loc<F, N>, F, bool)>
466 let initial_residual = -b; 530 where BT : BTSearch<F, N, Agg=Bounds<F>>,
467 let τ = config.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); 531 G : SupportGenerator<F, N, Id=BT::Data>,
468 532 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
469 match config.meta { 533 + LocalAnalysis<F, Bounds<F>, N> {
470 FBMetaAlgorithm::None => generic_pointsource_fb( 534 let τα = τ * self.α();
471 opA, α, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, 535 let keep_below = τα + ε;
472 BasicFB{ b, opA } 536 let maximise_above = τα + ε * config.insertion_cutoff_factor;
473 ), 537 let refinement_tolerance = ε * config.refinement.tolerance_mult;
474 FBMetaAlgorithm::InertiaFISTA => generic_pointsource_fb( 538
475 opA, α, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, 539 // If preliminary check indicates that we are in bonds, and if it otherwise matches
476 FISTA{ b, opA, λ : 1.0, μ_prev : DiscreteMeasure::new() } 540 // the insertion strategy, skip insertion.
477 ), 541 if skip_by_rough_check && d.bounds().upper() <= keep_below {
478 } 542 None
479 } 543 } else {
480 544 // If the rough check didn't indicate no insertion needed, find maximising point.
481 /// Generic implementation of [`pointsource_fb`]. 545 d.maximise_above(maximise_above, refinement_tolerance, config.refinement.max_steps)
546 .map(|(ξ, v_ξ)| (ξ, v_ξ, v_ξ <= keep_below))
547 }
548 }
549
550 fn verify_merge_candidate<G, BT>(
551 &self,
552 d : &mut BTFN<F, G, BT, N>,
553 μ : &DiscreteMeasure<Loc<F, N>, F>,
554 τ : F,
555 ε : F,
556 config : &FBGenericConfig<F>,
557 ) -> bool
558 where BT : BTSearch<F, N, Agg=Bounds<F>>,
559 G : SupportGenerator<F, N, Id=BT::Data>,
560 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
561 + LocalAnalysis<F, Bounds<F>, N> {
562 let τα = τ * self.α();
563 let refinement_tolerance = ε * config.refinement.tolerance_mult;
564 let merge_tolerance = config.merge_tolerance_mult * ε;
565 let keep_below = τα + merge_tolerance;
566 let keep_supp_above = τα - merge_tolerance;
567 let bnd = d.bounds();
568
569 return (
570 bnd.lower() >= keep_supp_above
571 ||
572 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
573 (β == 0.0) || d.apply(x) >= keep_supp_above
574 }).all(std::convert::identity)
575 ) && (
576 bnd.upper() <= keep_below
577 ||
578 d.has_upper_bound(keep_below, refinement_tolerance, config.refinement.max_steps)
579 )
580 }
581
582 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
583 let τα = τ * self.α();
584 Some(Bounds(τα - ε, τα + ε))
585 }
586
587 fn tolerance_scaling(&self) -> F {
588 self.α()
589 }
590 }
591
592 #[replace_float_literals(F::cast_from(literal))]
593 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for RadonRegTerm<F>
594 where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
595 fn solve_findim(
596 &self,
597 mA : &DMatrix<F::MixedType>,
598 g : &DVector<F::MixedType>,
599 τ : F,
600 x : &mut DVector<F::MixedType>,
601 mA_normest: F,
602 ε : F,
603 config : &FBGenericConfig<F>
604 ) -> usize {
605 let inner_tolerance = ε * config.inner.tolerance_mult;
606 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
607 let inner_τ = config.inner.τ0 / mA_normest;
608 quadratic_unconstrained(config.inner.method, mA, g, τ * self.α(), x,
609 inner_τ, inner_it)
610 }
611
612 fn find_tolerance_violation<G, BT>(
613 &self,
614 d : &mut BTFN<F, G, BT, N>,
615 τ : F,
616 ε : F,
617 skip_by_rough_check : bool,
618 config : &FBGenericConfig<F>,
619 ) -> Option<(Loc<F, N>, F, bool)>
620 where BT : BTSearch<F, N, Agg=Bounds<F>>,
621 G : SupportGenerator<F, N, Id=BT::Data>,
622 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
623 + LocalAnalysis<F, Bounds<F>, N> {
624 let τα = τ * self.α();
625 let keep_below = τα + ε;
626 let keep_above = -τα - ε;
627 let maximise_above = τα + ε * config.insertion_cutoff_factor;
628 let minimise_below = -τα - ε * config.insertion_cutoff_factor;
629 let refinement_tolerance = ε * config.refinement.tolerance_mult;
630
631 // If preliminary check indicates that we are in bonds, and if it otherwise matches
632 // the insertion strategy, skip insertion.
633 if skip_by_rough_check && Bounds(keep_above, keep_below).superset(&d.bounds()) {
634 None
635 } else {
636 // If the rough check didn't indicate no insertion needed, find maximising point.
637 let mx = d.maximise_above(maximise_above, refinement_tolerance,
638 config.refinement.max_steps);
639 let mi = d.minimise_below(minimise_below, refinement_tolerance,
640 config.refinement.max_steps);
641
642 match (mx, mi) {
643 (None, None) => None,
644 (Some((ξ, v_ξ)), None) => Some((ξ, v_ξ, keep_below >= v_ξ)),
645 (None, Some((ζ, v_ζ))) => Some((ζ, v_ζ, keep_above <= v_ζ)),
646 (Some((ξ, v_ξ)), Some((ζ, v_ζ))) => {
647 if v_ξ - τα > τα - v_ζ {
648 Some((ξ, v_ξ, keep_below >= v_ξ))
649 } else {
650 Some((ζ, v_ζ, keep_above <= v_ζ))
651 }
652 }
653 }
654 }
655 }
656
657 fn verify_merge_candidate<G, BT>(
658 &self,
659 d : &mut BTFN<F, G, BT, N>,
660 μ : &DiscreteMeasure<Loc<F, N>, F>,
661 τ : F,
662 ε : F,
663 config : &FBGenericConfig<F>,
664 ) -> bool
665 where BT : BTSearch<F, N, Agg=Bounds<F>>,
666 G : SupportGenerator<F, N, Id=BT::Data>,
667 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
668 + LocalAnalysis<F, Bounds<F>, N> {
669 let τα = τ * self.α();
670 let refinement_tolerance = ε * config.refinement.tolerance_mult;
671 let merge_tolerance = config.merge_tolerance_mult * ε;
672 let keep_below = τα + merge_tolerance;
673 let keep_above = -τα - merge_tolerance;
674 let keep_supp_pos_above = τα - merge_tolerance;
675 let keep_supp_neg_below = -τα + merge_tolerance;
676 let bnd = d.bounds();
677
678 return (
679 (bnd.lower() >= keep_supp_pos_above && bnd.upper() <= keep_supp_neg_below)
680 ||
681 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
682 use std::cmp::Ordering::*;
683 match β.partial_cmp(&0.0) {
684 Some(Greater) => d.apply(x) >= keep_supp_pos_above,
685 Some(Less) => d.apply(x) <= keep_supp_neg_below,
686 _ => true,
687 }
688 }).all(std::convert::identity)
689 ) && (
690 bnd.upper() <= keep_below
691 ||
692 d.has_upper_bound(keep_below, refinement_tolerance,
693 config.refinement.max_steps)
694 ) && (
695 bnd.lower() >= keep_above
696 ||
697 d.has_lower_bound(keep_above, refinement_tolerance,
698 config.refinement.max_steps)
699 )
700 }
701
702 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
703 let τα = τ * self.α();
704 Some(Bounds(-τα - ε, τα + ε))
705 }
706
707 fn tolerance_scaling(&self) -> F {
708 self.α()
709 }
710 }
711
712
713 /// Generic implementation of [`pointsource_fb_reg`].
482 /// 714 ///
483 /// The method can be specialised to even primal-dual proximal splitting through the 715 /// The method can be specialised to even primal-dual proximal splitting through the
484 /// [`FBSpecialisation`] parameter `specialisation`. 716 /// [`FBSpecialisation`] parameter `specialisation`.
485 /// The settings in `config` have their [respective documentation](FBGenericConfig). `opA` is the 717 /// The settings in `config` have their [respective documentation](FBGenericConfig). `opA` is the
486 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. 718 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
495 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features 727 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
496 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. 728 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
497 /// 729 ///
498 /// Returns the final iterate. 730 /// Returns the final iterate.
499 #[replace_float_literals(F::cast_from(literal))] 731 #[replace_float_literals(F::cast_from(literal))]
500 pub fn generic_pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, const N : usize>( 732 pub fn generic_pointsource_fb_reg<
733 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, Reg, const N : usize
734 >(
501 opA : &'a A, 735 opA : &'a A,
502 α : F, 736 reg : Reg,
503 op𝒟 : &'a 𝒟, 737 op𝒟 : &'a 𝒟,
504 mut τ : F, 738 mut τ : F,
505 config : &FBGenericConfig<F>, 739 config : &FBGenericConfig<F>,
506 iterator : I, 740 iterator : I,
507 mut plotter : SeqPlotter<F, N>, 741 mut plotter : SeqPlotter<F, N>,
508 mut residual : A::Observable, 742 mut residual : A::Observable,
509 mut specialisation : Spec, 743 mut specialisation : Spec
510 ) -> DiscreteMeasure<Loc<F, N>, F> 744 ) -> DiscreteMeasure<Loc<F, N>, F>
511 where F : Float + ToNalgebraRealField, 745 where F : Float + ToNalgebraRealField,
512 I : AlgIteratorFactory<IterInfo<F, N>>, 746 I : AlgIteratorFactory<IterInfo<F, N>>,
513 Spec : FBSpecialisation<F, A::Observable, N>, 747 Spec : FBSpecialisation<F, A::Observable, N>,
514 A::Observable : std::ops::MulAssign<F>, 748 A::Observable : std::ops::MulAssign<F>,
520 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 754 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
521 𝒟::Codomain : RealMapping<F, N>, 755 𝒟::Codomain : RealMapping<F, N>,
522 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 756 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
523 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 757 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
524 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 758 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
525 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
526 PlotLookup : Plotting<N>, 759 PlotLookup : Plotting<N>,
527 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { 760 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
761 Reg : RegTerm<F, N> {
528 762
529 // Set up parameters 763 // Set up parameters
530 let quiet = iterator.is_quiet(); 764 let quiet = iterator.is_quiet();
531 let op𝒟norm = op𝒟.opnorm_bound(); 765 let op𝒟norm = op𝒟.opnorm_bound();
532 // We multiply tolerance by τ for FB since 766 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
533 // our subproblems depending on tolerances are scaled by τ compared to the conditional 767 // by τ compared to the conditional gradient approach.
534 // gradient approach. 768 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
535 let tolerance = config.tolerance * τ * α;
536 let mut ε = tolerance.initial(); 769 let mut ε = tolerance.initial();
537 770
538 // Initialise operators 771 // Initialise operators
539 let preadjA = opA.preadjoint(); 772 let preadjA = opA.preadjoint();
540 773
566 ν 799 ν
567 }; 800 };
568 801
569 // Run the algorithm 802 // Run the algorithm
570 iterator.iterate(|state| { 803 iterator.iterate(|state| {
571 // Calculate subproblem tolerances, and update main tolerance for next iteration
572 let τα = τ * α;
573 let target_bounds = Bounds(τα - ε, τα + ε);
574 let merge_tolerance = config.merge_tolerance_mult * ε;
575 let merge_target_bounds = Bounds(τα - merge_tolerance, τα + merge_tolerance);
576 let inner_tolerance = ε * config.inner.tolerance_mult;
577 let refinement_tolerance = ε * config.refinement.tolerance_mult;
578 let maximise_above = τα + ε * config.insertion_cutoff_factor;
579 let ε_prev = ε;
580 ε = tolerance.update(ε, state.iteration());
581
582 // Maximum insertion count and measure difference calculation depend on insertion style. 804 // Maximum insertion count and measure difference calculation depend on insertion style.
583 let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { 805 let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
584 (i, Some((l, k))) if i <= l => (k, false), 806 (i, Some((l, k))) if i <= l => (k, false),
585 _ => (config.max_insertions, !quiet), 807 _ => (config.max_insertions, !quiet),
586 }; 808 };
624 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. 846 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
625 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ 847 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
626 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ 848 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
627 // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 849 // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2
628 // = n |𝒟| |x|_2, where n is the number of points. Therefore 850 // = n |𝒟| |x|_2, where n is the number of points. Therefore
629 let inner_τ = config.inner.τ0 / (op𝒟norm * F::cast_from(μ.len())); 851 let Ã_normest = op𝒟norm * F::cast_from(μ.len());
630 852
631 // Solve finite-dimensional subproblem. 853 // Solve finite-dimensional subproblem.
632 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); 854 inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config);
633 inner_iters += quadratic_nonneg(config.inner.method, &Ã, &g̃, τ*α, &mut x,
634 inner_τ, inner_it);
635 855
636 // Update masses of μ based on solution of finite-dimensional subproblem. 856 // Update masses of μ based on solution of finite-dimensional subproblem.
637 μ.set_masses_dvector(&x); 857 μ.set_masses_dvector(&x);
638 } 858 }
639 859
642 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ, &μ_base)); 862 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ, &μ_base));
643 863
644 // If no merging heuristic is used, let's be more conservative about spike insertion, 864 // If no merging heuristic is used, let's be more conservative about spike insertion,
645 // and skip it after first round. If merging is done, being more greedy about spike 865 // and skip it after first round. If merging is done, being more greedy about spike
646 // insertion also seems to improve performance. 866 // insertion also seems to improve performance.
647 let may_break = if let SpikeMergingMethod::None = config.merging { 867 let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging {
648 false 868 false
649 } else { 869 } else {
650 count > 0 870 count > 0
651 }; 871 };
652 872
653 // If preliminary check indicates that we are in bonds, and if it otherwise matches 873 // Find a spike to insert, if needed
654 // the insertion strategy, skip insertion. 874 let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation(
655 if may_break && target_bounds.superset(&d.bounds()) { 875 &mut d, τ, ε, skip_by_rough_check, config
656 break 'insertion (true, d) 876 ) {
657 }
658
659 // If the rough check didn't indicate stopping, find maximising point, maintaining for
660 // the calculations in the beginning of the loop that v_ξ = (ω0-τv-𝒟μ)(ξ) = d(ξ),
661 // where 𝒟μ is now distinct from μ0 after the insertions already performed.
662 // We do not need to check lower bounds, as a solution of the finite-dimensional
663 // subproblem should always satisfy them.
664
665 // If μ has some spikes, only find a maximum of d if it is above a threshold
666 // defined by the refinment tolerance.
667 let (ξ, v_ξ) = match d.maximise_above(maximise_above, refinement_tolerance,
668 config.refinement.max_steps) {
669 None => break 'insertion (true, d), 877 None => break 'insertion (true, d),
670 Some(res) => res, 878 Some(res) => res,
671 }; 879 };
672 880
673 // Break if maximum insertion count reached 881 // Break if maximum insertion count reached
674 if count >= max_insertions { 882 if count >= max_insertions {
675 let in_bounds2 = target_bounds.upper() >= v_ξ; 883 break 'insertion (in_bounds, d)
676 break 'insertion (in_bounds2, d)
677 } 884 }
678 885
679 // No point in optimising the weight here; the finite-dimensional algorithm is fast. 886 // No point in optimising the weight here; the finite-dimensional algorithm is fast.
680 μ += DeltaMeasure { x : ξ, α : 0.0 }; 887 μ += DeltaMeasure { x : ξ, α : 0.0 };
681 count += 1; 888 count += 1;
693 if state.iteration() % config.merge_every == 0 { 900 if state.iteration() % config.merge_every == 0 {
694 let n_before_merge = μ.len(); 901 let n_before_merge = μ.len();
695 μ.merge_spikes(config.merging, |μ_candidate| { 902 μ.merge_spikes(config.merging, |μ_candidate| {
696 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ_candidate, &μ_base)); 903 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ_candidate, &μ_base));
697 904
698 if merge_target_bounds.superset(&d.bounds()) { 905 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
699 return Some(()) 906 .then_some(())
700 }
701
702 let d_min_supp = μ_candidate.iter_spikes().filter_map(|&DeltaMeasure{ α, ref x }| {
703 (α != 0.0).then(|| d.apply(x))
704 }).reduce(F::min);
705
706 if d_min_supp.map_or(true, |b| b >= merge_target_bounds.lower()) &&
707 d.has_upper_bound(merge_target_bounds.upper(), refinement_tolerance,
708 config.refinement.max_steps) {
709 Some(())
710 } else {
711 None
712 }
713 }); 907 });
714 debug_assert!(μ.len() >= n_before_merge); 908 debug_assert!(μ.len() >= n_before_merge);
715 merged += μ.len() - n_before_merge; 909 merged += μ.len() - n_before_merge;
716 } 910 }
717 911
723 debug_assert!(μ.len() <= n_before_prune); 917 debug_assert!(μ.len() <= n_before_prune);
724 pruned += n_before_prune - μ.len(); 918 pruned += n_before_prune - μ.len();
725 919
726 this_iters += 1; 920 this_iters += 1;
727 921
922 // Update main tolerance for next iteration
923 let ε_prev = ε;
924 ε = tolerance.update(ε, state.iteration());
925
728 // Give function value if needed 926 // Give function value if needed
729 state.if_verbose(|| { 927 state.if_verbose(|| {
730 let value_μ = specialisation.value_μ(&μ); 928 let value_μ = specialisation.value_μ(&μ);
731 // Plot if so requested 929 // Plot if so requested
732 plotter.plot_spikes( 930 plotter.plot_spikes(
733 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, 931 format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
734 "start".to_string(), Some(&minus_τv), 932 "start".to_string(), Some(&minus_τv),
735 Some(target_bounds), value_μ, 933 reg.target_bounds(τ, ε_prev), value_μ,
736 ); 934 );
737 // Calculate mean inner iterations and reset relevant counters. 935 // Calculate mean inner iterations and reset relevant counters.
738 // Return the statistics 936 // Return the statistics
739 let res = IterInfo { 937 let res = IterInfo {
740 value : specialisation.calculate_fit(&μ, &residual) + α * value_μ.norm(Radon), 938 value : specialisation.calculate_fit(&μ, &residual) + reg.apply(value_μ),
741 n_spikes : value_μ.len(), 939 n_spikes : value_μ.len(),
742 inner_iters, 940 inner_iters,
743 this_iters, 941 this_iters,
744 merged, 942 merged,
745 pruned, 943 pruned,
755 }); 953 });
756 954
757 specialisation.postprocess(μ, config.final_merging) 955 specialisation.postprocess(μ, config.final_merging)
758 } 956 }
759 957
760 958 /// Iteratively solve the pointsource localisation problem using forward-backward splitting
761 959 ///
762 960 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
961 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
962 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
963 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
964 /// as documented in [`alg_tools::iterate`].
965 ///
966 /// For details on the mathematical formulation, see the [module level](self) documentation.
967 ///
968 /// Returns the final iterate.
969 #[replace_float_literals(F::cast_from(literal))]
970 pub fn pointsource_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>(
971 opA : &'a A,
972 b : &A::Observable,
973 reg : Reg,
974 op𝒟 : &'a 𝒟,
975 config : &FBConfig<F>,
976 iterator : I,
977 plotter : SeqPlotter<F, N>,
978 ) -> DiscreteMeasure<Loc<F, N>, F>
979 where F : Float + ToNalgebraRealField,
980 I : AlgIteratorFactory<IterInfo<F, N>>,
981 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
982 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
983 A::Observable : std::ops::MulAssign<F>,
984 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
985 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
986 + Lipschitz<𝒟, FloatType=F>,
987 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
988 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
989 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
990 𝒟::Codomain : RealMapping<F, N>,
991 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
992 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
993 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
994 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
995 PlotLookup : Plotting<N>,
996 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
997 Reg : RegTerm<F, N> {
998
999 let initial_residual = -b;
1000 let τ = config.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
1001
1002 match config.meta {
1003 FBMetaAlgorithm::None => generic_pointsource_fb_reg(
1004 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual,
1005 BasicFB{ b, opA },
1006 ),
1007 FBMetaAlgorithm::InertiaFISTA => generic_pointsource_fb_reg(
1008 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual,
1009 FISTA{ b, opA, λ : 1.0, μ_prev : DiscreteMeasure::new() },
1010 ),
1011 }
1012 }
1013
1014 //
1015 // Deprecated interfaces
1016 //
1017
1018 #[deprecated(note = "Use `pointsource_fb_reg`")]
1019 pub fn pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, const N : usize>(
1020 opA : &'a A,
1021 b : &A::Observable,
1022 α : F,
1023 op𝒟 : &'a 𝒟,
1024 config : &FBConfig<F>,
1025 iterator : I,
1026 plotter : SeqPlotter<F, N>
1027 ) -> DiscreteMeasure<Loc<F, N>, F>
1028 where F : Float + ToNalgebraRealField,
1029 I : AlgIteratorFactory<IterInfo<F, N>>,
1030 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
1031 A::Observable : std::ops::MulAssign<F>,
1032 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
1033 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
1034 + Lipschitz<𝒟, FloatType=F>,
1035 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
1036 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
1037 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
1038 𝒟::Codomain : RealMapping<F, N>,
1039 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
1040 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
1041 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
1042 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
1043 PlotLookup : Plotting<N>,
1044 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
1045
1046 pointsource_fb_reg(opA, b, NonnegRadonRegTerm(α), op𝒟, config, iterator, plotter)
1047 }
1048
1049
1050 #[deprecated(note = "Use `generic_pointsource_fb_reg`")]
1051 pub fn generic_pointsource_fb<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, const N : usize>(
1052 opA : &'a A,
1053 α : F,
1054 op𝒟 : &'a 𝒟,
1055 τ : F,
1056 config : &FBGenericConfig<F>,
1057 iterator : I,
1058 plotter : SeqPlotter<F, N>,
1059 residual : A::Observable,
1060 specialisation : Spec,
1061 ) -> DiscreteMeasure<Loc<F, N>, F>
1062 where F : Float + ToNalgebraRealField,
1063 I : AlgIteratorFactory<IterInfo<F, N>>,
1064 Spec : FBSpecialisation<F, A::Observable, N>,
1065 A::Observable : std::ops::MulAssign<F>,
1066 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
1067 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
1068 + Lipschitz<𝒟, FloatType=F>,
1069 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
1070 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
1071 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
1072 𝒟::Codomain : RealMapping<F, N>,
1073 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
1074 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
1075 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
1076 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
1077 PlotLookup : Plotting<N>,
1078 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
1079
1080 generic_pointsource_fb_reg(opA, NonnegRadonRegTerm(α), op𝒟, τ, config, iterator, plotter,
1081 residual, specialisation)
1082 }

mercurial