src/fb.rs

branch
dev
changeset 35
b087e3eab191
parent 34
efa60bc4f743
child 37
c5d8bd1a7728
--- a/src/fb.rs	Thu Aug 29 00:00:00 2024 -0500
+++ b/src/fb.rs	Tue Dec 31 09:25:45 2024 -0500
@@ -6,10 +6,7 @@
  * Valkonen T. - _Proximal methods for point source localisation_,
    [arXiv:2212.02991](https://arxiv.org/abs/2212.02991).
 
-The main routine is [`pointsource_fb_reg`]. It is based on [`generic_pointsource_fb_reg`], which is
-also used by our [primal-dual proximal splitting][crate::pdps] implementation.
-
-FISTA-type inertia can also be enabled through [`FBConfig::meta`].
+The main routine is [`pointsource_fb_reg`].
 
 ## Problem
 
@@ -76,7 +73,7 @@
 $$
 </p>
 
-We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by
+We solve this with either SSN or FB as determined by
 [`InnerSettings`] in [`FBGenericConfig::inner`].
 */
 
@@ -87,10 +84,11 @@
 
 use alg_tools::iterate::{
     AlgIteratorFactory,
-    AlgIteratorState,
+    AlgIteratorIteration,
+    AlgIterator,
 };
 use alg_tools::euclidean::Euclidean;
-use alg_tools::linops::{Apply, GEMV};
+use alg_tools::linops::{Mapping, GEMV};
 use alg_tools::sets::Cube;
 use alg_tools::loc::Loc;
 use alg_tools::bisection_tree::{
@@ -107,17 +105,24 @@
 };
 use alg_tools::mapping::RealMapping;
 use alg_tools::nalgebra_support::ToNalgebraRealField;
+use alg_tools::instance::Instance;
+use alg_tools::norms::Linfinity;
 
 use crate::types::*;
 use crate::measures::{
     DiscreteMeasure,
+    RNDM,
     DeltaMeasure,
+    Radon,
 };
 use crate::measures::merging::{
     SpikeMergingMethod,
     SpikeMerging,
 };
-use crate::forward_model::ForwardModel;
+use crate::forward_model::{
+    ForwardModel,
+    AdjointProductBoundedBy
+};
 use crate::seminorms::DiscreteMeasureOp;
 use crate::subproblem::{
     InnerSettings,
@@ -146,8 +151,7 @@
     pub generic : FBGenericConfig<F>,
 }
 
-/// Settings for the solution of the stepwise optimality condition in algorithms based on
-/// [`generic_pointsource_fb_reg`].
+/// Settings for the solution of the stepwise optimality condition.
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
 pub struct FBGenericConfig<F : Float> {
@@ -188,8 +192,8 @@
     /// Iterations between merging heuristic tries
     pub merge_every : usize,
 
-    /// Save $μ$ for postprocessing optimisation
-    pub postprocessing : bool
+    // /// Save $μ$ for postprocessing optimisation
+    // pub postprocessing : bool
 }
 
 #[replace_float_literals(F::cast_from(literal))]
@@ -221,29 +225,36 @@
             final_merging : Default::default(),
             merge_every : 10,
             merge_tolerance_mult : 2.0,
-            postprocessing : false,
+            // postprocessing : false,
         }
     }
 }
 
+impl<F : Float> FBGenericConfig<F> {
+    /// Check if merging should be attempted this iteration
+    pub fn merge_now<I : AlgIterator>(&self, state : &AlgIteratorIteration<I>) -> bool {
+        state.iteration() % self.merge_every == 0
+    }
+}
+
 /// TODO: document.
 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike
 /// locations, while `ν_delta` may have different locations.
 #[replace_float_literals(F::cast_from(literal))]
 pub(crate) fn insert_and_reweigh<
-    'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
+    'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, I, const N : usize
 >(
-    μ : &mut DiscreteMeasure<Loc<F, N>, F>,
-    minus_τv : &BTFN<F, GA, BTA, N>,
-    μ_base : &DiscreteMeasure<Loc<F, N>, F>,
-    ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>,
+    μ : &mut RNDM<F, N>,
+    τv : &BTFN<F, GA, BTA, N>,
+    μ_base : &RNDM<F, N>,
+    ν_delta: Option<&RNDM<F, N>>,
     op𝒟 : &'a 𝒟,
     op𝒟norm : F,
     τ : F,
     ε : F,
     config : &FBGenericConfig<F>,
     reg : &Reg,
-    state : &State,
+    state : &AlgIteratorIteration<I>,
     stats : &mut IterInfo<F, N>,
 ) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool)
 where F : Float + ToNalgebraRealField,
@@ -255,9 +266,8 @@
       S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
       K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
       BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
       Reg : RegTerm<F, N>,
-      State : AlgIteratorState {
+      I : AlgIterator {
 
     // Maximum insertion count and measure difference calculation depend on insertion style.
     let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
@@ -265,11 +275,10 @@
         _ => (config.max_insertions, !state.is_quiet()),
     };
 
-    // TODO: should avoid a copy of μ_base here.
-    let ω0 = op𝒟.apply(match ν_delta {
-        None => μ_base.clone(),
-        Some(ν_d) => &*μ_base + ν_d,
-    });
+    let ω0 = match ν_delta {
+        None => op𝒟.apply(μ_base),
+        Some(ν) => op𝒟.apply(μ_base + ν),
+    };
 
     // Add points to support until within error tolerance or maximum insertion count reached.
     let mut count = 0;
@@ -277,10 +286,12 @@
         if μ.len() > 0 {
             // Form finite-dimensional subproblem. The subproblem references to the original μ^k
             // from the beginning of the iteration are all contained in the immutable c and g.
+            // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
+            // problems have not yet been updated to sign change.
             let à = op𝒟.findim_matrix(μ.iter_locations());
             let g̃ = DVector::from_iterator(μ.len(),
                                            μ.iter_locations()
-                                            .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ))
+                                            .map(|ζ| ω0.apply(ζ) - τv.apply(ζ))
                                             .map(F::to_nalgebra_mixed));
             let mut x = μ.masses_dvector();
 
@@ -298,12 +309,12 @@
             μ.set_masses_dvector(&x);
         }
 
-        // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality
+        // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality
         // conditions in the predual space, and finding new points for insertion, if necessary.
-        let mut d = minus_τv + op𝒟.preapply(match ν_delta {
-            None => μ_base.sub_matching(μ),
-            Some(ν) =>  μ_base.sub_matching(μ) + ν
-        });
+        let mut d = τv + match ν_delta {
+            None => op𝒟.preapply(μ.sub_matching(μ_base)),
+            Some(ν) => op𝒟.preapply(μ.sub_matching(μ_base) - ν)
+        };
 
         // If no merging heuristic is used, let's be more conservative about spike insertion,
         // and skip it after first round. If merging is done, being more greedy about spike
@@ -330,11 +341,9 @@
         // No point in optimising the weight here; the finite-dimensional algorithm is fast.
         *μ += DeltaMeasure { x : ξ, α : 0.0 };
         count += 1;
+        stats.inserted += 1;
     };
 
-    // TODO: should redo everything if some transports cause a problem.
-    // Maybe implementation should call above loop as a closure.
-
     if !within_tolerances && warn_insertions {
         // Complain (but continue) if we failed to get within tolerances
         // by inserting more points.
@@ -346,61 +355,33 @@
     (d, within_tolerances)
 }
 
-#[replace_float_literals(F::cast_from(literal))]
-pub(crate) fn prune_and_maybe_simple_merge<
-    'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
->(
-    μ : &mut DiscreteMeasure<Loc<F, N>, F>,
-    minus_τv : &BTFN<F, GA, BTA, N>,
-    μ_base : &DiscreteMeasure<Loc<F, N>, F>,
-    op𝒟 : &'a 𝒟,
-    τ : F,
-    ε : F,
-    config : &FBGenericConfig<F>,
-    reg : &Reg,
-    state : &State,
-    stats : &mut IterInfo<F, N>,
-)
-where F : Float + ToNalgebraRealField,
-      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
-      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
-      𝒟::Codomain : RealMapping<F, N>,
-      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
-      Reg : RegTerm<F, N>,
-      State : AlgIteratorState {
-    if state.iteration() % config.merge_every == 0 {
-        stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
-            let mut d = minus_τv + op𝒟.preapply(μ_base.sub_matching(&μ_candidate));
-            reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
-        });
-    }
-
+pub(crate) fn prune_with_stats<F : Float, const N : usize>(
+    μ : &mut RNDM<F, N>,
+) -> usize {
     let n_before_prune = μ.len();
     μ.prune();
     debug_assert!(μ.len() <= n_before_prune);
-    stats.pruned += n_before_prune - μ.len();
+    n_before_prune - μ.len()
 }
 
 #[replace_float_literals(F::cast_from(literal))]
 pub(crate) fn postprocess<
     F : Float,
     V : Euclidean<F> + Clone,
-    A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>,
+    A : GEMV<F, RNDM<F, N>, Codomain = V>,
     D : DataTerm<F, V, N>,
     const N : usize
 > (
-    mut μ : DiscreteMeasure<Loc<F, N>, F>,
+    mut μ : RNDM<F, N>,
     config : &FBGenericConfig<F>,
     dataterm : D,
     opA : &A,
     b : &V,
-) -> DiscreteMeasure<Loc<F, N>, F>
-where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
+) -> RNDM<F, N>
+where
+    RNDM<F, N> : SpikeMerging<F>,
+    for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>,
+{
     μ.merge_spikes_fitness(config.merging,
                            |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
                            |&v| v);
@@ -437,15 +418,13 @@
     fbconfig : &FBConfig<F>,
     iterator : I,
     mut plotter : SeqPlotter<F, N>,
-) -> DiscreteMeasure<Loc<F, N>, F>
+) -> RNDM<F, N>
 where F : Float + ToNalgebraRealField,
       I : AlgIteratorFactory<IterInfo<F, N>>,
       for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
-                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
-      A::Observable : std::ops::MulAssign<F>,
       GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
-          + Lipschitz<&'a 𝒟, FloatType=F>,
+      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
+          + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
       𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
@@ -455,13 +434,13 @@
       BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
       Cube<F, N>: P2Minimise<Loc<F, N>, F>,
       PlotLookup : Plotting<N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      RNDM<F, N> : SpikeMerging<F>,
       Reg : RegTerm<F, N> {
 
     // Set up parameters
     let config = &fbconfig.generic;
-    let op𝒟norm = op𝒟.opnorm_bound();
-    let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
+    let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
+    let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap();
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
     let tolerance = config.tolerance * τ * reg.tolerance_scaling();
@@ -470,66 +449,59 @@
     // Initialise iterates
     let mut μ = DiscreteMeasure::new();
     let mut residual = -b;
+
+    // Statistics
+    let full_stats = |residual : &A::Observable,
+                      μ : &RNDM<F, N>,
+                      ε, stats| IterInfo {
+        value : residual.norm2_squared_div2() + reg.apply(μ),
+        n_spikes : μ.len(),
+        ε,
+        //postprocessing: config.postprocessing.then(|| μ.clone()),
+        .. stats
+    };
     let mut stats = IterInfo::new();
 
     // Run the algorithm
-    iterator.iterate(|state| {
+    for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
-        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
-        // the residual and replacing it below before the end of this closure.
-        residual *= -τ;
-        let r = std::mem::replace(&mut residual, opA.empty_observable());
-        let minus_τv = opA.preadjoint().apply(r);
+        let τv = opA.preadjoint().apply(residual * τ);
 
         // Save current base point
         let μ_base = μ.clone();
             
         // Insert and reweigh
-        let (d, within_tolerances) = insert_and_reweigh(
-            &mut μ, &minus_τv, &μ_base, None,
+        let (d, _within_tolerances) = insert_and_reweigh(
+            &mut μ, &τv, &μ_base, None,
             op𝒟, op𝒟norm,
             τ, ε,
-            config, &reg, state, &mut stats
+            config, &reg, &state, &mut stats
         );
 
         // Prune and possibly merge spikes
-        prune_and_maybe_simple_merge(
-            &mut μ, &minus_τv, &μ_base,
-            op𝒟,
-            τ, ε,
-            config, &reg, state, &mut stats
-        );
+        if config.merge_now(&state) {
+            stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
+                let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base));
+                reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
+            });
+        }
+        stats.pruned += prune_with_stats(&mut μ);
 
         // Update residual
         residual = calculate_residual(&μ, opA, b);
 
-        // Update main tolerance for next iteration
-        let ε_prev = ε;
-        ε = tolerance.update(ε, state.iteration());
+        let iter = state.iteration();
         stats.this_iters += 1;
 
-        // Give function value if needed
+        // Give statistics if needed
         state.if_verbose(|| {
-            // Plot if so requested
-            plotter.plot_spikes(
-                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
-                "start".to_string(), Some(&minus_τv),
-                reg.target_bounds(τ, ε_prev), &μ,
-            );
-            // Calculate mean inner iterations and reset relevant counters.
-            // Return the statistics
-            let res = IterInfo {
-                value : residual.norm2_squared_div2() + reg.apply(&μ),
-                n_spikes : μ.len(),
-                ε : ε_prev,
-                postprocessing: config.postprocessing.then(|| μ.clone()),
-                .. stats
-            };
-            stats = IterInfo::new();
-            res
-        })
-    });
+            plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ);
+            full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
+        });
+        
+        // Update main tolerance for next iteration
+        ε = tolerance.update(ε, iter);
+    }
 
     postprocess(μ, config, L2Squared, opA, b)
 }
@@ -563,15 +535,13 @@
     fbconfig : &FBConfig<F>,
     iterator : I,
     mut plotter : SeqPlotter<F, N>,
-) -> DiscreteMeasure<Loc<F, N>, F>
+) -> RNDM<F, N>
 where F : Float + ToNalgebraRealField,
       I : AlgIteratorFactory<IterInfo<F, N>>,
       for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
-                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
-      A::Observable : std::ops::MulAssign<F>,
       GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
-          + Lipschitz<&'a 𝒟, FloatType=F>,
+      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
+          + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
       𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
@@ -581,13 +551,13 @@
       BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
       Cube<F, N>: P2Minimise<Loc<F, N>, F>,
       PlotLookup : Plotting<N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      RNDM<F, N> : SpikeMerging<F>,
       Reg : RegTerm<F, N> {
 
     // Set up parameters
     let config = &fbconfig.generic;
-    let op𝒟norm = op𝒟.opnorm_bound();
-    let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
+    let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
+    let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap();
     let mut λ = 1.0;
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
@@ -598,32 +568,36 @@
     let mut μ = DiscreteMeasure::new();
     let mut μ_prev = DiscreteMeasure::new();
     let mut residual = -b;
+    let mut warned_merging = false;
+
+    // Statistics
+    let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo {
+        value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν),
+        n_spikes : ν.len(),
+        ε,
+        // postprocessing: config.postprocessing.then(|| ν.clone()),
+        .. stats
+    };
     let mut stats = IterInfo::new();
-    let mut warned_merging = false;
 
     // Run the algorithm
-    iterator.iterate(|state| {
+    for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
-        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
-        // the residual and replacing it below before the end of this closure.
-        residual *= -τ;
-        let r = std::mem::replace(&mut residual, opA.empty_observable());
-        let minus_τv = opA.preadjoint().apply(r);
+        let τv = opA.preadjoint().apply(residual * τ);
 
         // Save current base point
         let μ_base = μ.clone();
             
         // Insert new spikes and reweigh
-        let (d, within_tolerances) = insert_and_reweigh(
-            &mut μ, &minus_τv, &μ_base, None,
+        let (d, _within_tolerances) = insert_and_reweigh(
+            &mut μ, &τv, &μ_base, None,
             op𝒟, op𝒟norm,
             τ, ε,
-            config, &reg, state, &mut stats
+            config, &reg, &state, &mut stats
         );
 
         // (Do not) merge spikes.
-        if state.iteration() % config.merge_every == 0 {
+        if config.merge_now(&state) {
             match config.merging {
                 SpikeMergingMethod::None => { },
                 _ => if !warned_merging {
@@ -653,32 +627,18 @@
         // Update residual
         residual = calculate_residual(&μ, opA, b);
 
-        // Update main tolerance for next iteration
-        let ε_prev = ε;
-        ε = tolerance.update(ε, state.iteration());
+        let iter = state.iteration();
         stats.this_iters += 1;
 
-        // Give function value if needed
+        // Give statistics if needed
         state.if_verbose(|| {
-            // Plot if so requested
-            plotter.plot_spikes(
-                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
-                "start".to_string(), Some(&minus_τv),
-                reg.target_bounds(τ, ε_prev), &μ_prev,
-            );
-            // Calculate mean inner iterations and reset relevant counters.
-            // Return the statistics
-            let res = IterInfo {
-                value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev),
-                n_spikes : μ_prev.len(),
-                ε : ε_prev,
-                postprocessing: config.postprocessing.then(|| μ_prev.clone()),
-                .. stats
-            };
-            stats = IterInfo::new();
-            res
-        })
-    });
+            plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ_prev);
+            full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new()))
+        });
+
+        // Update main tolerance for next iteration
+        ε = tolerance.update(ε, iter);
+    }
 
     postprocess(μ_prev, config, L2Squared, opA, b)
 }

mercurial