--- a/src/fb.rs Mon Feb 17 13:45:11 2025 -0500 +++ b/src/fb.rs Mon Feb 17 13:51:50 2025 -0500 @@ -74,69 +74,50 @@ </p> We solve this with either SSN or FB as determined by -[`InnerSettings`] in [`FBGenericConfig::inner`]. +[`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`]. */ +use colored::Colorize; use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; -use colored::Colorize; +use serde::{Deserialize, Serialize}; +use alg_tools::euclidean::Euclidean; +use alg_tools::instance::Instance; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::euclidean::Euclidean; use alg_tools::linops::{Mapping, GEMV}; use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::instance::Instance; -use crate::types::*; -use crate::measures::{ - DiscreteMeasure, - RNDM, -}; +use crate::dataterm::{calculate_residual, DataTerm, L2Squared}; +use crate::forward_model::{AdjointProductBoundedBy, ForwardModel}; use crate::measures::merging::SpikeMerging; -use crate::forward_model::{ - ForwardModel, - AdjointProductBoundedBy, -}; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; +use crate::measures::{DiscreteMeasure, RNDM}; +use crate::plot::{PlotLookup, Plotting, SeqPlotter}; +pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty}; use crate::regularisation::RegTerm; -use crate::dataterm::{ - calculate_residual, - L2Squared, - DataTerm, -}; -pub use crate::prox_penalty::{ - FBGenericConfig, - ProxPenalty -}; +use crate::types::*; /// Settings for [`pointsource_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct FBConfig<F : Float> { +pub struct FBConfig<F: Float> { /// Step length scaling - pub τ0 : F, + pub τ0: F, /// Generic parameters - pub generic : FBGenericConfig<F>, + pub generic: FBGenericConfig<F>, } #[replace_float_literals(F::cast_from(literal))] -impl<F : Float> Default for FBConfig<F> { +impl<F: Float> Default for FBConfig<F> { fn default() -> Self { FBConfig { - τ0 : 0.99, - generic : Default::default(), + τ0: 0.99, + generic: Default::default(), } } } -pub(crate) fn prune_with_stats<F : Float, const N : usize>( - μ : &mut RNDM<F, N>, -) -> usize { +pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize { let n_before_prune = μ.len(); μ.prune(); debug_assert!(μ.len() <= n_before_prune); @@ -145,25 +126,27 @@ #[replace_float_literals(F::cast_from(literal))] pub(crate) fn postprocess< - F : Float, - V : Euclidean<F> + Clone, - A : GEMV<F, RNDM<F, N>, Codomain = V>, - D : DataTerm<F, V, N>, - const N : usize -> ( - mut μ : RNDM<F, N>, - config : &FBGenericConfig<F>, - dataterm : D, - opA : &A, - b : &V, + F: Float, + V: Euclidean<F> + Clone, + A: GEMV<F, RNDM<F, N>, Codomain = V>, + D: DataTerm<F, V, N>, + const N: usize, +>( + mut μ: RNDM<F, N>, + config: &FBGenericConfig<F>, + dataterm: D, + opA: &A, + b: &V, ) -> RNDM<F, N> where - RNDM<F, N> : SpikeMerging<F>, - for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>, + RNDM<F, N>: SpikeMerging<F>, + for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>, { - μ.merge_spikes_fitness(config.final_merging_method(), - |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), - |&v| v); + μ.merge_spikes_fitness( + config.final_merging_method(), + |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), + |&v| v, + ); μ.prune(); μ } @@ -187,33 +170,29 @@ /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fb_reg< - F, I, A, Reg, P, const N : usize ->( - opA : &A, - b : &A::Observable, - reg : Reg, - prox_penalty : &P, - fbconfig : &FBConfig<F>, - iterator : I, - mut plotter : SeqPlotter<F, N>, +pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>( + opA: &A, + b: &A::Observable, + reg: Reg, + prox_penalty: &P, + fbconfig: &FBConfig<F>, + iterator: I, + mut plotter: SeqPlotter<F, N>, ) -> RNDM<F, N> where - F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, - A : ForwardModel<RNDM<F, N>, F> - + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, - A::PreadjointCodomain : RealMapping<F, N>, - PlotLookup : Plotting<N>, - RNDM<F, N> : SpikeMerging<F>, - Reg : RegTerm<F, N>, - P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory<IterInfo<F, N>>, + for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, + A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, + A::PreadjointCodomain: RealMapping<F, N>, + PlotLookup: Plotting<N>, + RNDM<F, N>: SpikeMerging<F>, + Reg: RegTerm<F, N>, + P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, { - // Set up parameters let config = &fbconfig.generic; - let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); + let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. let tolerance = config.tolerance * τ * reg.tolerance_scaling(); @@ -224,14 +203,12 @@ let mut residual = -b; // Statistics - let full_stats = |residual : &A::Observable, - μ : &RNDM<F, N>, - ε, stats| IterInfo { - value : residual.norm2_squared_div2() + reg.apply(μ), - n_spikes : μ.len(), + let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { + value: residual.norm2_squared_div2() + reg.apply(μ), + n_spikes: μ.len(), ε, //postprocessing: config.postprocessing.then(|| μ.clone()), - .. stats + ..stats }; let mut stats = IterInfo::new(); @@ -242,19 +219,24 @@ // Save current base point let μ_base = μ.clone(); - + // Insert and reweigh let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( - &mut μ, &mut τv, &μ_base, None, - τ, ε, - config, ®, &state, &mut stats + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, ); // Prune and possibly merge spikes if config.merge_now(&state) { stats.merged += prox_penalty.merge_spikes( - &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, - Some(|μ̃ : &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), + &mut μ, + &mut τv, + &μ_base, + None, + τ, + ε, + config, + ®, + Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), ); } @@ -269,9 +251,14 @@ // Give statistics if needed state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); - full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) + full_stats( + &residual, + &μ, + ε, + std::mem::replace(&mut stats, IterInfo::new()), + ) }); - + // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } @@ -298,33 +285,29 @@ /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fista_reg< - F, I, A, Reg, P, const N : usize ->( - opA : &A, - b : &A::Observable, - reg : Reg, - prox_penalty : &P, - fbconfig : &FBConfig<F>, - iterator : I, - mut plotter : SeqPlotter<F, N>, +pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>( + opA: &A, + b: &A::Observable, + reg: Reg, + prox_penalty: &P, + fbconfig: &FBConfig<F>, + iterator: I, + mut plotter: SeqPlotter<F, N>, ) -> RNDM<F, N> where - F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, - A : ForwardModel<RNDM<F, N>, F> - + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, - A::PreadjointCodomain : RealMapping<F, N>, - PlotLookup : Plotting<N>, - RNDM<F, N> : SpikeMerging<F>, - Reg : RegTerm<F, N>, - P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory<IterInfo<F, N>>, + for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, + A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, + A::PreadjointCodomain: RealMapping<F, N>, + PlotLookup: Plotting<N>, + RNDM<F, N>: SpikeMerging<F>, + Reg: RegTerm<F, N>, + P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, { - // Set up parameters let config = &fbconfig.generic; - let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); + let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); let mut λ = 1.0; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. @@ -338,12 +321,12 @@ let mut warned_merging = false; // Statistics - let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo { - value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), - n_spikes : ν.len(), + let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo { + value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), + n_spikes: ν.len(), ε, // postprocessing: config.postprocessing.then(|| ν.clone()), - .. stats + ..stats }; let mut stats = IterInfo::new(); @@ -354,12 +337,10 @@ // Save current base point let μ_base = μ.clone(); - + // Insert new spikes and reweigh let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( - &mut μ, &mut τv, &μ_base, None, - τ, ε, - config, ®, &state, &mut stats + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, ); // (Do not) merge spikes. @@ -371,7 +352,7 @@ // Update inertial prameters let λ_prev = λ; - λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); + λ = 2.0 * λ_prev / (λ_prev + (4.0 + λ_prev * λ_prev).sqrt()); let θ = λ / λ_prev - λ; // Perform inertial update on μ.