src/fb.rs

branch
dev
changeset 37
c5d8bd1a7728
parent 35
b087e3eab191
child 39
6316d68b58af
--- a/src/fb.rs	Mon Jan 06 11:32:57 2025 -0500
+++ b/src/fb.rs	Thu Jan 23 23:35:28 2025 +0100
@@ -80,40 +80,18 @@
 use numeric_literals::replace_float_literals;
 use serde::{Serialize, Deserialize};
 use colored::Colorize;
-use nalgebra::DVector;
 
-use alg_tools::iterate::{
-    AlgIteratorFactory,
-    AlgIteratorIteration,
-    AlgIterator,
-};
+use alg_tools::iterate::AlgIteratorFactory;
 use alg_tools::euclidean::Euclidean;
 use alg_tools::linops::{Mapping, GEMV};
-use alg_tools::sets::Cube;
-use alg_tools::loc::Loc;
-use alg_tools::bisection_tree::{
-    BTFN,
-    PreBTFN,
-    Bounds,
-    BTNodeLookup,
-    BTNode,
-    BTSearch,
-    P2Minimise,
-    SupportGenerator,
-    LocalAnalysis,
-    BothGenerators,
-};
 use alg_tools::mapping::RealMapping;
 use alg_tools::nalgebra_support::ToNalgebraRealField;
 use alg_tools::instance::Instance;
-use alg_tools::norms::Linfinity;
 
 use crate::types::*;
 use crate::measures::{
     DiscreteMeasure,
     RNDM,
-    DeltaMeasure,
-    Radon,
 };
 use crate::measures::merging::{
     SpikeMergingMethod,
@@ -121,14 +99,8 @@
 };
 use crate::forward_model::{
     ForwardModel,
-    AdjointProductBoundedBy
+    AdjointProductBoundedBy,
 };
-use crate::seminorms::DiscreteMeasureOp;
-use crate::subproblem::{
-    InnerSettings,
-    InnerMethod,
-};
-use crate::tolerance::Tolerance;
 use crate::plot::{
     SeqPlotter,
     Plotting,
@@ -140,6 +112,10 @@
     L2Squared,
     DataTerm,
 };
+pub use crate::prox_penalty::{
+    FBGenericConfig,
+    ProxPenalty
+};
 
 /// Settings for [`pointsource_fb_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
@@ -151,51 +127,6 @@
     pub generic : FBGenericConfig<F>,
 }
 
-/// Settings for the solution of the stepwise optimality condition.
-#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
-#[serde(default)]
-pub struct FBGenericConfig<F : Float> {
-    /// Tolerance for point insertion.
-    pub tolerance : Tolerance<F>,
-
-    /// Stop looking for predual maximum (where to isert a new point) below
-    /// `tolerance` multiplied by this factor.
-    ///
-    /// Not used by [`super::radon_fb`].
-    pub insertion_cutoff_factor : F,
-
-    /// Settings for branch and bound refinement when looking for predual maxima
-    pub refinement : RefinementSettings<F>,
-
-    /// Maximum insertions within each outer iteration
-    ///
-    /// Not used by [`super::radon_fb`].
-    pub max_insertions : usize,
-
-    /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations.
-    ///
-    /// Not used by [`super::radon_fb`].
-    pub bootstrap_insertions : Option<(usize, usize)>,
-
-    /// Inner method settings
-    pub inner : InnerSettings<F>,
-
-    /// Spike merging method
-    pub merging : SpikeMergingMethod<F>,
-
-    /// Tolerance multiplier for merges
-    pub merge_tolerance_mult : F,
-
-    /// Spike merging method after the last step
-    pub final_merging : SpikeMergingMethod<F>,
-
-    /// Iterations between merging heuristic tries
-    pub merge_every : usize,
-
-    // /// Save $μ$ for postprocessing optimisation
-    // pub postprocessing : bool
-}
-
 #[replace_float_literals(F::cast_from(literal))]
 impl<F : Float> Default for FBConfig<F> {
     fn default() -> Self {
@@ -206,155 +137,6 @@
     }
 }
 
-#[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for FBGenericConfig<F> {
-    fn default() -> Self {
-        FBGenericConfig {
-            tolerance : Default::default(),
-            insertion_cutoff_factor : 1.0,
-            refinement : Default::default(),
-            max_insertions : 100,
-            //bootstrap_insertions : None,
-            bootstrap_insertions : Some((10, 1)),
-            inner : InnerSettings {
-                method : InnerMethod::Default,
-                .. Default::default()
-            },
-            merging : SpikeMergingMethod::None,
-            //merging : Default::default(),
-            final_merging : Default::default(),
-            merge_every : 10,
-            merge_tolerance_mult : 2.0,
-            // postprocessing : false,
-        }
-    }
-}
-
-impl<F : Float> FBGenericConfig<F> {
-    /// Check if merging should be attempted this iteration
-    pub fn merge_now<I : AlgIterator>(&self, state : &AlgIteratorIteration<I>) -> bool {
-        state.iteration() % self.merge_every == 0
-    }
-}
-
-/// TODO: document.
-/// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike
-/// locations, while `ν_delta` may have different locations.
-#[replace_float_literals(F::cast_from(literal))]
-pub(crate) fn insert_and_reweigh<
-    'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, I, const N : usize
->(
-    μ : &mut RNDM<F, N>,
-    τv : &BTFN<F, GA, BTA, N>,
-    μ_base : &RNDM<F, N>,
-    ν_delta: Option<&RNDM<F, N>>,
-    op𝒟 : &'a 𝒟,
-    op𝒟norm : F,
-    τ : F,
-    ε : F,
-    config : &FBGenericConfig<F>,
-    reg : &Reg,
-    state : &AlgIteratorIteration<I>,
-    stats : &mut IterInfo<F, N>,
-) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool)
-where F : Float + ToNalgebraRealField,
-      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
-      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
-      𝒟::Codomain : RealMapping<F, N>,
-      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
-      Reg : RegTerm<F, N>,
-      I : AlgIterator {
-
-    // Maximum insertion count and measure difference calculation depend on insertion style.
-    let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
-        (i, Some((l, k))) if i <= l => (k, false),
-        _ => (config.max_insertions, !state.is_quiet()),
-    };
-
-    let ω0 = match ν_delta {
-        None => op𝒟.apply(μ_base),
-        Some(ν) => op𝒟.apply(μ_base + ν),
-    };
-
-    // Add points to support until within error tolerance or maximum insertion count reached.
-    let mut count = 0;
-    let (within_tolerances, d) = 'insertion: loop {
-        if μ.len() > 0 {
-            // Form finite-dimensional subproblem. The subproblem references to the original μ^k
-            // from the beginning of the iteration are all contained in the immutable c and g.
-            // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
-            // problems have not yet been updated to sign change.
-            let à = op𝒟.findim_matrix(μ.iter_locations());
-            let g̃ = DVector::from_iterator(μ.len(),
-                                           μ.iter_locations()
-                                            .map(|ζ| ω0.apply(ζ) - τv.apply(ζ))
-                                            .map(F::to_nalgebra_mixed));
-            let mut x = μ.masses_dvector();
-
-            // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
-            // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
-            // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤  sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
-            // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2
-            // = n |𝒟| |x|_2, where n is the number of points. Therefore
-            let Ã_normest = op𝒟norm * F::cast_from(μ.len());
-
-            // Solve finite-dimensional subproblem.
-            stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config);
-
-            // Update masses of μ based on solution of finite-dimensional subproblem.
-            μ.set_masses_dvector(&x);
-        }
-
-        // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality
-        // conditions in the predual space, and finding new points for insertion, if necessary.
-        let mut d = τv + match ν_delta {
-            None => op𝒟.preapply(μ.sub_matching(μ_base)),
-            Some(ν) => op𝒟.preapply(μ.sub_matching(μ_base) - ν)
-        };
-
-        // If no merging heuristic is used, let's be more conservative about spike insertion,
-        // and skip it after first round. If merging is done, being more greedy about spike
-        // insertion also seems to improve performance.
-        let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging {
-            false
-        } else {
-            count > 0
-        };
-
-        // Find a spike to insert, if needed
-        let (ξ, _v_ξ, in_bounds) =  match reg.find_tolerance_violation(
-            &mut d, τ, ε, skip_by_rough_check, config
-        ) {
-            None => break 'insertion (true, d),
-            Some(res) => res,
-        };
-
-        // Break if maximum insertion count reached
-        if count >= max_insertions {
-            break 'insertion (in_bounds, d)
-        }
-
-        // No point in optimising the weight here; the finite-dimensional algorithm is fast.
-        *μ += DeltaMeasure { x : ξ, α : 0.0 };
-        count += 1;
-        stats.inserted += 1;
-    };
-
-    if !within_tolerances && warn_insertions {
-        // Complain (but continue) if we failed to get within tolerances
-        // by inserting more points.
-        let err = format!("Maximum insertions reached without achieving \
-                            subproblem solution tolerance");
-        println!("{}", err.red());
-    }
-
-    (d, within_tolerances)
-}
-
 pub(crate) fn prune_with_stats<F : Float, const N : usize>(
     μ : &mut RNDM<F, N>,
 ) -> usize {
@@ -409,38 +191,32 @@
 /// Returns the final iterate.
 #[replace_float_literals(F::cast_from(literal))]
 pub fn pointsource_fb_reg<
-    'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize
+    F, I, A, Reg, P, const N : usize
 >(
-    opA : &'a A,
+    opA : &A,
     b : &A::Observable,
     reg : Reg,
-    op𝒟 : &'a 𝒟,
+    prox_penalty : &P,
     fbconfig : &FBConfig<F>,
     iterator : I,
     mut plotter : SeqPlotter<F, N>,
 ) -> RNDM<F, N>
-where F : Float + ToNalgebraRealField,
-      I : AlgIteratorFactory<IterInfo<F, N>>,
-      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
-      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
-          + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
-      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
-      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
-      𝒟::Codomain : RealMapping<F, N>,
-      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
-      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
-      PlotLookup : Plotting<N>,
-      RNDM<F, N> : SpikeMerging<F>,
-      Reg : RegTerm<F, N> {
+where
+    F : Float + ToNalgebraRealField,
+    I : AlgIteratorFactory<IterInfo<F, N>>,
+    for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
+    A : ForwardModel<RNDM<F, N>, F>
+        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>,
+    A::PreadjointCodomain : RealMapping<F, N>,
+    PlotLookup : Plotting<N>,
+    RNDM<F, N> : SpikeMerging<F>,
+    Reg : RegTerm<F, N>,
+    P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
+{
 
     // Set up parameters
     let config = &fbconfig.generic;
-    let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
-    let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap();
+    let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap();
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
     let tolerance = config.tolerance * τ * reg.tolerance_scaling();
@@ -465,26 +241,23 @@
     // Run the algorithm
     for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        let τv = opA.preadjoint().apply(residual * τ);
+        let mut τv = opA.preadjoint().apply(residual * τ);
 
         // Save current base point
         let μ_base = μ.clone();
             
         // Insert and reweigh
-        let (d, _within_tolerances) = insert_and_reweigh(
-            &mut μ, &τv, &μ_base, None,
-            op𝒟, op𝒟norm,
+        let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
+            &mut μ, &mut τv, &μ_base, None,
             τ, ε,
             config, &reg, &state, &mut stats
         );
 
         // Prune and possibly merge spikes
         if config.merge_now(&state) {
-            stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
-                let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base));
-                reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
-            });
+            stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, &reg);
         }
+
         stats.pruned += prune_with_stats(&mut μ);
 
         // Update residual
@@ -495,7 +268,7 @@
 
         // Give statistics if needed
         state.if_verbose(|| {
-            plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ);
+            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
             full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
         });
         
@@ -526,38 +299,32 @@
 /// Returns the final iterate.
 #[replace_float_literals(F::cast_from(literal))]
 pub fn pointsource_fista_reg<
-    'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize
+    F, I, A, Reg, P, const N : usize
 >(
-    opA : &'a A,
+    opA : &A,
     b : &A::Observable,
     reg : Reg,
-    op𝒟 : &'a 𝒟,
+    prox_penalty : &P,
     fbconfig : &FBConfig<F>,
     iterator : I,
     mut plotter : SeqPlotter<F, N>,
 ) -> RNDM<F, N>
-where F : Float + ToNalgebraRealField,
-      I : AlgIteratorFactory<IterInfo<F, N>>,
-      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
-      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
-          + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
-      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
-      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
-      𝒟::Codomain : RealMapping<F, N>,
-      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
-      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
-      PlotLookup : Plotting<N>,
-      RNDM<F, N> : SpikeMerging<F>,
-      Reg : RegTerm<F, N> {
+where
+    F : Float + ToNalgebraRealField,
+    I : AlgIteratorFactory<IterInfo<F, N>>,
+    for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
+    A : ForwardModel<RNDM<F, N>, F>
+        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>,
+    A::PreadjointCodomain : RealMapping<F, N>,
+    PlotLookup : Plotting<N>,
+    RNDM<F, N> : SpikeMerging<F>,
+    Reg : RegTerm<F, N>,
+    P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
+{
 
     // Set up parameters
     let config = &fbconfig.generic;
-    let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
-    let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap();
+    let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap();
     let mut λ = 1.0;
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
@@ -583,15 +350,14 @@
     // Run the algorithm
     for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        let τv = opA.preadjoint().apply(residual * τ);
+        let mut τv = opA.preadjoint().apply(residual * τ);
 
         // Save current base point
         let μ_base = μ.clone();
             
         // Insert new spikes and reweigh
-        let (d, _within_tolerances) = insert_and_reweigh(
-            &mut μ, &τv, &μ_base, None,
-            op𝒟, op𝒟norm,
+        let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
+            &mut μ, &mut τv, &μ_base, None,
             τ, ε,
             config, &reg, &state, &mut stats
         );
@@ -632,7 +398,7 @@
 
         // Give statistics if needed
         state.if_verbose(|| {
-            plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ_prev);
+            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ_prev);
             full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new()))
         });
 

mercurial