--- a/src/sliding_fb.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/sliding_fb.rs Thu Feb 26 11:38:43 2026 -0500 @@ -10,22 +10,21 @@ use itertools::izip; use std::iter::Iterator; +use crate::fb::*; +use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; +use crate::measures::merging::SpikeMerging; +use crate::measures::{DiscreteMeasure, Radon, RNDM}; +use crate::plot::Plotter; +use crate::prox_penalty::{ProxPenalty, StepLengthBound}; +use crate::regularisation::SlidingRegTerm; +use crate::types::*; +use alg_tools::error::DynResult; use alg_tools::euclidean::Euclidean; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping}; +use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping}; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::norms::Norm; - -use crate::forward_model::{AdjointProductBoundedBy, BoundedCurvature, ForwardModel}; -use crate::measures::merging::SpikeMerging; -use crate::measures::{DiscreteMeasure, Radon, RNDM}; -use crate::types::*; -//use crate::tolerance::Tolerance; -use crate::dataterm::{calculate_residual, calculate_residual2, DataTerm, L2Squared}; -use crate::fb::*; -use crate::plot::{PlotLookup, Plotting, SeqPlotter}; -use crate::regularisation::SlidingRegTerm; -//use crate::transport::TransportLipschitz; +use anyhow::ensure; /// Transport settings for [`pointsource_sliding_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] @@ -42,21 +41,18 @@ #[replace_float_literals(F::cast_from(literal))] impl<F: Float> TransportConfig<F> { /// Check that the parameters are ok. Panics if not. - pub fn check(&self) { - assert!(self.θ0 > 0.0); - assert!(0.0 < self.adaptation && self.adaptation < 1.0); - assert!(self.tolerance_mult_con > 0.0); + pub fn check(&self) -> DynResult<()> { + ensure!(self.θ0 > 0.0); + ensure!(0.0 < self.adaptation && self.adaptation < 1.0); + ensure!(self.tolerance_mult_con > 0.0); + Ok(()) } } #[replace_float_literals(F::cast_from(literal))] impl<F: Float> Default for TransportConfig<F> { fn default() -> Self { - TransportConfig { - θ0: 0.9, - adaptation: 0.9, - tolerance_mult_con: 100.0, - } + TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0 } } } @@ -66,10 +62,14 @@ pub struct SlidingFBConfig<F: Float> { /// Step length scaling pub τ0: F, + // Auxiliary variable step length scaling for [`crate::sliding_pdps::pointsource_sliding_fb_pair`] + pub σp0: F, /// Transport parameters pub transport: TransportConfig<F>, /// Generic parameters - pub insertion: FBGenericConfig<F>, + pub insertion: InsertionConfig<F>, + /// Guess for curvature bound calculations. + pub guess: BoundedCurvatureGuess, } #[replace_float_literals(F::cast_from(literal))] @@ -77,8 +77,10 @@ fn default() -> Self { SlidingFBConfig { τ0: 0.99, + σp0: 0.99, transport: Default::default(), insertion: Default::default(), + guess: BoundedCurvatureGuess::BetterThanZero, } } } @@ -100,16 +102,16 @@ /// with step lengh τ and transport step length `θ_or_adaptive`. #[replace_float_literals(F::cast_from(literal))] pub(crate) fn initial_transport<F, G, D, const N: usize>( - γ1: &mut RNDM<F, N>, - μ: &mut RNDM<F, N>, + γ1: &mut RNDM<N, F>, + μ: &mut RNDM<N, F>, τ: F, θ_or_adaptive: &mut TransportStepLength<F, G>, v: D, -) -> (Vec<F>, RNDM<F, N>) +) -> (Vec<F>, RNDM<N, F>) where F: Float + ToNalgebraRealField, G: Fn(F, F) -> F, - D: DifferentiableRealMapping<F, N>, + D: DifferentiableRealMapping<N, F>, { use TransportStepLength::*; @@ -145,22 +147,14 @@ ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); } } - AdaptiveMax { - l: ℓ_F, - ref mut max_transport, - g: ref calculate_θ, - } => { + AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θ } => { *max_transport = max_transport.max(γ1.norm(Radon)); let θτ = τ * calculate_θ(ℓ_F, *max_transport); for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); } } - FullyAdaptive { - l: ref mut adaptive_ℓ_F, - ref mut max_transport, - g: ref calculate_θ, - } => { + FullyAdaptive { l: ref mut adaptive_ℓ_F, ref mut max_transport, g: ref calculate_θ } => { *max_transport = max_transport.max(γ1.norm(Radon)); let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport); // Do two runs through the spikes to update θ, breaking if first run did not cause @@ -209,9 +203,9 @@ /// A posteriori transport adaptation. #[replace_float_literals(F::cast_from(literal))] pub(crate) fn aposteriori_transport<F, const N: usize>( - γ1: &mut RNDM<F, N>, - μ: &mut RNDM<F, N>, - μ_base_minus_γ0: &mut RNDM<F, N>, + γ1: &mut RNDM<N, F>, + μ: &mut RNDM<N, F>, + μ_base_minus_γ0: &mut RNDM<N, F>, μ_base_masses: &Vec<F>, extra: Option<F>, ε: F, @@ -264,36 +258,33 @@ /// The parametrisation is as for [`pointsource_fb_reg`]. /// Inertia is currently not supported. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N: usize>( - opA: &A, - b: &A::Observable, - reg: Reg, +pub fn pointsource_sliding_fb_reg<F, I, Dat, Reg, Plot, P, const N: usize>( + f: &Dat, + reg: &Reg, prox_penalty: &P, config: &SlidingFBConfig<F>, iterator: I, - mut plotter: SeqPlotter<F, N>, -) -> RNDM<F, N> + mut plotter: Plot, + μ0: Option<RNDM<N, F>>, +) -> DynResult<RNDM<N, F>> where F: Float + ToNalgebraRealField, - I: AlgIteratorFactory<IterInfo<F, N>>, - A: ForwardModel<RNDM<F, N>, F> - + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F> - + BoundedCurvature<FloatType = F>, - for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>, - A::PreadjointCodomain: DifferentiableRealMapping<F, N>, - RNDM<F, N>: SpikeMerging<F>, - Reg: SlidingRegTerm<F, N>, - P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>, - PlotLookup: Plotting<N>, + I: AlgIteratorFactory<IterInfo<F>>, + Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>, + Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>, + //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>, + RNDM<N, F>: SpikeMerging<F>, + Reg: SlidingRegTerm<Loc<N, F>, F>, + P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, + Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, { // Check parameters - assert!(config.τ0 > 0.0, "Invalid step length parameter"); - config.transport.check(); + ensure!(config.τ0 > 0.0, "Invalid step length parameter"); + config.transport.check()?; // Initialise iterates - let mut μ = DiscreteMeasure::new(); + let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); let mut γ1 = DiscreteMeasure::new(); - let mut residual = -b; // Has to equal $Aμ-b$. // Set up parameters // let opAnorm = opA.opnorm_bound(Radon, L2); @@ -301,21 +292,21 @@ // * reg.radon_norm_bound(b.norm2_squared() / 2.0); //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; let ℓ = 0.0; - let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); - let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); - let transport_lip = maybe_transport_lip.unwrap(); + let τ = config.τ0 / prox_penalty.step_length_bound(&f)?; + let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); + let transport_lip = maybe_transport_lip?; let calculate_θ = |ℓ_F, max_transport| { let ℓ_r = transport_lip * max_transport; config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r)) }; - let mut θ_or_adaptive = match maybe_ℓ_F0 { + let mut θ_or_adaptive = match maybe_ℓ_F { //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)), - Some(ℓ_F0) => TransportStepLength::AdaptiveMax { - l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual + Ok(ℓ_F) => TransportStepLength::AdaptiveMax { + l: ℓ_F, // TODO: could estimate computing the real reesidual max_transport: 0.0, g: calculate_θ, }, - None => TransportStepLength::FullyAdaptive { + Err(_) => TransportStepLength::FullyAdaptive { l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials max_transport: 0.0, g: calculate_θ, @@ -327,8 +318,8 @@ let mut ε = tolerance.initial(); // Statistics - let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo { - value: residual.norm2_squared_div2() + reg.apply(μ), + let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo { + value: f.apply(μ) + reg.apply(μ), n_spikes: μ.len(), ε, // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), @@ -337,9 +328,9 @@ let mut stats = IterInfo::new(); // Run the algorithm - for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { + for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate initial transport - let v = opA.preadjoint().apply(residual); + let v = f.differential(&μ); let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); @@ -347,8 +338,11 @@ // regularisation term conforms to the assumptions made for the transport above. let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) - let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); - let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); + //let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); + // TODO: this could be optimised by doing the differential like the + // old residual2. + let μ̆ = &γ1 + &μ_base_minus_γ0; + let mut τv̆ = f.differential(μ̆) * τ; // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( @@ -362,7 +356,7 @@ ®, &state, &mut stats, - ); + )?; // A posteriori transport adaptation. if aposteriori_transport( @@ -404,7 +398,7 @@ ε, ins, ®, - Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), + Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)), ); } @@ -419,26 +413,19 @@ μ = μ_new; } - // Update residual - residual = calculate_residual(&μ, opA, b); - let iter = state.iteration(); stats.this_iters += 1; // Give statistics if requested state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); - full_stats( - &residual, - &μ, - ε, - std::mem::replace(&mut stats, IterInfo::new()), - ) + full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) }); // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } - postprocess(μ, &config.insertion, L2Squared, opA, b) + //postprocess(μ, &config.insertion, f) + postprocess(μ, &config.insertion, |μ̃| f.apply(μ̃)) }