src/fb.rs

branch
dev
changeset 51
0693cc9ba9f0
parent 39
6316d68b58af
--- a/src/fb.rs	Mon Feb 17 13:45:11 2025 -0500
+++ b/src/fb.rs	Mon Feb 17 13:51:50 2025 -0500
@@ -74,69 +74,50 @@
 </p>
 
 We solve this with either SSN or FB as determined by
-[`InnerSettings`] in [`FBGenericConfig::inner`].
+[`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`].
 */
 
+use colored::Colorize;
 use numeric_literals::replace_float_literals;
-use serde::{Serialize, Deserialize};
-use colored::Colorize;
+use serde::{Deserialize, Serialize};
 
+use alg_tools::euclidean::Euclidean;
+use alg_tools::instance::Instance;
 use alg_tools::iterate::AlgIteratorFactory;
-use alg_tools::euclidean::Euclidean;
 use alg_tools::linops::{Mapping, GEMV};
 use alg_tools::mapping::RealMapping;
 use alg_tools::nalgebra_support::ToNalgebraRealField;
-use alg_tools::instance::Instance;
 
-use crate::types::*;
-use crate::measures::{
-    DiscreteMeasure,
-    RNDM,
-};
+use crate::dataterm::{calculate_residual, DataTerm, L2Squared};
+use crate::forward_model::{AdjointProductBoundedBy, ForwardModel};
 use crate::measures::merging::SpikeMerging;
-use crate::forward_model::{
-    ForwardModel,
-    AdjointProductBoundedBy,
-};
-use crate::plot::{
-    SeqPlotter,
-    Plotting,
-    PlotLookup
-};
+use crate::measures::{DiscreteMeasure, RNDM};
+use crate::plot::{PlotLookup, Plotting, SeqPlotter};
+pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty};
 use crate::regularisation::RegTerm;
-use crate::dataterm::{
-    calculate_residual,
-    L2Squared,
-    DataTerm,
-};
-pub use crate::prox_penalty::{
-    FBGenericConfig,
-    ProxPenalty
-};
+use crate::types::*;
 
 /// Settings for [`pointsource_fb_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
-pub struct FBConfig<F : Float> {
+pub struct FBConfig<F: Float> {
     /// Step length scaling
-    pub τ0 : F,
+    pub τ0: F,
     /// Generic parameters
-    pub generic : FBGenericConfig<F>,
+    pub generic: FBGenericConfig<F>,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for FBConfig<F> {
+impl<F: Float> Default for FBConfig<F> {
     fn default() -> Self {
         FBConfig {
-            τ0 : 0.99,
-            generic : Default::default(),
+            τ0: 0.99,
+            generic: Default::default(),
         }
     }
 }
 
-pub(crate) fn prune_with_stats<F : Float, const N : usize>(
-    μ : &mut RNDM<F, N>,
-) -> usize {
+pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize {
     let n_before_prune = μ.len();
     μ.prune();
     debug_assert!(μ.len() <= n_before_prune);
@@ -145,25 +126,27 @@
 
 #[replace_float_literals(F::cast_from(literal))]
 pub(crate) fn postprocess<
-    F : Float,
-    V : Euclidean<F> + Clone,
-    A : GEMV<F, RNDM<F, N>, Codomain = V>,
-    D : DataTerm<F, V, N>,
-    const N : usize
-> (
-    mut μ : RNDM<F, N>,
-    config : &FBGenericConfig<F>,
-    dataterm : D,
-    opA : &A,
-    b : &V,
+    F: Float,
+    V: Euclidean<F> + Clone,
+    A: GEMV<F, RNDM<F, N>, Codomain = V>,
+    D: DataTerm<F, V, N>,
+    const N: usize,
+>(
+    mut μ: RNDM<F, N>,
+    config: &FBGenericConfig<F>,
+    dataterm: D,
+    opA: &A,
+    b: &V,
 ) -> RNDM<F, N>
 where
-    RNDM<F, N> : SpikeMerging<F>,
-    for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>,
+    RNDM<F, N>: SpikeMerging<F>,
+    for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>,
 {
-    μ.merge_spikes_fitness(config.final_merging_method(),
-                           |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
-                           |&v| v);
+    μ.merge_spikes_fitness(
+        config.final_merging_method(),
+        |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
+        |&v| v,
+    );
     μ.prune();
     μ
 }
@@ -187,33 +170,29 @@
 ///
 /// Returns the final iterate.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_fb_reg<
-    F, I, A, Reg, P, const N : usize
->(
-    opA : &A,
-    b : &A::Observable,
-    reg : Reg,
-    prox_penalty : &P,
-    fbconfig : &FBConfig<F>,
-    iterator : I,
-    mut plotter : SeqPlotter<F, N>,
+pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>(
+    opA: &A,
+    b: &A::Observable,
+    reg: Reg,
+    prox_penalty: &P,
+    fbconfig: &FBConfig<F>,
+    iterator: I,
+    mut plotter: SeqPlotter<F, N>,
 ) -> RNDM<F, N>
 where
-    F : Float + ToNalgebraRealField,
-    I : AlgIteratorFactory<IterInfo<F, N>>,
-    for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
-    A : ForwardModel<RNDM<F, N>, F>
-        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>,
-    A::PreadjointCodomain : RealMapping<F, N>,
-    PlotLookup : Plotting<N>,
-    RNDM<F, N> : SpikeMerging<F>,
-    Reg : RegTerm<F, N>,
-    P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
+    F: Float + ToNalgebraRealField,
+    I: AlgIteratorFactory<IterInfo<F, N>>,
+    for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>,
+    A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>,
+    A::PreadjointCodomain: RealMapping<F, N>,
+    PlotLookup: Plotting<N>,
+    RNDM<F, N>: SpikeMerging<F>,
+    Reg: RegTerm<F, N>,
+    P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
 {
-
     // Set up parameters
     let config = &fbconfig.generic;
-    let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap();
+    let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
     let tolerance = config.tolerance * τ * reg.tolerance_scaling();
@@ -224,14 +203,12 @@
     let mut residual = -b;
 
     // Statistics
-    let full_stats = |residual : &A::Observable,
-                      μ : &RNDM<F, N>,
-                      ε, stats| IterInfo {
-        value : residual.norm2_squared_div2() + reg.apply(μ),
-        n_spikes : μ.len(),
+    let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo {
+        value: residual.norm2_squared_div2() + reg.apply(μ),
+        n_spikes: μ.len(),
         ε,
         //postprocessing: config.postprocessing.then(|| μ.clone()),
-        .. stats
+        ..stats
     };
     let mut stats = IterInfo::new();
 
@@ -242,19 +219,24 @@
 
         // Save current base point
         let μ_base = μ.clone();
-            
+
         // Insert and reweigh
         let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
-            &mut μ, &mut τv, &μ_base, None,
-            τ, ε,
-            config, &reg, &state, &mut stats
+            &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
         );
 
         // Prune and possibly merge spikes
         if config.merge_now(&state) {
             stats.merged += prox_penalty.merge_spikes(
-                &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg,
-                Some(|μ̃ : &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
+                &mut μ,
+                &mut τv,
+                &μ_base,
+                None,
+                τ,
+                ε,
+                config,
+                &reg,
+                Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
             );
         }
 
@@ -269,9 +251,14 @@
         // Give statistics if needed
         state.if_verbose(|| {
             plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
-            full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
+            full_stats(
+                &residual,
+                &μ,
+                ε,
+                std::mem::replace(&mut stats, IterInfo::new()),
+            )
         });
-        
+
         // Update main tolerance for next iteration
         ε = tolerance.update(ε, iter);
     }
@@ -298,33 +285,29 @@
 ///
 /// Returns the final iterate.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_fista_reg<
-    F, I, A, Reg, P, const N : usize
->(
-    opA : &A,
-    b : &A::Observable,
-    reg : Reg,
-    prox_penalty : &P,
-    fbconfig : &FBConfig<F>,
-    iterator : I,
-    mut plotter : SeqPlotter<F, N>,
+pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>(
+    opA: &A,
+    b: &A::Observable,
+    reg: Reg,
+    prox_penalty: &P,
+    fbconfig: &FBConfig<F>,
+    iterator: I,
+    mut plotter: SeqPlotter<F, N>,
 ) -> RNDM<F, N>
 where
-    F : Float + ToNalgebraRealField,
-    I : AlgIteratorFactory<IterInfo<F, N>>,
-    for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
-    A : ForwardModel<RNDM<F, N>, F>
-        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>,
-    A::PreadjointCodomain : RealMapping<F, N>,
-    PlotLookup : Plotting<N>,
-    RNDM<F, N> : SpikeMerging<F>,
-    Reg : RegTerm<F, N>,
-    P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
+    F: Float + ToNalgebraRealField,
+    I: AlgIteratorFactory<IterInfo<F, N>>,
+    for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>,
+    A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>,
+    A::PreadjointCodomain: RealMapping<F, N>,
+    PlotLookup: Plotting<N>,
+    RNDM<F, N>: SpikeMerging<F>,
+    Reg: RegTerm<F, N>,
+    P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
 {
-
     // Set up parameters
     let config = &fbconfig.generic;
-    let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap();
+    let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
     let mut λ = 1.0;
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
@@ -338,12 +321,12 @@
     let mut warned_merging = false;
 
     // Statistics
-    let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo {
-        value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν),
-        n_spikes : ν.len(),
+    let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo {
+        value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν),
+        n_spikes: ν.len(),
         ε,
         // postprocessing: config.postprocessing.then(|| ν.clone()),
-        .. stats
+        ..stats
     };
     let mut stats = IterInfo::new();
 
@@ -354,12 +337,10 @@
 
         // Save current base point
         let μ_base = μ.clone();
-            
+
         // Insert new spikes and reweigh
         let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
-            &mut μ, &mut τv, &μ_base, None,
-            τ, ε,
-            config, &reg, &state, &mut stats
+            &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
         );
 
         // (Do not) merge spikes.
@@ -371,7 +352,7 @@
 
         // Update inertial prameters
         let λ_prev = λ;
-        λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() );
+        λ = 2.0 * λ_prev / (λ_prev + (4.0 + λ_prev * λ_prev).sqrt());
         let θ = λ / λ_prev - λ;
 
         // Perform inertial update on μ.

mercurial