diff -r 39c5e6c7759d -r 0693cc9ba9f0 src/fb.rs --- a/src/fb.rs Mon Feb 17 13:45:11 2025 -0500 +++ b/src/fb.rs Mon Feb 17 13:51:50 2025 -0500 @@ -74,69 +74,50 @@

We solve this with either SSN or FB as determined by -[`InnerSettings`] in [`FBGenericConfig::inner`]. +[`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`]. */ +use colored::Colorize; use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; -use colored::Colorize; +use serde::{Deserialize, Serialize}; +use alg_tools::euclidean::Euclidean; +use alg_tools::instance::Instance; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::euclidean::Euclidean; use alg_tools::linops::{Mapping, GEMV}; use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::instance::Instance; -use crate::types::*; -use crate::measures::{ - DiscreteMeasure, - RNDM, -}; +use crate::dataterm::{calculate_residual, DataTerm, L2Squared}; +use crate::forward_model::{AdjointProductBoundedBy, ForwardModel}; use crate::measures::merging::SpikeMerging; -use crate::forward_model::{ - ForwardModel, - AdjointProductBoundedBy, -}; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; +use crate::measures::{DiscreteMeasure, RNDM}; +use crate::plot::{PlotLookup, Plotting, SeqPlotter}; +pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty}; use crate::regularisation::RegTerm; -use crate::dataterm::{ - calculate_residual, - L2Squared, - DataTerm, -}; -pub use crate::prox_penalty::{ - FBGenericConfig, - ProxPenalty -}; +use crate::types::*; /// Settings for [`pointsource_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct FBConfig { +pub struct FBConfig { /// Step length scaling - pub τ0 : F, + pub τ0: F, /// Generic parameters - pub generic : FBGenericConfig, + pub generic: FBGenericConfig, } #[replace_float_literals(F::cast_from(literal))] -impl Default for FBConfig { +impl Default for FBConfig { fn default() -> Self { FBConfig { - τ0 : 0.99, - generic : Default::default(), + τ0: 0.99, + generic: Default::default(), } } } -pub(crate) fn prune_with_stats( - μ : &mut RNDM, -) -> usize { +pub(crate) fn prune_with_stats(μ: &mut RNDM) -> usize { let n_before_prune = μ.len(); μ.prune(); debug_assert!(μ.len() <= n_before_prune); @@ -145,25 +126,27 @@ #[replace_float_literals(F::cast_from(literal))] pub(crate) fn postprocess< - F : Float, - V : Euclidean + Clone, - A : GEMV, Codomain = V>, - D : DataTerm, - const N : usize -> ( - mut μ : RNDM, - config : &FBGenericConfig, - dataterm : D, - opA : &A, - b : &V, + F: Float, + V: Euclidean + Clone, + A: GEMV, Codomain = V>, + D: DataTerm, + const N: usize, +>( + mut μ: RNDM, + config: &FBGenericConfig, + dataterm: D, + opA: &A, + b: &V, ) -> RNDM where - RNDM : SpikeMerging, - for<'a> &'a RNDM : Instance>, + RNDM: SpikeMerging, + for<'a> &'a RNDM: Instance>, { - μ.merge_spikes_fitness(config.final_merging_method(), - |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), - |&v| v); + μ.merge_spikes_fitness( + config.final_merging_method(), + |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), + |&v| v, + ); μ.prune(); μ } @@ -187,33 +170,29 @@ /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fb_reg< - F, I, A, Reg, P, const N : usize ->( - opA : &A, - b : &A::Observable, - reg : Reg, - prox_penalty : &P, - fbconfig : &FBConfig, - iterator : I, - mut plotter : SeqPlotter, +pub fn pointsource_fb_reg( + opA: &A, + b: &A::Observable, + reg: Reg, + prox_penalty: &P, + fbconfig: &FBConfig, + iterator: I, + mut plotter: SeqPlotter, ) -> RNDM where - F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg, - A : ForwardModel, F> - + AdjointProductBoundedBy, P, FloatType=F>, - A::PreadjointCodomain : RealMapping, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : RegTerm, - P : ProxPenalty, + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + for<'b> &'b A::Observable: std::ops::Neg, + A: ForwardModel, F> + AdjointProductBoundedBy, P, FloatType = F>, + A::PreadjointCodomain: RealMapping, + PlotLookup: Plotting, + RNDM: SpikeMerging, + Reg: RegTerm, + P: ProxPenalty, { - // Set up parameters let config = &fbconfig.generic; - let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); + let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. let tolerance = config.tolerance * τ * reg.tolerance_scaling(); @@ -224,14 +203,12 @@ let mut residual = -b; // Statistics - let full_stats = |residual : &A::Observable, - μ : &RNDM, - ε, stats| IterInfo { - value : residual.norm2_squared_div2() + reg.apply(μ), - n_spikes : μ.len(), + let full_stats = |residual: &A::Observable, μ: &RNDM, ε, stats| IterInfo { + value: residual.norm2_squared_div2() + reg.apply(μ), + n_spikes: μ.len(), ε, //postprocessing: config.postprocessing.then(|| μ.clone()), - .. stats + ..stats }; let mut stats = IterInfo::new(); @@ -242,19 +219,24 @@ // Save current base point let μ_base = μ.clone(); - + // Insert and reweigh let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( - &mut μ, &mut τv, &μ_base, None, - τ, ε, - config, ®, &state, &mut stats + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, ); // Prune and possibly merge spikes if config.merge_now(&state) { stats.merged += prox_penalty.merge_spikes( - &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, - Some(|μ̃ : &RNDM| L2Squared.calculate_fit_op(μ̃, opA, b)), + &mut μ, + &mut τv, + &μ_base, + None, + τ, + ε, + config, + ®, + Some(|μ̃: &RNDM| L2Squared.calculate_fit_op(μ̃, opA, b)), ); } @@ -269,9 +251,14 @@ // Give statistics if needed state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); - full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) + full_stats( + &residual, + &μ, + ε, + std::mem::replace(&mut stats, IterInfo::new()), + ) }); - + // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } @@ -298,33 +285,29 @@ /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fista_reg< - F, I, A, Reg, P, const N : usize ->( - opA : &A, - b : &A::Observable, - reg : Reg, - prox_penalty : &P, - fbconfig : &FBConfig, - iterator : I, - mut plotter : SeqPlotter, +pub fn pointsource_fista_reg( + opA: &A, + b: &A::Observable, + reg: Reg, + prox_penalty: &P, + fbconfig: &FBConfig, + iterator: I, + mut plotter: SeqPlotter, ) -> RNDM where - F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg, - A : ForwardModel, F> - + AdjointProductBoundedBy, P, FloatType=F>, - A::PreadjointCodomain : RealMapping, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : RegTerm, - P : ProxPenalty, + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + for<'b> &'b A::Observable: std::ops::Neg, + A: ForwardModel, F> + AdjointProductBoundedBy, P, FloatType = F>, + A::PreadjointCodomain: RealMapping, + PlotLookup: Plotting, + RNDM: SpikeMerging, + Reg: RegTerm, + P: ProxPenalty, { - // Set up parameters let config = &fbconfig.generic; - let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); + let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); let mut λ = 1.0; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. @@ -338,12 +321,12 @@ let mut warned_merging = false; // Statistics - let full_stats = |ν : &RNDM, ε, stats| IterInfo { - value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), - n_spikes : ν.len(), + let full_stats = |ν: &RNDM, ε, stats| IterInfo { + value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), + n_spikes: ν.len(), ε, // postprocessing: config.postprocessing.then(|| ν.clone()), - .. stats + ..stats }; let mut stats = IterInfo::new(); @@ -354,12 +337,10 @@ // Save current base point let μ_base = μ.clone(); - + // Insert new spikes and reweigh let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( - &mut μ, &mut τv, &μ_base, None, - τ, ε, - config, ®, &state, &mut stats + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, ); // (Do not) merge spikes. @@ -371,7 +352,7 @@ // Update inertial prameters let λ_prev = λ; - λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); + λ = 2.0 * λ_prev / (λ_prev + (4.0 + λ_prev * λ_prev).sqrt()); let θ = λ / λ_prev - λ; // Perform inertial update on μ.