diff -r 53136eba9abf -r 6b0db7251ebe src/sliding_fb.rs --- a/src/sliding_fb.rs Fri Feb 14 23:16:14 2025 -0500 +++ b/src/sliding_fb.rs Fri Feb 14 23:46:43 2025 -0500 @@ -4,56 +4,43 @@ */ use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; //use colored::Colorize; //use nalgebra::{DVector, DMatrix}; use itertools::izip; use std::iter::Iterator; +use alg_tools::euclidean::Euclidean; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::euclidean::Euclidean; -use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; +use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping}; +use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::norms::Norm; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use crate::types::*; -use crate::measures::{DiscreteMeasure, Radon, RNDM}; +use crate::forward_model::{AdjointProductBoundedBy, BoundedCurvature, ForwardModel}; use crate::measures::merging::SpikeMerging; -use crate::forward_model::{ - ForwardModel, - AdjointProductBoundedBy, - BoundedCurvature, -}; +use crate::measures::{DiscreteMeasure, Radon, RNDM}; +use crate::types::*; //use crate::tolerance::Tolerance; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; +use crate::dataterm::{calculate_residual, calculate_residual2, DataTerm, L2Squared}; use crate::fb::*; +use crate::plot::{PlotLookup, Plotting, SeqPlotter}; use crate::regularisation::SlidingRegTerm; -use crate::dataterm::{ - L2Squared, - DataTerm, - calculate_residual, - calculate_residual2, -}; //use crate::transport::TransportLipschitz; /// Transport settings for [`pointsource_sliding_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct TransportConfig { +pub struct TransportConfig { /// Transport step length $θ$ normalised to $(0, 1)$. - pub θ0 : F, + pub θ0: F, /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. - pub adaptation : F, + pub adaptation: F, /// A posteriori transport tolerance multiplier (C_pos) - pub tolerance_mult_con : F, + pub tolerance_mult_con: F, } #[replace_float_literals(F::cast_from(literal))] -impl TransportConfig { +impl TransportConfig { /// Check that the parameters are ok. Panics if not. pub fn check(&self) { assert!(self.θ0 > 0.0); @@ -63,12 +50,12 @@ } #[replace_float_literals(F::cast_from(literal))] -impl Default for TransportConfig { +impl Default for TransportConfig { fn default() -> Self { TransportConfig { - θ0 : 0.9, - adaptation : 0.9, - tolerance_mult_con : 100.0, + θ0: 0.9, + adaptation: 0.9, + tolerance_mult_con: 100.0, } } } @@ -76,55 +63,54 @@ /// Settings for [`pointsource_sliding_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct SlidingFBConfig { +pub struct SlidingFBConfig { /// Step length scaling - pub τ0 : F, + pub τ0: F, /// Transport parameters - pub transport : TransportConfig, + pub transport: TransportConfig, /// Generic parameters - pub insertion : FBGenericConfig, + pub insertion: FBGenericConfig, } #[replace_float_literals(F::cast_from(literal))] -impl Default for SlidingFBConfig { +impl Default for SlidingFBConfig { fn default() -> Self { SlidingFBConfig { - τ0 : 0.99, - transport : Default::default(), - insertion : Default::default() + τ0: 0.99, + transport: Default::default(), + insertion: Default::default(), } } } /// Internal type of adaptive transport step length calculation -pub(crate) enum TransportStepLength F> { +pub(crate) enum TransportStepLength F> { /// Fixed, known step length #[allow(dead_code)] Fixed(F), /// Adaptive step length, only wrt. maximum transport. /// Content of `l` depends on use case, while `g` calculates the step length from `l`. - AdaptiveMax{ l : F, max_transport : F, g : G }, + AdaptiveMax { l: F, max_transport: F, g: G }, /// Adaptive step length. /// Content of `l` depends on use case, while `g` calculates the step length from `l`. - FullyAdaptive{ l : F, max_transport : F, g : G }, + FullyAdaptive { l: F, max_transport: F, g: G }, } /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)` /// with step lengh τ and transport step length `θ_or_adaptive`. #[replace_float_literals(F::cast_from(literal))] -pub(crate) fn initial_transport( - γ1 : &mut RNDM, - μ : &mut RNDM, - τ : F, - θ_or_adaptive : &mut TransportStepLength, - v : D, +pub(crate) fn initial_transport( + γ1: &mut RNDM, + μ: &mut RNDM, + τ: F, + θ_or_adaptive: &mut TransportStepLength, + v: D, ) -> (Vec, RNDM) where - F : Float + ToNalgebraRealField, - G : Fn(F, F) -> F, - D : DifferentiableRealMapping, + F: Float + ToNalgebraRealField, + G: Fn(F, F) -> F, + D: DifferentiableRealMapping, { - use TransportStepLength::*; // Save current base point and shift μ to new positions. Idea is that @@ -132,10 +118,10 @@ // μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} // γ1 = π_♯^1γ^{k+1} // μ = μ^{k+1} - let μ_base_masses : Vec = μ.iter_masses().collect(); + let μ_base_masses: Vec = μ.iter_masses().collect(); let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below. - // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates - //let mut sum_norm_dv = 0.0; + // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates + //let mut sum_norm_dv = 0.0; let γ_prev_len = γ1.len(); assert!(μ.len() >= γ_prev_len); γ1.extend(μ[γ_prev_len..].iter().cloned()); @@ -149,7 +135,7 @@ } else { δ.α }; - }; + } // Calculate transport rays. match *θ_or_adaptive { @@ -158,15 +144,23 @@ for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); } - }, - AdaptiveMax{ l : ℓ_F, ref mut max_transport, g : ref calculate_θ } => { + } + AdaptiveMax { + l: ℓ_F, + ref mut max_transport, + g: ref calculate_θ, + } => { *max_transport = max_transport.max(γ1.norm(Radon)); let θτ = τ * calculate_θ(ℓ_F, *max_transport); for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) { ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ); } - }, - FullyAdaptive{ l : ref mut adaptive_ℓ_F, ref mut max_transport, g : ref calculate_θ } => { + } + FullyAdaptive { + l: ref mut adaptive_ℓ_F, + ref mut max_transport, + g: ref calculate_θ, + } => { *max_transport = max_transport.max(γ1.norm(Radon)); let mut θ = calculate_θ(*adaptive_ℓ_F, *max_transport); // Do two runs through the spikes to update θ, breaking if first run did not cause @@ -187,7 +181,7 @@ } } if !changes { - break + break; } } } @@ -203,24 +197,29 @@ } } // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b) - μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses()) - .map(|(&a,b)| a - b)); + μ_base_minus_γ0.set_masses( + μ_base_masses + .iter() + .zip(γ1.iter_masses()) + .map(|(&a, b)| a - b), + ); (μ_base_masses, μ_base_minus_γ0) } /// A posteriori transport adaptation. #[replace_float_literals(F::cast_from(literal))] -pub(crate) fn aposteriori_transport( - γ1 : &mut RNDM, - μ : &mut RNDM, - μ_base_minus_γ0 : &mut RNDM, - μ_base_masses : &Vec, - extra : Option, - ε : F, - tconfig : &TransportConfig +pub(crate) fn aposteriori_transport( + γ1: &mut RNDM, + μ: &mut RNDM, + μ_base_minus_γ0: &mut RNDM, + μ_base_masses: &Vec, + extra: Option, + ε: F, + tconfig: &TransportConfig, ) -> bool -where F : Float + ToNalgebraRealField { - +where + F: Float + ToNalgebraRealField, +{ // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not, // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1 // at that point to zero, and retry. @@ -238,19 +237,22 @@ let nγ = γ1.norm(Radon); let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1) + extra.unwrap_or(0.0); let t = ε * tconfig.tolerance_mult_con; - if nγ*nΔ > t { + if nγ * nΔ > t { // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, // this will guarantee that eventually ‖γ‖ decreases sufficiently that we // will not enter here. - *γ1 *= tconfig.adaptation * t / ( nγ * nΔ ); + *γ1 *= tconfig.adaptation * t / (nγ * nΔ); all_ok = false } if !all_ok { // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1} - μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses()) - .map(|(&a,b)| a - b)); - + μ_base_minus_γ0.set_masses( + μ_base_masses + .iter() + .zip(γ1.iter_masses()) + .map(|(&a, b)| a - b), + ); } all_ok @@ -262,29 +264,28 @@ /// The parametrisation is as for [`pointsource_fb_reg`]. /// Inertia is currently not supported. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_sliding_fb_reg( - opA : &A, - b : &A::Observable, - reg : Reg, - prox_penalty : &P, - config : &SlidingFBConfig, - iterator : I, - mut plotter : SeqPlotter, +pub fn pointsource_sliding_fb_reg( + opA: &A, + b: &A::Observable, + reg: Reg, + prox_penalty: &P, + config: &SlidingFBConfig, + iterator: I, + mut plotter: SeqPlotter, ) -> RNDM where - F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - A : ForwardModel, F> - + AdjointProductBoundedBy, P, FloatType=F> - + BoundedCurvature, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - A::PreadjointCodomain : DifferentiableRealMapping, - RNDM : SpikeMerging, - Reg : SlidingRegTerm, - P : ProxPenalty, - PlotLookup : Plotting, + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + A: ForwardModel, F> + + AdjointProductBoundedBy, P, FloatType = F> + + BoundedCurvature, + for<'b> &'b A::Observable: std::ops::Neg + Instance, + A::PreadjointCodomain: DifferentiableRealMapping, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, + P: ProxPenalty, + PlotLookup: Plotting, { - // Check parameters assert!(config.τ0 > 0.0, "Invalid step length parameter"); config.transport.check(); @@ -301,23 +302,23 @@ //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; let ℓ = 0.0; let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); - let (maybe_ℓ_v0, maybe_transport_lip) = opA.curvature_bound_components(); + let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); let transport_lip = maybe_transport_lip.unwrap(); - let calculate_θ = |ℓ_v, max_transport| { - let ℓ_F = ℓ_v + transport_lip * max_transport; - config.transport.θ0 / (τ*(ℓ + ℓ_F)) + let calculate_θ = |ℓ_F, max_transport| { + let ℓ_r = transport_lip * max_transport; + config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r)) }; - let mut θ_or_adaptive = match maybe_ℓ_v0 { - //Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * b.norm2(), 0.0)), - Some(ℓ_v0) => TransportStepLength::AdaptiveMax { - l: ℓ_v0 * b.norm2(), // TODO: could estimate computing the real reesidual - max_transport : 0.0, - g : calculate_θ + let mut θ_or_adaptive = match maybe_ℓ_F0 { + //Some(ℓ_F0) => TransportStepLength::Fixed(calculate_θ(ℓ_F0 * b.norm2(), 0.0)), + Some(ℓ_F0) => TransportStepLength::AdaptiveMax { + l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual + max_transport: 0.0, + g: calculate_θ, }, None => TransportStepLength::FullyAdaptive { - l : 10.0 * F::EPSILON, // Start with something very small to estimate differentials - max_transport : 0.0, - g : calculate_θ + l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials + max_transport: 0.0, + g: calculate_θ, }, }; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled @@ -326,14 +327,12 @@ let mut ε = tolerance.initial(); // Statistics - let full_stats = |residual : &A::Observable, - μ : &RNDM, - ε, stats| IterInfo { - value : residual.norm2_squared_div2() + reg.apply(μ), - n_spikes : μ.len(), + let full_stats = |residual: &A::Observable, μ: &RNDM, ε, stats| IterInfo { + value: residual.norm2_squared_div2() + reg.apply(μ), + n_spikes: μ.len(), ε, // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), - .. stats + ..stats }; let mut stats = IterInfo::new(); @@ -341,9 +340,8 @@ for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { // Calculate initial transport let v = opA.preadjoint().apply(residual); - let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport( - &mut γ1, &mut μ, τ, &mut θ_or_adaptive, v - ); + let (μ_base_masses, mut μ_base_minus_γ0) = + initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. @@ -354,18 +352,29 @@ // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( - &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), - τ, ε, &config.insertion, - ®, &state, &mut stats, + &mut μ, + &mut τv̆, + &γ1, + Some(&μ_base_minus_γ0), + τ, + ε, + &config.insertion, + ®, + &state, + &mut stats, ); // A posteriori transport adaptation. if aposteriori_transport( - &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, + &mut γ1, + &mut μ, + &mut μ_base_minus_γ0, + &μ_base_masses, None, - ε, &config.transport + ε, + &config.transport, ) { - break 'adapt_transport (maybe_d, within_tolerances, τv̆) + break 'adapt_transport (maybe_d, within_tolerances, τv̆); } }; @@ -387,8 +396,15 @@ let ins = &config.insertion; if ins.merge_now(&state) { stats.merged += prox_penalty.merge_spikes( - &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, ins, ®, - Some(|μ̃ : &RNDM| L2Squared.calculate_fit_op(μ̃, opA, b)), + &mut μ, + &mut τv̆, + &γ1, + Some(&μ_base_minus_γ0), + τ, + ε, + ins, + ®, + Some(|μ̃: &RNDM| L2Squared.calculate_fit_op(μ̃, opA, b)), ); } @@ -412,7 +428,12 @@ // Give statistics if requested state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); - full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) + full_stats( + &residual, + &μ, + ε, + std::mem::replace(&mut stats, IterInfo::new()), + ) }); // Update main tolerance for next iteration