Thu, 26 Feb 2026 13:05:07 -0500
Allow fitness merge when forward_pdps and sliding_pdps are used as forward-backward with aux variable.
/*! Solver for the point source localisation problem using a sliding forward-backward splitting method. */ use numeric_literals::replace_float_literals; use serde::{Deserialize, Serialize}; //use colored::Colorize; //use nalgebra::{DVector, DMatrix}; use itertools::izip; use std::iter::Iterator; use std::ops::MulAssign; use crate::fb::*; use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; use crate::measures::merging::SpikeMerging; use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon, RNDM}; use crate::plot::Plotter; use crate::prox_penalty::{ProxPenalty, StepLengthBound}; use crate::regularisation::SlidingRegTerm; use crate::types::*; use alg_tools::error::DynResult; use alg_tools::euclidean::Euclidean; use alg_tools::iterate::AlgIteratorFactory; use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping}; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::norms::Norm; use anyhow::ensure; use std::ops::ControlFlow; /// Transport settings for [`pointsource_sliding_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct TransportConfig<F: Float> { /// Transport step length $θ$ normalised to $(0, 1)$. pub θ0: F, /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance. pub adaptation: F, /// A posteriori transport tolerance multiplier (C_pos) pub tolerance_mult_con: F, /// maximum number of adaptation iterations, until cancelling transport. pub max_attempts: usize, } #[replace_float_literals(F::cast_from(literal))] impl<F: Float> TransportConfig<F> { /// Check that the parameters are ok. Panics if not. pub fn check(&self) -> DynResult<()> { ensure!(self.θ0 > 0.0); ensure!(0.0 < self.adaptation && self.adaptation < 1.0); ensure!(self.tolerance_mult_con > 0.0); Ok(()) } } #[replace_float_literals(F::cast_from(literal))] impl<F: Float> Default for TransportConfig<F> { fn default() -> Self { TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0, max_attempts: 2 } } } /// Settings for [`pointsource_sliding_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct SlidingFBConfig<F: Float> { /// Step length scaling pub τ0: F, // Auxiliary variable step length scaling for [`crate::sliding_pdps::pointsource_sliding_fb_pair`] pub σp0: F, /// Transport parameters pub transport: TransportConfig<F>, /// Generic parameters pub insertion: InsertionConfig<F>, /// Guess for curvature bound calculations. pub guess: BoundedCurvatureGuess, } #[replace_float_literals(F::cast_from(literal))] impl<F: Float> Default for SlidingFBConfig<F> { fn default() -> Self { SlidingFBConfig { τ0: 0.99, σp0: 0.99, transport: Default::default(), insertion: Default::default(), guess: BoundedCurvatureGuess::BetterThanZero, } } } /// Internal type of adaptive transport step length calculation pub(crate) enum TransportStepLength<F: Float, G: Fn(F, F) -> F> { /// Fixed, known step length #[allow(dead_code)] Fixed(F), /// Adaptive step length, only wrt. maximum transport. /// Content of `l` depends on use case, while `g` calculates the step length from `l`. AdaptiveMax { l: F, max_transport: F, g: G }, /// Adaptive step length. /// Content of `l` depends on use case, while `g` calculates the step length from `l`. FullyAdaptive { l: F, max_transport: F, g: G }, } #[derive(Clone, Debug, Serialize)] pub struct SingleTransport<const N: usize, F: Float> { /// Source point x: Loc<N, F>, /// Target point y: Loc<N, F>, /// Original mass α_μ_orig: F, /// Transported mass α_γ: F, /// Helper for pruning prune: bool, } #[derive(Clone, Debug, Serialize)] pub struct Transport<const N: usize, F: Float> { vec: Vec<SingleTransport<N, F>>, } /// Whether partiall transported points are allowed. /// /// Partial transport can cause spike count explosion, so full or zero /// transport is generally preferred. If this is set to `true`, different /// transport adaptation heuristics will be used. const ALLOW_PARTIAL_TRANSPORT: bool = true; const MINIMAL_PARTIAL_TRANSPORT: bool = true; impl<const N: usize, F: Float> Transport<N, F> { pub(crate) fn new() -> Self { Transport { vec: Vec::new() } } pub(crate) fn iter(&self) -> impl Iterator<Item = &'_ SingleTransport<N, F>> { self.vec.iter() } pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut SingleTransport<N, F>> { self.vec.iter_mut() } pub(crate) fn extend<I>(&mut self, it: I) where I: IntoIterator<Item = SingleTransport<N, F>>, { self.vec.extend(it) } pub(crate) fn len(&self) -> usize { self.vec.len() } // pub(crate) fn dist_matching(&self, μ: &RNDM<N, F>) -> F { // self.iter() // .zip(μ.iter_spikes()) // .map(|(ρ, δ)| (ρ.α_γ - δ.α).abs()) // .sum() // } /// Construct `μ̆`, replacing the contents of `μ`. #[replace_float_literals(F::cast_from(literal))] pub(crate) fn μ̆_into(&self, μ: &mut RNDM<N, F>) { assert!(self.len() <= μ.len()); // First transported points for (δ, ρ) in izip!(μ.iter_spikes_mut(), self.iter()) { if ρ.α_γ.abs() > 0.0 { // Transport – transported point δ.α = ρ.α_γ; δ.x = ρ.y; } else { // No transport – original point δ.α = ρ.α_μ_orig; δ.x = ρ.x; } } // Then source points with partial transport let mut i = self.len(); if ALLOW_PARTIAL_TRANSPORT { // This can cause the number of points to explode, so cannot have partial transport. for ρ in self.iter() { let α = ρ.α_μ_orig - ρ.α_γ; if ρ.α_γ.abs() > F::EPSILON && α != 0.0 { let δ = DeltaMeasure { α, x: ρ.x }; if i < μ.len() { μ[i] = δ; } else { μ.push(δ) } i += 1; } } } μ.truncate(i); } /// Constrution of initial transport `γ1` from initial measure `μ` and `v=F'(μ)` /// with step lengh τ and transport step length `θ_or_adaptive`. #[replace_float_literals(F::cast_from(literal))] pub(crate) fn initial_transport<G, D>( &mut self, μ: &RNDM<N, F>, _τ: F, τθ_or_adaptive: &mut TransportStepLength<F, G>, v: D, ) where G: Fn(F, F) -> F, D: DifferentiableRealMapping<N, F>, { use TransportStepLength::*; // Initialise transport structure weights for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) { ρ.α_μ_orig = δ.α; ρ.x = δ.x; // If old transport has opposing sign, the new transport will be none. ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) { 0.0 } else { δ.α } } let γ_prev_len = self.len(); assert!(μ.len() >= γ_prev_len); self.extend(μ[γ_prev_len..].iter().map(|δ| SingleTransport { x: δ.x, y: δ.x, // Just something, will be filled properly in the next phase α_μ_orig: δ.α, α_γ: δ.α, prune: false, })); // Calculate transport rays. match *τθ_or_adaptive { Fixed(θ) => { for ρ in self.iter_mut() { ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ); } } AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θτ } => { *max_transport = max_transport.max(self.norm(Radon)); let θτ = calculate_θτ(ℓ_F, *max_transport); for ρ in self.iter_mut() { ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ); } } FullyAdaptive { l: ref mut adaptive_ℓ_F, ref mut max_transport, g: ref calculate_θτ, } => { *max_transport = max_transport.max(self.norm(Radon)); let mut θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport); // Do two runs through the spikes to update θ, breaking if first run did not cause // a change. for _i in 0..=1 { let mut changes = false; for ρ in self.iter_mut() { let dv_x = v.differential(&ρ.x); let g = &dv_x * (ρ.α_γ.signum() * θτ); ρ.y = ρ.x - g; let n = g.norm2(); if n >= F::EPSILON { // Estimate Lipschitz factor of ∇v let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n; *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F); θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport); changes = true } } if !changes { break; } } } } } /// A posteriori transport adaptation. #[replace_float_literals(F::cast_from(literal))] pub(crate) fn aposteriori_transport<D>( &mut self, μ: &RNDM<N, F>, μ̆: &RNDM<N, F>, _v: &mut D, extra: Option<F>, ε: F, tconfig: &TransportConfig<F>, attempts: &mut usize, ) -> bool where D: DifferentiableRealMapping<N, F>, { *attempts += 1; // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not, // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1 // at that point to zero, and retry. let mut all_ok = true; for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) { if δ.α == 0.0 && ρ.α_γ != 0.0 { all_ok = false; ρ.α_γ = 0.0; } } // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z). // through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ̆^k // which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient. let nγ = self.norm(Radon); let nΔ = μ.dist_matching(&μ̆) + extra.unwrap_or(0.0); let t = ε * tconfig.tolerance_mult_con; if nγ * nΔ > t && *attempts >= tconfig.max_attempts { all_ok = false; } else if nγ * nΔ > t { // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1, // this will guarantee that eventually ‖γ‖ decreases sufficiently that we // will not enter here. //*self *= tconfig.adaptation * t / (nγ * nΔ); // We want a consistent behaviour that has the potential to set many weights to zero. // Therefore, we find the smallest uniform reduction `chg_one`, subtracted // from all weights, that achieves total `adapt` adaptation. let adapt_to = tconfig.adaptation * t / nΔ; let reduction_target = nγ - adapt_to; assert!(reduction_target > 0.0); if ALLOW_PARTIAL_TRANSPORT { if MINIMAL_PARTIAL_TRANSPORT { // This reduces weights of transport, starting from … until `adapt` is // exhausted. It will, therefore, only ever cause one extrap point insertion // at the sources, unlike “full” partial transport. //let refs = self.vec.iter_mut().collect::<Vec<_>>(); //refs.sort_by(|ρ1, ρ2| ρ1.α_γ.abs().partial_cmp(&ρ2.α_γ.abs()).unwrap()); // let mut it = refs.into_iter(); // // Maybe sort by differential norm // let mut refs = self // .vec // .iter_mut() // .map(|ρ| { // let val = v.differential(&ρ.x).norm2_squared(); // (ρ, val) // }) // .collect::<Vec<_>>(); // refs.sort_by(|(_, v1), (_, v2)| v2.partial_cmp(&v1).unwrap()); // let mut it = refs.into_iter().map(|(ρ, _)| ρ); let mut it = self.vec.iter_mut().rev(); let _unused = it.try_fold(reduction_target, |left, ρ| { let w = ρ.α_γ.abs(); if left <= w { ρ.α_γ = ρ.α_γ.signum() * (w - left); ControlFlow::Break(()) } else { ρ.α_γ = 0.0; ControlFlow::Continue(left - w) } }); } else { // This version equally reduces all weights. It causes partial transport, which // has the problem that that we need to then adapt weights in both start and // end points, in insert_and_reweigh, somtimes causing the number of spikes μ // to explode. let mut abs_weights = self .vec .iter() .map(|ρ| ρ.α_γ.abs()) .filter(|t| *t > F::EPSILON) .collect::<Vec<F>>(); abs_weights.sort_by(|a, b| a.partial_cmp(b).unwrap()); let n = abs_weights.len(); // Cannot have partial transport; can cause spike count explosion let chg = abs_weights.into_iter().zip((1..=n).rev()).try_fold( 0.0, |smaller_total, (w, m)| { let mf = F::cast_from(m); let reduction = w * mf + smaller_total; if reduction >= reduction_target { ControlFlow::Break((reduction_target - smaller_total) / mf) } else { ControlFlow::Continue(smaller_total + w) } }, ); match chg { ControlFlow::Continue(_) => self.vec.iter_mut().for_each(|δ| δ.α_γ = 0.0), ControlFlow::Break(chg_one) => self.vec.iter_mut().for_each(|ρ| { let t = ρ.α_γ.abs(); if t > 0.0 { if ALLOW_PARTIAL_TRANSPORT { let new = (t - chg_one).max(0.0); ρ.α_γ = ρ.α_γ.signum() * new; } } }), } } } else { // This version zeroes smallest weights, avoiding partial transport. let mut abs_weights_idx = self .vec .iter() .map(|ρ| ρ.α_γ.abs()) .zip(0..) .filter(|(w, _)| *w >= 0.0) .collect::<Vec<(F, usize)>>(); abs_weights_idx.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap()); let mut left = reduction_target; for (w, i) in abs_weights_idx { left -= w; let ρ = &mut self.vec[i]; ρ.α_γ = 0.0; if left < 0.0 { break; } } } all_ok = false } if !all_ok && *attempts >= tconfig.max_attempts { for ρ in self.iter_mut() { ρ.α_γ = 0.0; } } all_ok } /// Returns $‖μ\^k - π\_♯\^0γ\^{k+1}‖$ pub(crate) fn μ0_minus_γ0_radon(&self) -> F { self.vec.iter().map(|ρ| (ρ.α_μ_orig - ρ.α_γ).abs()).sum() } /// Returns $∫ c_2 d|γ|$ #[replace_float_literals(F::cast_from(literal))] pub(crate) fn c2integral(&self) -> F { self.vec .iter() .map(|ρ| ρ.y.dist2_squared(&ρ.x) / 2.0 * ρ.α_γ.abs()) .sum() } #[replace_float_literals(F::cast_from(literal))] pub(crate) fn get_transport_stats(&self, stats: &mut IterInfo<F>, μ: &RNDM<N, F>) { // TODO: This doesn't take into account μ[i].α becoming zero in the latest tranport // attempt, for i < self.len(), when a corresponding source term also exists with index // j ≥ self.len(). For now, we let that be reflected in the prune count. stats.inserted += μ.len() - self.len(); let transp = stats.get_transport_mut(); transp.dist = { let (a, b) = transp.dist; (a + self.c2integral(), b + self.norm(Radon)) }; transp.untransported_fraction = { let (a, b) = transp.untransported_fraction; let source = self.iter().map(|ρ| ρ.α_μ_orig.abs()).sum(); (a + self.μ0_minus_γ0_radon(), b + source) }; transp.transport_error = { let (a, b) = transp.transport_error; //(a + self.dist_matching(&μ), b + self.norm(Radon)) // This ignores points that have been not transported at all, to only calculate // destnation error; untransported_fraction accounts for not being able to transport // at all. self.iter() .zip(μ.iter_spikes()) .fold((a, b), |(a, b), (ρ, δ)| { let transported = ρ.α_γ.abs(); if transported > F::EPSILON { (a + (ρ.α_γ - δ.α).abs(), b + transported) } else { (a, b) } }) }; } /// Prune spikes with zero weight. To maintain correct ordering between μ and γ, also the /// latter needs to be pruned when μ is. pub(crate) fn prune_compat(&mut self, μ: &mut RNDM<N, F>, stats: &mut IterInfo<F>) { assert!(self.vec.len() <= μ.len()); let old_len = μ.len(); for (ρ, δ) in self.vec.iter_mut().zip(μ.iter_spikes()) { ρ.prune = !(δ.α.abs() > F::EPSILON); } μ.prune_by(|δ| δ.α.abs() > F::EPSILON); stats.pruned += old_len - μ.len(); self.vec.retain(|ρ| !ρ.prune); assert!(self.vec.len() <= μ.len()); } } impl<const N: usize, F: Float> Norm<Radon, F> for Transport<N, F> { fn norm(&self, _: Radon) -> F { self.iter().map(|ρ| ρ.α_γ.abs()).sum() } } impl<const N: usize, F: Float> MulAssign<F> for Transport<N, F> { fn mul_assign(&mut self, factor: F) { for ρ in self.iter_mut() { ρ.α_γ *= factor; } } } /// Iteratively solve the pointsource localisation problem using sliding forward-backward /// splitting /// /// The parametrisation is as for [`pointsource_fb_reg`]. /// Inertia is currently not supported. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_sliding_fb_reg<F, I, Dat, Reg, Plot, P, const N: usize>( f: &Dat, reg: &Reg, prox_penalty: &P, config: &SlidingFBConfig<F>, iterator: I, mut plotter: Plot, μ0: Option<RNDM<N, F>>, ) -> DynResult<RNDM<N, F>> where F: Float + ToNalgebraRealField, I: AlgIteratorFactory<IterInfo<F>>, Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>, Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>, //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>, RNDM<N, F>: SpikeMerging<F>, Reg: SlidingRegTerm<Loc<N, F>, F>, P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, { // Check parameters ensure!(config.τ0 > 0.0, "Invalid step length parameter"); config.transport.check()?; // Initialise iterates let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); let mut γ = Transport::new(); // Set up parameters // let opAnorm = opA.opnorm_bound(Radon, L2); //let max_transport = config.max_transport.scale // * reg.radon_norm_bound(b.norm2_squared() / 2.0); //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; let ℓ = 0.0; let τ = config.τ0 / prox_penalty.step_length_bound(&f)?; let mut θ_or_adaptive = match f.curvature_bound_components(config.guess) { (_, Err(_)) => TransportStepLength::Fixed(config.transport.θ0), (maybe_ℓ_F, Ok(transport_lip)) => { let calculate_θτ = move |ℓ_F, max_transport| { let ℓ_r = transport_lip * max_transport; config.transport.θ0 / (ℓ + ℓ_F + ℓ_r) }; match maybe_ℓ_F { Ok(ℓ_F) => TransportStepLength::AdaptiveMax { l: ℓ_F, // TODO: could estimate computing the real reesidual max_transport: 0.0, g: calculate_θτ, }, Err(_) => TransportStepLength::FullyAdaptive { l: 10.0 * F::EPSILON, // Start with something very small to estimate differentials max_transport: 0.0, g: calculate_θτ, }, } } }; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling(); let mut ε = tolerance.initial(); // Statistics let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo { value: f.apply(μ) + reg.apply(μ), n_spikes: μ.len(), ε, // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), ..stats }; let mut stats = IterInfo::new(); // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate initial transport let v = f.differential(&μ); γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v); let mut attempts = 0; // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. let (maybe_d, _within_tolerances, mut τv̆, μ̆) = 'adapt_transport: loop { // Set initial guess for μ=μ^{k+1}. γ.μ̆_into(&mut μ); let μ̆ = μ.clone(); // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) //let residual_μ̆ = calculate_residual2(&γ1, &μ0_minus_γ0, opA, b); // TODO: this could be optimised by doing the differential like the // old residual2. // NOTE: This assumes that μ = γ1 let mut τv̆ = f.differential(&μ̆) * τ; // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( &mut μ, &mut τv̆, τ, ε, &config.insertion, ®, &state, &mut stats, )?; // A posteriori transport adaptation. if γ.aposteriori_transport(&μ, &μ̆, &mut τv̆, None, ε, &config.transport, &mut attempts) { break 'adapt_transport (maybe_d, within_tolerances, τv̆, μ̆); } stats.get_transport_mut().readjustment_iters += 1; }; γ.get_transport_stats(&mut stats, &μ); // Merge spikes. // This crucially expects the merge routine to be stable with respect to spike locations, // and not to performing any pruning. That is be to done below simultaneously for γ. if config.insertion.merge_now(&state) { stats.merged += prox_penalty.merge_spikes( &mut μ, &mut τv̆, &μ̆, τ, ε, &config.insertion, ®, Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)), ); } γ.prune_compat(&mut μ, &mut stats); let iter = state.iteration(); stats.this_iters += 1; // Give statistics if requested state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) }); // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } //postprocess(μ, &config.insertion, f) postprocess(μ, &config.insertion, |μ̃| f.apply(μ̃)) }