src/fb.rs

branch
dev
changeset 34
efa60bc4f743
parent 32
56c8adc32b09
child 35
b087e3eab191
--- a/src/fb.rs	Tue Aug 01 10:32:12 2023 +0300
+++ b/src/fb.rs	Thu Aug 29 00:00:00 2024 -0500
@@ -136,17 +136,6 @@
     DataTerm,
 };
 
-/// Method for constructing $μ$ on each iteration
-#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
-#[allow(dead_code)]
-pub enum InsertionStyle {
-    /// Resuse previous $μ$ from previous iteration, optimising weights
-    /// before inserting new spikes.
-    Reuse,
-    /// Start each iteration with $μ=0$.
-    Zero,
-}
-
 /// Settings for [`pointsource_fb_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
@@ -154,7 +143,7 @@
     /// Step length scaling
     pub τ0 : F,
     /// Generic parameters
-    pub insertion : FBGenericConfig<F>,
+    pub generic : FBGenericConfig<F>,
 }
 
 /// Settings for the solution of the stepwise optimality condition in algorithms based on
@@ -162,29 +151,43 @@
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
 pub struct FBGenericConfig<F : Float> {
-    /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`].
-    pub insertion_style : InsertionStyle,
     /// Tolerance for point insertion.
     pub tolerance : Tolerance<F>,
+
     /// Stop looking for predual maximum (where to isert a new point) below
     /// `tolerance` multiplied by this factor.
+    ///
+    /// Not used by [`super::radon_fb`].
     pub insertion_cutoff_factor : F,
+
     /// Settings for branch and bound refinement when looking for predual maxima
     pub refinement : RefinementSettings<F>,
+
     /// Maximum insertions within each outer iteration
+    ///
+    /// Not used by [`super::radon_fb`].
     pub max_insertions : usize,
+
     /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations.
+    ///
+    /// Not used by [`super::radon_fb`].
     pub bootstrap_insertions : Option<(usize, usize)>,
+
     /// Inner method settings
     pub inner : InnerSettings<F>,
+
     /// Spike merging method
     pub merging : SpikeMergingMethod<F>,
+
     /// Tolerance multiplier for merges
     pub merge_tolerance_mult : F,
+
     /// Spike merging method after the last step
     pub final_merging : SpikeMergingMethod<F>,
+
     /// Iterations between merging heuristic tries
     pub merge_every : usize,
+
     /// Save $μ$ for postprocessing optimisation
     pub postprocessing : bool
 }
@@ -194,7 +197,7 @@
     fn default() -> Self {
         FBConfig {
             τ0 : 0.99,
-            insertion : Default::default()
+            generic : Default::default(),
         }
     }
 }
@@ -203,7 +206,6 @@
 impl<F : Float> Default for FBGenericConfig<F> {
     fn default() -> Self {
         FBGenericConfig {
-            insertion_style : InsertionStyle::Reuse,
             tolerance : Default::default(),
             insertion_cutoff_factor : 1.0,
             refinement : Default::default(),
@@ -211,7 +213,7 @@
             //bootstrap_insertions : None,
             bootstrap_insertions : Some((10, 1)),
             inner : InnerSettings {
-                method : InnerMethod::SSN,
+                method : InnerMethod::Default,
                 .. Default::default()
             },
             merging : SpikeMergingMethod::None,
@@ -224,35 +226,9 @@
     }
 }
 
-#[replace_float_literals(F::cast_from(literal))]
-pub(crate) fn μ_diff<F : Float, const N : usize>(
-    μ_new : &DiscreteMeasure<Loc<F, N>, F>,
-    μ_base : &DiscreteMeasure<Loc<F, N>, F>,
-    ν_delta : Option<&DiscreteMeasure<Loc<F, N>, F>>,
-    config : &FBGenericConfig<F>
-) -> DiscreteMeasure<Loc<F, N>, F> {
-    let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style {
-        InsertionStyle::Reuse => {
-            μ_new.iter_spikes()
-                 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0)))
-                 .map(|(δ, α_base)| (δ.x, α_base - δ.α))
-                 .collect()
-        },
-        InsertionStyle::Zero => {
-            μ_new.iter_spikes()
-                 .map(|δ| -δ)
-                 .chain(μ_base.iter_spikes().copied())
-                 .collect()
-        }
-    };
-    ν.prune(); // Potential small performance improvement
-    // Add ν_delta if given
-    match ν_delta {
-        None => ν,
-        Some(ν_d) => ν + ν_d,
-    }
-}
-
+/// TODO: document.
+/// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike
+/// locations, while `ν_delta` may have different locations.
 #[replace_float_literals(F::cast_from(literal))]
 pub(crate) fn insert_and_reweigh<
     'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
@@ -284,24 +260,15 @@
       State : AlgIteratorState {
 
     // Maximum insertion count and measure difference calculation depend on insertion style.
-    let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
+    let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
         (i, Some((l, k))) if i <= l => (k, false),
         _ => (config.max_insertions, !state.is_quiet()),
     };
-    let max_insertions = match config.insertion_style {
-        InsertionStyle::Zero => {
-            todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled.");
-            // let n = μ.len();
-            // μ = DiscreteMeasure::new();
-            // n + m
-        },
-        InsertionStyle::Reuse => m,
-    };
 
-    // TODO: should avoid a second copy of μ here; μ_base already stores a copy.
+    // TODO: should avoid a copy of μ_base here.
     let ω0 = op𝒟.apply(match ν_delta {
-        None => μ.clone(),
-        Some(ν_d) => &*μ + ν_d,
+        None => μ_base.clone(),
+        Some(ν_d) => &*μ_base + ν_d,
     });
 
     // Add points to support until within error tolerance or maximum insertion count reached.
@@ -333,7 +300,10 @@
 
         // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality
         // conditions in the predual space, and finding new points for insertion, if necessary.
-        let mut d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_base, ν_delta, config));
+        let mut d = minus_τv + op𝒟.preapply(match ν_delta {
+            None => μ_base.sub_matching(μ),
+            Some(ν) =>  μ_base.sub_matching(μ) + ν
+        });
 
         // If no merging heuristic is used, let's be more conservative about spike insertion,
         // and skip it after first round. If merging is done, being more greedy about spike
@@ -404,16 +374,10 @@
       Reg : RegTerm<F, N>,
       State : AlgIteratorState {
     if state.iteration() % config.merge_every == 0 {
-        let n_before_merge = μ.len();
-        μ.merge_spikes(config.merging, |μ_candidate| {
-            let μd = μ_diff(&μ_candidate, &μ_base, None, config);
-            let mut d = minus_τv + op𝒟.preapply(μd);
-
+        stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
+            let mut d = minus_τv + op𝒟.preapply(μ_base.sub_matching(&μ_candidate));
             reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
-                .then_some(())
         });
-        debug_assert!(μ.len() >= n_before_merge);
-        stats.merged += μ.len() - n_before_merge;
     }
 
     let n_before_prune = μ.len();
@@ -495,7 +459,7 @@
       Reg : RegTerm<F, N> {
 
     // Set up parameters
-    let config = &fbconfig.insertion;
+    let config = &fbconfig.generic;
     let op𝒟norm = op𝒟.opnorm_bound();
     let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
@@ -621,7 +585,7 @@
       Reg : RegTerm<F, N> {
 
     // Set up parameters
-    let config = &fbconfig.insertion;
+    let config = &fbconfig.generic;
     let op𝒟norm = op𝒟.opnorm_bound();
     let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
     let mut λ = 1.0;

mercurial