--- a/src/fb.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/fb.rs Thu Feb 26 11:38:43 2026 -0500 @@ -74,37 +74,34 @@ </p> We solve this with either SSN or FB as determined by -[`crate::subproblem::InnerSettings`] in [`FBGenericConfig::inner`]. +[`crate::subproblem::InnerSettings`] in [`InsertionConfig::inner`]. */ +use crate::measures::merging::SpikeMerging; +use crate::measures::{DiscreteMeasure, RNDM}; +use crate::plot::Plotter; +pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound}; +use crate::regularisation::RegTerm; +use crate::types::*; +use alg_tools::error::DynResult; +use alg_tools::instance::Instance; +use alg_tools::iterate::AlgIteratorFactory; +use alg_tools::mapping::DifferentiableMapping; +use alg_tools::nalgebra_support::ToNalgebraRealField; use colored::Colorize; use numeric_literals::replace_float_literals; use serde::{Deserialize, Serialize}; -use alg_tools::euclidean::Euclidean; -use alg_tools::instance::Instance; -use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::linops::{Mapping, GEMV}; -use alg_tools::mapping::RealMapping; -use alg_tools::nalgebra_support::ToNalgebraRealField; - -use crate::dataterm::{calculate_residual, DataTerm, L2Squared}; -use crate::forward_model::{AdjointProductBoundedBy, ForwardModel}; -use crate::measures::merging::SpikeMerging; -use crate::measures::{DiscreteMeasure, RNDM}; -use crate::plot::{PlotLookup, Plotting, SeqPlotter}; -pub use crate::prox_penalty::{FBGenericConfig, ProxPenalty}; -use crate::regularisation::RegTerm; -use crate::types::*; - /// Settings for [`pointsource_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct FBConfig<F: Float> { /// Step length scaling pub τ0: F, + // Auxiliary variable step length scaling for [`crate::forward_pdps::pointsource_fb_pair`] + pub σp0: F, /// Generic parameters - pub generic: FBGenericConfig<F>, + pub insertion: InsertionConfig<F>, } #[replace_float_literals(F::cast_from(literal))] @@ -112,12 +109,13 @@ fn default() -> Self { FBConfig { τ0: 0.99, - generic: Default::default(), + σp0: 0.99, + insertion: Default::default(), } } } -pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<F, N>) -> usize { +pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize { let n_before_prune = μ.len(); μ.prune(); debug_assert!(μ.len() <= n_before_prune); @@ -125,30 +123,19 @@ } #[replace_float_literals(F::cast_from(literal))] -pub(crate) fn postprocess< - F: Float, - V: Euclidean<F> + Clone, - A: GEMV<F, RNDM<F, N>, Codomain = V>, - D: DataTerm<F, V, N>, - const N: usize, ->( - mut μ: RNDM<F, N>, - config: &FBGenericConfig<F>, - dataterm: D, - opA: &A, - b: &V, -) -> RNDM<F, N> +pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>( + mut μ: RNDM<N, F>, + config: &InsertionConfig<F>, + f: Dat, +) -> DynResult<RNDM<N, F>> where - RNDM<F, N>: SpikeMerging<F>, - for<'a> &'a RNDM<F, N>: Instance<RNDM<F, N>>, + RNDM<N, F>: SpikeMerging<F>, + for<'a> &'a RNDM<N, F>: Instance<RNDM<N, F>>, { - μ.merge_spikes_fitness( - config.final_merging_method(), - |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), - |&v| v, - ); + //μ.merge_spikes_fitness(config.final_merging_method(), |μ̃| f.apply(μ̃), |&v| v); + μ.merge_spikes_fitness(config.final_merging_method(), f, |&v| v); μ.prune(); - μ + Ok(μ) } /// Iteratively solve the pointsource localisation problem using forward-backward splitting. @@ -161,50 +148,41 @@ /// /// For details on the mathematical formulation, see the [module level](self) documentation. /// -/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of -/// sums of simple functions usign bisection trees, and the related -/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions -/// active at a specific points, and to maximise their sums. Through the implementation of the -/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features -/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. -/// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fb_reg<F, I, A, Reg, P, const N: usize>( - opA: &A, - b: &A::Observable, - reg: Reg, +pub fn pointsource_fb_reg<F, I, Dat, Reg, P, Plot, const N: usize>( + f: &Dat, + reg: &Reg, prox_penalty: &P, fbconfig: &FBConfig<F>, iterator: I, - mut plotter: SeqPlotter<F, N>, -) -> RNDM<F, N> + mut plotter: Plot, + μ0 : Option<RNDM<N, F>>, +) -> DynResult<RNDM<N, F>> where F: Float + ToNalgebraRealField, - I: AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, - A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, - A::PreadjointCodomain: RealMapping<F, N>, - PlotLookup: Plotting<N>, - RNDM<F, N>: SpikeMerging<F>, - Reg: RegTerm<F, N>, - P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, + I: AlgIteratorFactory<IterInfo<F>>, + RNDM<N, F>: SpikeMerging<F>, + Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, + Dat::DerivativeDomain: ClosedMul<F>, + Reg: RegTerm<Loc<N, F>, F>, + P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, + Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, { // Set up parameters - let config = &fbconfig.generic; - let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); + let config = &fbconfig.insertion; + let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. let tolerance = config.tolerance * τ * reg.tolerance_scaling(); let mut ε = tolerance.initial(); // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut residual = -b; + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); // Statistics - let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { - value: residual.norm2_squared_div2() + reg.apply(μ), + let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo { + value: f.apply(μ) + reg.apply(μ), n_spikes: μ.len(), ε, //postprocessing: config.postprocessing.then(|| μ.clone()), @@ -213,9 +191,10 @@ let mut stats = IterInfo::new(); // Run the algorithm - for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { + for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(residual * τ); + // TODO: optimise τ to be applied to residual. + let mut τv = f.differential(&μ) * τ; // Save current base point let μ_base = μ.clone(); @@ -223,7 +202,7 @@ // Insert and reweigh let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, - ); + )?; // Prune and possibly merge spikes if config.merge_now(&state) { @@ -236,34 +215,27 @@ ε, config, ®, - Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), + Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)), ); } stats.pruned += prune_with_stats(&mut μ); - // Update residual - residual = calculate_residual(&μ, opA, b); - let iter = state.iteration(); stats.this_iters += 1; // Give statistics if needed state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); - full_stats( - &residual, - &μ, - ε, - std::mem::replace(&mut stats, IterInfo::new()), - ) + full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) }); // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } - postprocess(μ, config, L2Squared, opA, b) + //postprocess(μ_prev, config, f) + postprocess(μ, config, |μ̃| f.apply(μ̃)) } /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. @@ -276,38 +248,30 @@ /// /// For details on the mathematical formulation, see the [module level](self) documentation. /// -/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of -/// sums of simple functions usign bisection trees, and the related -/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions -/// active at a specific points, and to maximise their sums. Through the implementation of the -/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features -/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. -/// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fista_reg<F, I, A, Reg, P, const N: usize>( - opA: &A, - b: &A::Observable, - reg: Reg, +pub fn pointsource_fista_reg<F, I, Dat, Reg, P, Plot, const N: usize>( + f: &Dat, + reg: &Reg, prox_penalty: &P, fbconfig: &FBConfig<F>, iterator: I, - mut plotter: SeqPlotter<F, N>, -) -> RNDM<F, N> + mut plotter: Plot, + μ0: Option<RNDM<N, F>> +) -> DynResult<RNDM<N, F>> where F: Float + ToNalgebraRealField, - I: AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable>, - A: ForwardModel<RNDM<F, N>, F> + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>, - A::PreadjointCodomain: RealMapping<F, N>, - PlotLookup: Plotting<N>, - RNDM<F, N>: SpikeMerging<F>, - Reg: RegTerm<F, N>, - P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, + I: AlgIteratorFactory<IterInfo<F>>, + RNDM<N, F>: SpikeMerging<F>, + Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>, + Dat::DerivativeDomain: ClosedMul<F>, + Reg: RegTerm<Loc<N, F>, F>, + P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, + Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, { // Set up parameters - let config = &fbconfig.generic; - let τ = fbconfig.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); + let config = &fbconfig.insertion; + let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; let mut λ = 1.0; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. @@ -315,14 +279,13 @@ let mut ε = tolerance.initial(); // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut μ_prev = DiscreteMeasure::new(); - let mut residual = -b; + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); + let mut μ_prev = μ.clone(); let mut warned_merging = false; // Statistics - let full_stats = |ν: &RNDM<F, N>, ε, stats| IterInfo { - value: L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), + let full_stats = |ν: &RNDM<N, F>, ε, stats| IterInfo { + value: f.apply(ν) + reg.apply(ν), n_spikes: ν.len(), ε, // postprocessing: config.postprocessing.then(|| ν.clone()), @@ -333,7 +296,7 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(residual * τ); + let mut τv = f.differential(&μ) * τ; // Save current base point let μ_base = μ.clone(); @@ -341,7 +304,7 @@ // Insert new spikes and reweigh let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, - ); + )?; // (Do not) merge spikes. if config.merge_now(&state) && !warned_merging { @@ -369,9 +332,6 @@ debug_assert!(μ.len() <= n_before_prune); stats.pruned += n_before_prune - μ.len(); - // Update residual - residual = calculate_residual(&μ, opA, b); - let iter = state.iteration(); stats.this_iters += 1; @@ -385,5 +345,6 @@ ε = tolerance.update(ε, iter); } - postprocess(μ_prev, config, L2Squared, opA, b) + //postprocess(μ_prev, config, f) + postprocess(μ_prev, config, |μ̃| f.apply(μ̃)) }