diff -r 4f468d35fa29 -r 7a8a55fd41c0 src/sliding_pdps.rs --- a/src/sliding_pdps.rs Thu Feb 26 11:38:43 2026 -0500 +++ b/src/sliding_pdps.rs Thu Feb 26 11:36:22 2026 -0500 @@ -6,13 +6,11 @@ use crate::fb::*; use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; use crate::measures::merging::SpikeMerging; -use crate::measures::{DiscreteMeasure, Radon, RNDM}; +use crate::measures::{DiscreteMeasure, RNDM}; use crate::plot::Plotter; use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair}; use crate::regularisation::SlidingRegTerm; -use crate::sliding_fb::{ - aposteriori_transport, initial_transport, SlidingFBConfig, TransportConfig, TransportStepLength, -}; +use crate::sliding_fb::{SlidingFBConfig, Transport, TransportConfig, TransportStepLength}; use crate::types::*; use alg_tools::convex::{Conjugable, Prox, Zero}; use alg_tools::direct_product::Pair; @@ -24,13 +22,10 @@ }; use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance}; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::{Norm, L2}; +use alg_tools::norms::L2; use anyhow::ensure; use numeric_literals::replace_float_literals; use serde::{Deserialize, Serialize}; -//use colored::Colorize; -//use nalgebra::{DVector, DMatrix}; -use std::iter::Iterator; /// Settings for [`pointsource_sliding_pdps_pair`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] @@ -148,7 +143,7 @@ // Initialise iterates let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); - let mut γ1 = DiscreteMeasure::new(); + let mut γ = Transport::new(); //let zero_z = z.similar_origin(); // Set up parameters @@ -186,22 +181,25 @@ let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); // The factor two in the manuscript disappears due to the definition of 𝚹 being // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. - let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); - let transport_lip = maybe_transport_lip?; - let calculate_θ = |ℓ_F, max_transport| { - let ℓ_r = transport_lip * max_transport; - config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport) - }; - let mut θ_or_adaptive = match maybe_ℓ_F { - // We assume that the residual is decreasing. - Ok(ℓ_F) => TransportStepLength::AdaptiveMax { - l: ℓ_F, // TODO: could estimate computing the real reesidual - max_transport: 0.0, - g: calculate_θ, - }, - Err(_) => { - TransportStepLength::FullyAdaptive { - l: F::EPSILON, max_transport: 0.0, g: calculate_θ + + let mut θ_or_adaptive = match f.curvature_bound_components(config.guess) { + (_, Err(_)) => TransportStepLength::Fixed(config.transport.θ0), + (maybe_ℓ_F, Ok(transport_lip)) => { + let calculate_θτ = move |ℓ_F, max_transport| { + let ℓ_r = transport_lip * max_transport; + config.transport.θ0 / ((ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport / τ) + }; + match maybe_ℓ_F { + Ok(ℓ_F) => TransportStepLength::AdaptiveMax { + l: ℓ_F, // TODO: could estimate computing the real reesidual + max_transport: 0.0, + g: calculate_θτ, + }, + Err(_) => TransportStepLength::FullyAdaptive { + l: F::EPSILON, // Start with something very small to estimate differentials + max_transport: 0.0, + g: calculate_θτ, + }, } } }; @@ -243,26 +241,25 @@ //dbg!(&μ); - let (μ_base_masses, mut μ_base_minus_γ0) = - initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); + γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v); + + let mut attempts = 0; // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. - let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { + let (maybe_d, _within_tolerances, mut τv̆, z_new, μ̆) = 'adapt_transport: loop { + // Set initial guess for μ=μ^{k+1}. + γ.μ̆_into(&mut μ); + let μ̆ = μ.clone(); + // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) - // let residual_μ̆ = - // calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); - // let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); - // TODO: might be able to optimise the measure sum working as calculate_residual2 above. - let Pair(mut τv̆, τz̆) = f.differential(Pair(&γ1 + &μ_base_minus_γ0, &z)) * τ; + let Pair(mut τv̆, τz̆) = f.differential(Pair(&μ̆, &z)) * τ; // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( &mut μ, &mut τv̆, - &γ1, - Some(&μ_base_minus_γ0), τ, ε, &config.insertion, @@ -277,59 +274,37 @@ z_new = fnR.prox(σ_p, z_new + &z); // A posteriori transport adaptation. - if aposteriori_transport( - &mut γ1, - &mut μ, - &mut μ_base_minus_γ0, - &μ_base_masses, + if γ.aposteriori_transport( + &μ, + &μ̆, + &mut τv̆, Some(z_new.dist2(&z)), ε, &config.transport, + &mut attempts, ) { - break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new); + break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new, μ̆); } }; - stats.untransported_fraction = Some({ - assert_eq!(μ_base_masses.len(), γ1.len()); - let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0)); - let source = μ_base_masses.iter().map(|v| v.abs()).sum(); - (a + μ_base_minus_γ0.norm(Radon), b + source) - }); - stats.transport_error = Some({ - assert_eq!(μ_base_masses.len(), γ1.len()); - let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); - (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) - }); + γ.get_transport_stats(&mut stats, &μ); // Merge spikes. // This crucially expects the merge routine to be stable with respect to spike locations, // and not to performing any pruning. That is be to done below simultaneously for γ. - let ins = &config.insertion; - if ins.merge_now(&state) { + if config.insertion.merge_now(&state) { stats.merged += prox_penalty.merge_spikes_no_fitness( &mut μ, &mut τv̆, - &γ1, - Some(&μ_base_minus_γ0), + &μ̆, τ, ε, - ins, + &config.insertion, ®, - //Some(|μ̃ : &RNDM| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), ); } - // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the - // latter needs to be pruned when μ is. - // TODO: This could do with a two-vector Vec::retain to avoid copies. - let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); - if μ_new.len() != μ.len() { - let mut μ_iter = μ.iter_spikes(); - γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO); - stats.pruned += μ.len() - μ_new.len(); - μ = μ_new; - } + γ.prune_compat(&mut μ, &mut stats); // Do dual update // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}]