src/sliding_fb.rs

branch
dev
changeset 61
4f468d35fa29
parent 49
6b0db7251ebe
child 62
32328a74c790
child 63
7a8a55fd41c0
--- a/src/sliding_fb.rs	Sun Apr 27 15:03:51 2025 -0500
+++ b/src/sliding_fb.rs	Thu Feb 26 11:38:43 2026 -0500
@@ -10,22 +10,21 @@
 use itertools::izip;
 use std::iter::Iterator;
 
+use crate::fb::*;
+use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
+use crate::measures::merging::SpikeMerging;
+use crate::measures::{DiscreteMeasure, Radon, RNDM};
+use crate::plot::Plotter;
+use crate::prox_penalty::{ProxPenalty, StepLengthBound};
+use crate::regularisation::SlidingRegTerm;
+use crate::types::*;
+use alg_tools::error::DynResult;
 use alg_tools::euclidean::Euclidean;
 use alg_tools::iterate::AlgIteratorFactory;
-use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping};
+use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping};
 use alg_tools::nalgebra_support::ToNalgebraRealField;
 use alg_tools::norms::Norm;
-
-use crate::forward_model::{AdjointProductBoundedBy, BoundedCurvature, ForwardModel};
-use crate::measures::merging::SpikeMerging;
-use crate::measures::{DiscreteMeasure, Radon, RNDM};
-use crate::types::*;
-//use crate::tolerance::Tolerance;
-use crate::dataterm::{calculate_residual, calculate_residual2, DataTerm, L2Squared};
-use crate::fb::*;
-use crate::plot::{PlotLookup, Plotting, SeqPlotter};
-use crate::regularisation::SlidingRegTerm;
-//use crate::transport::TransportLipschitz;
+use anyhow::ensure;
 
 /// Transport settings for [`pointsource_sliding_fb_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
@@ -42,21 +41,18 @@
 #[replace_float_literals(F::cast_from(literal))]
 impl<F: Float> TransportConfig<F> {
     /// Check that the parameters are ok. Panics if not.
-    pub fn check(&self) {
-        assert!(self.θ0 > 0.0);
-        assert!(0.0 < self.adaptation && self.adaptation < 1.0);
-        assert!(self.tolerance_mult_con > 0.0);
+    pub fn check(&self) -> DynResult<()> {
+        ensure!(self.θ0 > 0.0);
+        ensure!(0.0 < self.adaptation && self.adaptation < 1.0);
+        ensure!(self.tolerance_mult_con > 0.0);
+        Ok(())
     }
 }
 
 #[replace_float_literals(F::cast_from(literal))]
 impl<F: Float> Default for TransportConfig<F> {
     fn default() -> Self {
-        TransportConfig {
-            θ0: 0.9,
-            adaptation: 0.9,
-            tolerance_mult_con: 100.0,
-        }
+        TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0 }
     }
 }
 
@@ -66,10 +62,14 @@
 pub struct SlidingFBConfig<F: Float> {
     /// Step length scaling
     pub τ0: F,
+    // Auxiliary variable step length scaling for [`crate::sliding_pdps::pointsource_sliding_fb_pair`]
+    pub σp0: F,
     /// Transport parameters
     pub transport: TransportConfig<F>,
     /// Generic parameters
-    pub insertion: FBGenericConfig<F>,
+    pub insertion: InsertionConfig<F>,
+    /// Guess for curvature bound calculations.
+    pub guess: BoundedCurvatureGuess,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
@@ -77,8 +77,10 @@
     fn default() -> Self {
         SlidingFBConfig {
             τ0: 0.99,
+            σp0: 0.99,
             transport: Default::default(),
             insertion: Default::default(),
+            guess: BoundedCurvatureGuess::BetterThanZero,
         }
     }
 }
@@ -100,16 +102,16 @@
 /// with step lengh τ and transport step length `θ_or_adaptive`.
 #[replace_float_literals(F::cast_from(literal))]
 pub(crate) fn initial_transport<F, G, D, const N: usize>(
-    γ1: &mut RNDM<F, N>,
-    μ: &mut RNDM<F, N>,
+    γ1: &mut RNDM<N, F>,
+    μ: &mut RNDM<N, F>,
     τ: F,
     θ_or_adaptive: &mut TransportStepLength<F, G>,
     v: D,
-) -> (Vec<F>, RNDM<F, N>)
+) -> (Vec<F>, RNDM<N, F>)
 where
     F: Float + ToNalgebraRealField,
     G: Fn(F, F) -> F,
-    D: DifferentiableRealMapping<F, N>,
+    D: DifferentiableRealMapping<N, F>,
 {
     use TransportStepLength::*;
 
@@ -145,22 +147,14 @@
                 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
             }
         }
-        AdaptiveMax {
-            l: ℓ_F,
-            ref mut max_transport,
-            g: ref calculate_θ,
-        } => {
+        AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θ } => {
             *max_transport = max_transport.max(γ1.norm(Radon));
             let θτ = τ * calculate_θ(ℓ_F, *max_transport);
             for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
                 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
             }
         }
-        FullyAdaptive {
-            l: ref mut adaptive_ℓ_F,
-            ref mut max_transport,
-            g: ref calculate_θ,
-        } => {
+        FullyAdaptive { l: ref mut adaptive_ℓ_F, ref mut max_transport, g: ref calculate_θ } => {
             *max_transport = max_transport.max(γ1.norm(Radon));
             let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
             // Do two runs through the spikes to update θ, breaking if first run did not cause
@@ -209,9 +203,9 @@
 /// A posteriori transport adaptation.
 #[replace_float_literals(F::cast_from(literal))]
 pub(crate) fn aposteriori_transport<F, const N: usize>(
-    γ1: &mut RNDM<F, N>,
-    μ: &mut RNDM<F, N>,
-    μ_base_minus_γ0: &mut RNDM<F, N>,
+    γ1: &mut RNDM<N, F>,
+    μ: &mut RNDM<N, F>,
+    μ_base_minus_γ0: &mut RNDM<N, F>,
     μ_base_masses: &Vec<F>,
     extra: Option<F>,
     ε: F,
@@ -264,36 +258,33 @@
 /// The parametrisation is as for [`pointsource_fb_reg`].
 /// Inertia is currently not supported.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N: usize>(
-    opA: &A,
-    b: &A::Observable,
-    reg: Reg,
+pub fn pointsource_sliding_fb_reg<F, I, Dat, Reg, Plot, P, const N: usize>(
+    f: &Dat,
+    reg: &Reg,
     prox_penalty: &P,
     config: &SlidingFBConfig<F>,
     iterator: I,
-    mut plotter: SeqPlotter<F, N>,
-) -> RNDM<F, N>
+    mut plotter: Plot,
+    μ0: Option<RNDM<N, F>>,
+) -> DynResult<RNDM<N, F>>
 where
     F: Float + ToNalgebraRealField,
-    I: AlgIteratorFactory<IterInfo<F, N>>,
-    A: ForwardModel<RNDM<F, N>, F>
-        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>
-        + BoundedCurvature<FloatType = F>,
-    for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>,
-    A::PreadjointCodomain: DifferentiableRealMapping<F, N>,
-    RNDM<F, N>: SpikeMerging<F>,
-    Reg: SlidingRegTerm<F, N>,
-    P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
-    PlotLookup: Plotting<N>,
+    I: AlgIteratorFactory<IterInfo<F>>,
+    Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>,
+    Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>,
+    //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>,
+    RNDM<N, F>: SpikeMerging<F>,
+    Reg: SlidingRegTerm<Loc<N, F>, F>,
+    P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
+    Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
 {
     // Check parameters
-    assert!(config.τ0 > 0.0, "Invalid step length parameter");
-    config.transport.check();
+    ensure!(config.τ0 > 0.0, "Invalid step length parameter");
+    config.transport.check()?;
 
     // Initialise iterates
-    let mut μ = DiscreteMeasure::new();
+    let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
     let mut γ1 = DiscreteMeasure::new();
-    let mut residual = -b; // Has to equal $Aμ-b$.
 
     // Set up parameters
     // let opAnorm = opA.opnorm_bound(Radon, L2);
@@ -301,21 +292,21 @@
     //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0);
     //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
     let ℓ = 0.0;
-    let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
-    let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components();
-    let transport_lip = maybe_transport_lip.unwrap();
+    let τ = config.τ0 / prox_penalty.step_length_bound(&f)?;
+    let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess);
+    let transport_lip = maybe_transport_lip?;
     let calculate_θ = |ℓ_F, max_transport| {
         let ℓ_r = transport_lip * max_transport;
         config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r))
     };
-    let mut θ_or_adaptive = match maybe_ℓ_F0 {
+    let mut θ_or_adaptive = match maybe_ℓ_F {
         //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)),
-        Some(ℓ_F0) => TransportStepLength::AdaptiveMax {
-            l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual
+        Ok(ℓ_F) => TransportStepLength::AdaptiveMax {
+            l: ℓ_F, // TODO: could estimate computing the real reesidual
             max_transport: 0.0,
             g: calculate_θ,
         },
-        None => TransportStepLength::FullyAdaptive {
+        Err(_) => TransportStepLength::FullyAdaptive {
             l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials
             max_transport: 0.0,
             g: calculate_θ,
@@ -327,8 +318,8 @@
     let mut ε = tolerance.initial();
 
     // Statistics
-    let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo {
-        value: residual.norm2_squared_div2() + reg.apply(μ),
+    let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
+        value: f.apply(μ) + reg.apply(μ),
         n_spikes: μ.len(),
         ε,
         // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
@@ -337,9 +328,9 @@
     let mut stats = IterInfo::new();
 
     // Run the algorithm
-    for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
+    for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate initial transport
-        let v = opA.preadjoint().apply(residual);
+        let v = f.differential(&μ);
         let (μ_base_masses, mut μ_base_minus_γ0) =
             initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v);
 
@@ -347,8 +338,11 @@
         // regularisation term conforms to the assumptions made for the transport above.
         let (maybe_d, _within_tolerances, mut τv̆) = 'adapt_transport: loop {
             // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
-            let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
-            let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ);
+            //let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
+            // TODO: this could be optimised by doing the differential like the
+            // old residual2.
+            let μ̆ = &γ1 + &μ_base_minus_γ0;
+            let mut τv̆ = f.differential(μ̆) * τ;
 
             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
             let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
@@ -362,7 +356,7 @@
                 &reg,
                 &state,
                 &mut stats,
-            );
+            )?;
 
             // A posteriori transport adaptation.
             if aposteriori_transport(
@@ -404,7 +398,7 @@
                 ε,
                 ins,
                 &reg,
-                Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
+                Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)),
             );
         }
 
@@ -419,26 +413,19 @@
             μ = μ_new;
         }
 
-        // Update residual
-        residual = calculate_residual(&μ, opA, b);
-
         let iter = state.iteration();
         stats.this_iters += 1;
 
         // Give statistics if requested
         state.if_verbose(|| {
             plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
-            full_stats(
-                &residual,
-                &μ,
-                ε,
-                std::mem::replace(&mut stats, IterInfo::new()),
-            )
+            full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
         });
 
         // Update main tolerance for next iteration
         ε = tolerance.update(ε, iter);
     }
 
-    postprocess(μ, &config.insertion, L2Squared, opA, b)
+    //postprocess(μ, &config.insertion, f)
+    postprocess(μ, &config.insertion, |μ̃| f.apply(μ̃))
 }

mercurial