src/sliding_fb.rs

branch
dev
changeset 49
6b0db7251ebe
parent 46
f358958cc1a6
--- a/src/sliding_fb.rs	Fri Feb 14 23:16:14 2025 -0500
+++ b/src/sliding_fb.rs	Fri Feb 14 23:46:43 2025 -0500
@@ -4,56 +4,43 @@
 */
 
 use numeric_literals::replace_float_literals;
-use serde::{Serialize, Deserialize};
+use serde::{Deserialize, Serialize};
 //use colored::Colorize;
 //use nalgebra::{DVector, DMatrix};
 use itertools::izip;
 use std::iter::Iterator;
 
+use alg_tools::euclidean::Euclidean;
 use alg_tools::iterate::AlgIteratorFactory;
-use alg_tools::euclidean::Euclidean;
-use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance};
+use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping};
+use alg_tools::nalgebra_support::ToNalgebraRealField;
 use alg_tools::norms::Norm;
-use alg_tools::nalgebra_support::ToNalgebraRealField;
 
-use crate::types::*;
-use crate::measures::{DiscreteMeasure, Radon, RNDM};
+use crate::forward_model::{AdjointProductBoundedBy, BoundedCurvature, ForwardModel};
 use crate::measures::merging::SpikeMerging;
-use crate::forward_model::{
-    ForwardModel,
-    AdjointProductBoundedBy,
-    BoundedCurvature,
-};
+use crate::measures::{DiscreteMeasure, Radon, RNDM};
+use crate::types::*;
 //use crate::tolerance::Tolerance;
-use crate::plot::{
-    SeqPlotter,
-    Plotting,
-    PlotLookup
-};
+use crate::dataterm::{calculate_residual, calculate_residual2, DataTerm, L2Squared};
 use crate::fb::*;
+use crate::plot::{PlotLookup, Plotting, SeqPlotter};
 use crate::regularisation::SlidingRegTerm;
-use crate::dataterm::{
-    L2Squared,
-    DataTerm,
-    calculate_residual,
-    calculate_residual2,
-};
 //use crate::transport::TransportLipschitz;
 
 /// Transport settings for [`pointsource_sliding_fb_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
-pub struct TransportConfig<F : Float> {
+pub struct TransportConfig<F: Float> {
     /// Transport step length $θ$ normalised to $(0, 1)$.
-    pub θ0 : F,
+    pub θ0: F,
     /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance.
-    pub adaptation : F,
+    pub adaptation: F,
     /// A posteriori transport tolerance multiplier (C_pos)
-    pub tolerance_mult_con : F,
+    pub tolerance_mult_con: F,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl <F : Float> TransportConfig<F> {
+impl<F: Float> TransportConfig<F> {
     /// Check that the parameters are ok. Panics if not.
     pub fn check(&self) {
         assert!(self.θ0 > 0.0);
@@ -63,12 +50,12 @@
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for TransportConfig<F> {
+impl<F: Float> Default for TransportConfig<F> {
     fn default() -> Self {
         TransportConfig {
-            θ0 : 0.9,
-            adaptation : 0.9,
-            tolerance_mult_con : 100.0,
+            θ0: 0.9,
+            adaptation: 0.9,
+            tolerance_mult_con: 100.0,
         }
     }
 }
@@ -76,55 +63,54 @@
 /// Settings for [`pointsource_sliding_fb_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
-pub struct SlidingFBConfig<F : Float> {
+pub struct SlidingFBConfig<F: Float> {
     /// Step length scaling
-    pub τ0 : F,
+    pub τ0: F,
     /// Transport parameters
-    pub transport : TransportConfig<F>,
+    pub transport: TransportConfig<F>,
     /// Generic parameters
-    pub insertion : FBGenericConfig<F>,
+    pub insertion: FBGenericConfig<F>,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for SlidingFBConfig<F> {
+impl<F: Float> Default for SlidingFBConfig<F> {
     fn default() -> Self {
         SlidingFBConfig {
-            τ0 : 0.99,
-            transport : Default::default(),
-            insertion : Default::default()
+            τ0: 0.99,
+            transport: Default::default(),
+            insertion: Default::default(),
         }
     }
 }
 
 /// Internal type of adaptive transport step length calculation
-pub(crate) enum TransportStepLength<F : Float, G : Fn(F, F) -> F> {
+pub(crate) enum TransportStepLength<F: Float, G: Fn(F, F) -> F> {
     /// Fixed, known step length
     #[allow(dead_code)]
     Fixed(F),
     /// Adaptive step length, only wrt. maximum transport.
     /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
-    AdaptiveMax{ l : F, max_transport : F, g : G },
+    AdaptiveMax { l: F, max_transport: F, g: G },
     /// Adaptive step length.
     /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
-    FullyAdaptive{ l : F, max_transport : F, g : G },
+    FullyAdaptive { l: F, max_transport: F, g: G },
 }
 
 /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)`
 /// with step lengh τ and transport step length `θ_or_adaptive`.
 #[replace_float_literals(F::cast_from(literal))]
-pub(crate) fn initial_transport<F, G, D, const N : usize>(
-    γ1 : &mut RNDM<F, N>,
-    μ : &mut RNDM<F, N>,
-    τ : F,
-    θ_or_adaptive : &mut TransportStepLength<F, G>,
-    v : D,
+pub(crate) fn initial_transport<F, G, D, const N: usize>(
+    γ1: &mut RNDM<F, N>,
+    μ: &mut RNDM<F, N>,
+    τ: F,
+    θ_or_adaptive: &mut TransportStepLength<F, G>,
+    v: D,
 ) -> (Vec<F>, RNDM<F, N>)
 where
-    F : Float + ToNalgebraRealField,
-    G : Fn(F, F) -> F,
-    D : DifferentiableRealMapping<F, N>,
+    F: Float + ToNalgebraRealField,
+    G: Fn(F, F) -> F,
+    D: DifferentiableRealMapping<F, N>,
 {
-
     use TransportStepLength::*;
 
     // Save current base point and shift μ to new positions. Idea is that
@@ -132,10 +118,10 @@
     //  μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
     //  γ1 = π_♯^1γ^{k+1}
     //  μ = μ^{k+1}
-    let μ_base_masses : Vec<F> = μ.iter_masses().collect();
+    let μ_base_masses: Vec<F> = μ.iter_masses().collect();
     let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
-    // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
-    //let mut sum_norm_dv = 0.0;
+                                         // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
+                                         //let mut sum_norm_dv = 0.0;
     let γ_prev_len = γ1.len();
     assert!(μ.len() >= γ_prev_len);
     γ1.extend(μ[γ_prev_len..].iter().cloned());
@@ -149,7 +135,7 @@
         } else {
             δ.α
         };
-    };
+    }
 
     // Calculate transport rays.
     match *θ_or_adaptive {
@@ -158,15 +144,23 @@
             for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
                 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
             }
-        },
-        AdaptiveMax{ l : ℓ_F, ref mut max_transport, g : ref calculate_θ } => {
+        }
+        AdaptiveMax {
+            l: ℓ_F,
+            ref mut max_transport,
+            g: ref calculate_θ,
+        } => {
             *max_transport = max_transport.max(γ1.norm(Radon));
             let θτ = τ * calculate_θ(ℓ_F, *max_transport);
             for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
                 ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
             }
-        },
-        FullyAdaptive{ l : ref mut adaptive_ℓ_F, ref mut max_transport, g : ref calculate_θ } => {
+        }
+        FullyAdaptive {
+            l: ref mut adaptive_ℓ_F,
+            ref mut max_transport,
+            g: ref calculate_θ,
+        } => {
             *max_transport = max_transport.max(γ1.norm(Radon));
             let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport);
             // Do two runs through the spikes to update θ, breaking if first run did not cause
@@ -187,7 +181,7 @@
                     }
                 }
                 if !changes {
-                    break
+                    break;
                 }
             }
         }
@@ -203,24 +197,29 @@
         }
     }
     // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b)
-    μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses())
-                                                   .map(|(&a,b)| a - b));
+    μ_base_minus_γ0.set_masses(
+        μ_base_masses
+            .iter()
+            .zip(γ1.iter_masses())
+            .map(|(&a, b)| a - b),
+    );
     (μ_base_masses, μ_base_minus_γ0)
 }
 
 /// A posteriori transport adaptation.
 #[replace_float_literals(F::cast_from(literal))]
-pub(crate) fn aposteriori_transport<F, const N : usize>(
-    γ1 : &mut RNDM<F, N>,
-    μ : &mut RNDM<F, N>,
-    μ_base_minus_γ0 : &mut RNDM<F, N>,
-    μ_base_masses : &Vec<F>,
-    extra : Option<F>,
-    ε : F,
-    tconfig : &TransportConfig<F>
+pub(crate) fn aposteriori_transport<F, const N: usize>(
+    γ1: &mut RNDM<F, N>,
+    μ: &mut RNDM<F, N>,
+    μ_base_minus_γ0: &mut RNDM<F, N>,
+    μ_base_masses: &Vec<F>,
+    extra: Option<F>,
+    ε: F,
+    tconfig: &TransportConfig<F>,
 ) -> bool
-where F : Float + ToNalgebraRealField {
-
+where
+    F: Float + ToNalgebraRealField,
+{
     // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
     // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
     // at that point to zero, and retry.
@@ -238,19 +237,22 @@
     let nγ = γ1.norm(Radon);
     let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0);
     let t = ε * tconfig.tolerance_mult_con;
-    if nγ*nΔ > t {
+    if nγ * nΔ > t {
         // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
         // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
         // will not enter here.
-        *γ1 *= tconfig.adaptation * t / ( nγ * nΔ );
+        *γ1 *= tconfig.adaptation * t / (nγ * nΔ);
         all_ok = false
     }
 
     if !all_ok {
         // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
-        μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses())
-                                                        .map(|(&a,b)| a - b));
-
+        μ_base_minus_γ0.set_masses(
+            μ_base_masses
+                .iter()
+                .zip(γ1.iter_masses())
+                .map(|(&a, b)| a - b),
+        );
     }
 
     all_ok
@@ -262,29 +264,28 @@
 /// The parametrisation is as for [`pointsource_fb_reg`].
 /// Inertia is currently not supported.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N : usize>(
-    opA : &A,
-    b : &A::Observable,
-    reg : Reg,
-    prox_penalty : &P,
-    config : &SlidingFBConfig<F>,
-    iterator : I,
-    mut plotter : SeqPlotter<F, N>,
+pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N: usize>(
+    opA: &A,
+    b: &A::Observable,
+    reg: Reg,
+    prox_penalty: &P,
+    config: &SlidingFBConfig<F>,
+    iterator: I,
+    mut plotter: SeqPlotter<F, N>,
 ) -> RNDM<F, N>
 where
-    F : Float + ToNalgebraRealField,
-    I : AlgIteratorFactory<IterInfo<F, N>>,
-    A : ForwardModel<RNDM<F, N>, F>
-        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>
-        + BoundedCurvature<FloatType=F>,
-    for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
-    A::PreadjointCodomain : DifferentiableRealMapping<F, N>,
-    RNDM<F, N> : SpikeMerging<F>,
-    Reg : SlidingRegTerm<F, N>,
-    P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
-    PlotLookup : Plotting<N>,
+    F: Float + ToNalgebraRealField,
+    I: AlgIteratorFactory<IterInfo<F, N>>,
+    A: ForwardModel<RNDM<F, N>, F>
+        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType = F>
+        + BoundedCurvature<FloatType = F>,
+    for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>,
+    A::PreadjointCodomain: DifferentiableRealMapping<F, N>,
+    RNDM<F, N>: SpikeMerging<F>,
+    Reg: SlidingRegTerm<F, N>,
+    P: ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
+    PlotLookup: Plotting<N>,
 {
-
     // Check parameters
     assert!(config.τ0 > 0.0, "Invalid step length parameter");
     config.transport.check();
@@ -301,23 +302,23 @@
     //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
     let ℓ = 0.0;
     let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
-    let (maybe_ℓ_v0, maybe_transport_lip) = opA.curvature_bound_components();
+    let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components();
     let transport_lip = maybe_transport_lip.unwrap();
-    let calculate_θ = |ℓ_v, max_transport| {
-        let ℓ_F = ℓ_v + transport_lip * max_transport;
-        config.transport.θ0 / (τ*(ℓ + ℓ_F))
+    let calculate_θ = |ℓ_F, max_transport| {
+        let ℓ_r = transport_lip * max_transport;
+        config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r))
     };
-    let mut θ_or_adaptive = match maybe_ℓ_v0 {
-        //Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * b.norm2(), 0.0)),
-        Some(ℓ_v0) => TransportStepLength::AdaptiveMax {
-            l: ℓ_v0 * b.norm2(), // TODO: could estimate computing the real reesidual
-            max_transport : 0.0,
-            g : calculate_θ
+    let mut θ_or_adaptive = match maybe_ℓ_F0 {
+        //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)),
+        Some(ℓ_F0) => TransportStepLength::AdaptiveMax {
+            l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual
+            max_transport: 0.0,
+            g: calculate_θ,
         },
         None => TransportStepLength::FullyAdaptive {
-            l : 10.0 * F::EPSILON, // Start with something very small to estimate differentials
-            max_transport : 0.0,
-            g : calculate_θ
+            l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials
+            max_transport: 0.0,
+            g: calculate_θ,
         },
     };
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
@@ -326,14 +327,12 @@
     let mut ε = tolerance.initial();
 
     // Statistics
-    let full_stats = |residual : &A::Observable,
-                      μ : &RNDM<F, N>,
-                      ε, stats| IterInfo {
-        value : residual.norm2_squared_div2() + reg.apply(μ),
-        n_spikes : μ.len(),
+    let full_stats = |residual: &A::Observable, μ: &RNDM<F, N>, ε, stats| IterInfo {
+        value: residual.norm2_squared_div2() + reg.apply(μ),
+        n_spikes: μ.len(),
         ε,
         // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
-        .. stats
+        ..stats
     };
     let mut stats = IterInfo::new();
 
@@ -341,9 +340,8 @@
     for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
         // Calculate initial transport
         let v = opA.preadjoint().apply(residual);
-        let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport(
-            &mut γ1, &mut μ, τ, &mut θ_or_adaptive, v
-        );
+        let (μ_base_masses, mut μ_base_minus_γ0) =
+            initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v);
 
         // Solve finite-dimensional subproblem several times until the dual variable for the
         // regularisation term conforms to the assumptions made for the transport above.
@@ -354,18 +352,29 @@
 
             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
             let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
-                &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0),
-                τ, ε, &config.insertion,
-                &reg, &state, &mut stats,
+                &mut μ,
+                &mut τv̆,
+                &γ1,
+                Some(&μ_base_minus_γ0),
+                τ,
+                ε,
+                &config.insertion,
+                &reg,
+                &state,
+                &mut stats,
             );
 
             // A posteriori transport adaptation.
             if aposteriori_transport(
-                &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses,
+                &mut γ1,
+                &mut μ,
+                &mut μ_base_minus_γ0,
+                &μ_base_masses,
                 None,
-                ε, &config.transport
+                ε,
+                &config.transport,
             ) {
-                break 'adapt_transport (maybe_d, within_tolerances, τv̆)
+                break 'adapt_transport (maybe_d, within_tolerances, τv̆);
             }
         };
 
@@ -387,8 +396,15 @@
         let ins = &config.insertion;
         if ins.merge_now(&state) {
             stats.merged += prox_penalty.merge_spikes(
-                &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, ins, &reg,
-                Some(|μ̃ : &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
+                &mut μ,
+                &mut τv̆,
+                &γ1,
+                Some(&μ_base_minus_γ0),
+                τ,
+                ε,
+                ins,
+                &reg,
+                Some(|μ̃: &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
             );
         }
 
@@ -412,7 +428,12 @@
         // Give statistics if requested
         state.if_verbose(|| {
             plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
-            full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
+            full_stats(
+                &residual,
+                &μ,
+                ε,
+                std::mem::replace(&mut stats, IterInfo::new()),
+            )
         });
 
         // Update main tolerance for next iteration

mercurial