--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/sliding_fb.rs Tue Dec 31 09:34:24 2024 -0500 @@ -0,0 +1,583 @@ +/*! +Solver for the point source localisation problem using a sliding +forward-backward splitting method. +*/ + +use numeric_literals::replace_float_literals; +use serde::{Serialize, Deserialize}; +//use colored::Colorize; +//use nalgebra::{DVector, DMatrix}; +use itertools::izip; +use std::iter::{Map, Flatten}; + +use alg_tools::iterate::{ + AlgIteratorFactory, + AlgIteratorState +}; +use alg_tools::euclidean::{ + Euclidean, + Dot +}; +use alg_tools::sets::Cube; +use alg_tools::loc::Loc; +use alg_tools::mapping::{Apply, Differentiable}; +use alg_tools::bisection_tree::{ + BTFN, + PreBTFN, + Bounds, + BTNodeLookup, + BTNode, + BTSearch, + P2Minimise, + SupportGenerator, + LocalAnalysis, + //Bounded, +}; +use alg_tools::mapping::RealMapping; +use alg_tools::nalgebra_support::ToNalgebraRealField; + +use crate::types::*; +use crate::measures::{ + DiscreteMeasure, + DeltaMeasure, +}; +use crate::measures::merging::{ + //SpikeMergingMethod, + SpikeMerging, +}; +use crate::forward_model::ForwardModel; +use crate::seminorms::DiscreteMeasureOp; +//use crate::tolerance::Tolerance; +use crate::plot::{ + SeqPlotter, + Plotting, + PlotLookup +}; +use crate::fb::*; +use crate::regularisation::SlidingRegTerm; +use crate::dataterm::{ + L2Squared, + //DataTerm, + calculate_residual, + calculate_residual2, +}; +use crate::transport::TransportLipschitz; + +/// Settings for [`pointsource_sliding_fb_reg`]. +#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] +#[serde(default)] +pub struct SlidingFBConfig<F : Float> { + /// Step length scaling + pub τ0 : F, + /// Transport smoothness assumption + pub ℓ0 : F, + /// Inverse of the scaling factor $θ$ of the 2-norm-squared transport cost. + /// This means that $τθ$ is the step length for the transport step. + pub inverse_transport_scaling : F, + /// Factor for deciding transport reduction based on smoothness assumption violation + pub minimum_goodness_factor : F, + /// Maximum rays to retain in transports from each source. + pub maximum_rays : usize, + /// Generic parameters + pub insertion : FBGenericConfig<F>, +} + +#[replace_float_literals(F::cast_from(literal))] +impl<F : Float> Default for SlidingFBConfig<F> { + fn default() -> Self { + SlidingFBConfig { + τ0 : 0.99, + ℓ0 : 1.5, + inverse_transport_scaling : 1.0, + minimum_goodness_factor : 1.0, // TODO: totally arbitrary choice, + // should be scaled by problem data? + maximum_rays : 10, + insertion : Default::default() + } + } +} + +/// A transport ray (including various additional computational information). +#[derive(Clone, Debug)] +pub struct Ray<Domain, F : Num> { + /// The destination of the ray, and the mass. The source is indicated in a [`RaySet`]. + δ : DeltaMeasure<Domain, F>, + /// Goodness of the data term for the aray: $v(z)-v(y)-⟨∇v(x), z-y⟩ + ℓ‖z-y‖^2$. + goodness : F, + /// Goodness of the regularisation term for the ray: $w(z)-w(y)$. + /// Initially zero until $w$ can be constructed. + reg_goodness : F, + /// Indicates that this ray also forms a component in γ^{k+1} with the mass `to_return`. + to_return : F, +} + +/// A set of transport rays with the same source point. +#[derive(Clone, Debug)] +pub struct RaySet<Domain, F : Num> { + /// Source of every ray in thset + source : Domain, + /// Mass of the diagonal ray, with destination the same as the source. + diagonal: F, + /// Goodness of the data term for the diagonal ray with $z=x$: + /// $v(x)-v(y)-⟨∇v(x), x-y⟩ + ℓ‖x-y‖^2$. + diagonal_goodness : F, + /// Goodness of the data term for the diagonal ray with $z=x$: $w(x)-w(y)$. + diagonal_reg_goodness : F, + /// The non-diagonal rays. + rays : Vec<Ray<Domain, F>>, +} + +#[replace_float_literals(F::cast_from(literal))] +impl<Domain, F : Float> RaySet<Domain, F> { + fn non_diagonal_mass(&self) -> F { + self.rays + .iter() + .map(|Ray{ δ : DeltaMeasure{ α, .. }, .. }| *α) + .sum() + } + + fn total_mass(&self) -> F { + self.non_diagonal_mass() + self.diagonal + } + + fn targets<'a>(&'a self) + -> Map< + std::slice::Iter<'a, Ray<Domain, F>>, + fn(&'a Ray<Domain, F>) -> &'a DeltaMeasure<Domain, F> + > { + fn get_δ<'b, Domain, F : Float>(Ray{ δ, .. }: &'b Ray<Domain, F>) + -> &'b DeltaMeasure<Domain, F> { + δ + } + self.rays + .iter() + .map(get_δ) + } + + // fn non_diagonal_goodness(&self) -> F { + // self.rays + // .iter() + // .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| { + // α * (goodness + reg_goodness) + // }) + // .sum() + // } + + // fn total_goodness(&self) -> F { + // self.non_diagonal_goodness() + (self.diagonal_goodness + self.diagonal_reg_goodness) + // } + + fn non_diagonal_badness(&self) -> F { + self.rays + .iter() + .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| { + 0.0.max(- α * (goodness + reg_goodness)) + }) + .sum() + } + + fn total_badness(&self) -> F { + self.non_diagonal_badness() + + 0.0.max(- self.diagonal * (self.diagonal_goodness + self.diagonal_reg_goodness)) + } + + fn total_return(&self) -> F { + self.rays + .iter() + .map(|&Ray{ to_return, .. }| to_return) + .sum() + } +} + +#[replace_float_literals(F::cast_from(literal))] +impl<Domain : Clone, F : Num> RaySet<Domain, F> { + fn return_targets<'a>(&'a self) + -> Flatten<Map< + std::slice::Iter<'a, Ray<Domain, F>>, + fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>> + >> { + fn get_return<'b, Domain : Clone, F : Num>(ray: &'b Ray<Domain, F>) + -> Option<DeltaMeasure<Domain, F>> { + (ray.to_return != 0.0).then_some( + DeltaMeasure{x : ray.δ.x.clone(), α : ray.to_return} + ) + } + let tmp : Map< + std::slice::Iter<'a, Ray<Domain, F>>, + fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>> + > = self.rays + .iter() + .map(get_return); + tmp.flatten() + } +} + +/// Iteratively solve the pointsource localisation problem using sliding forward-backward +/// splitting +/// +/// The parametrisatio is as for [`pointsource_fb_reg`]. +/// Inertia is currently not supported. +#[replace_float_literals(F::cast_from(literal))] +pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>( + opA : &'a A, + b : &A::Observable, + reg : Reg, + op𝒟 : &'a 𝒟, + sfbconfig : &SlidingFBConfig<F>, + iterator : I, + mut plotter : SeqPlotter<F, N>, +) -> DiscreteMeasure<Loc<F, N>, F> +where F : Float + ToNalgebraRealField, + I : AlgIteratorFactory<IterInfo<F, N>>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, + //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow + A::Observable : std::ops::MulAssign<F>, + A::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output=Loc<F, N>>, + GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, + A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> + + Lipschitz<&'a 𝒟, FloatType=F> + TransportLipschitz<L2Squared, FloatType=F>, + BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, + G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, + 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, + 𝒟::Codomain : RealMapping<F, N>, + S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> + + Differentiable<Loc<F, N>, Output=Loc<F,N>>, + K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, + //+ Differentiable<Loc<F, N>, Output=Loc<F,N>>, + BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, + Cube<F, N>: P2Minimise<Loc<F, N>, F>, + PlotLookup : Plotting<N>, + DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, + Reg : SlidingRegTerm<F, N> { + + assert!(sfbconfig.τ0 > 0.0 && + sfbconfig.inverse_transport_scaling > 0.0 && + sfbconfig.ℓ0 > 0.0); + + // Set up parameters + let config = &sfbconfig.insertion; + let op𝒟norm = op𝒟.opnorm_bound(); + let θ = sfbconfig.inverse_transport_scaling; + let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap() + .max(opA.transport_lipschitz_factor(L2Squared) * θ); + let ℓ = sfbconfig.ℓ0; // TODO: v scaling? + // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled + // by τ compared to the conditional gradient approach. + let tolerance = config.tolerance * τ * reg.tolerance_scaling(); + let mut ε = tolerance.initial(); + + // Initialise iterates + let mut μ : DiscreteMeasure<Loc<F, N>, F> = DiscreteMeasure::new(); + let mut μ_transported_base = DiscreteMeasure::new(); + let mut γ_hat : Vec<RaySet<Loc<F, N>, F>> = Vec::new(); // γ̂_k and extra info + let mut residual = -b; + let mut stats = IterInfo::new(); + + // Run the algorithm + iterator.iterate(|state| { + // Calculate smooth part of surrogate model. + // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` + // has no significant overhead. For some reosn Rust doesn't allow us simply moving + // the residual and replacing it below before the end of this closure. + residual *= -τ; + let r = std::mem::replace(&mut residual, opA.empty_observable()); + let minus_τv = opA.preadjoint().apply(r); + + // Save current base point and shift μ to new positions. + let μ_base = μ.clone(); + for δ in μ.iter_spikes_mut() { + δ.x += minus_τv.differential(&δ.x) * θ; + } + let mut μ_transported = μ.clone(); + + assert_eq!(μ.len(), γ_hat.len()); + + // Calculate the goodness λ formed from γ_hat (≈ γ̂_k) and γ^{k+1}, where the latter + // transports points x from μ_base to points y in μ as shifted above, or “returns” + // them “home” to z given by the rays in γ_hat. Returning is necessary if the rays + // are not “good” for the smoothness assumptions, or if γ_hat has more mass than + // μ_base. + let mut total_goodness = 0.0; // data term goodness + let mut total_reg_goodness = 0.0; // regulariser goodness + let minimum_goodness = - ε * sfbconfig.minimum_goodness_factor; + + for (δ, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { + // Calculate data term goodness for all rays. + let &DeltaMeasure{ x : ref y, α : δ_mass } = δ; + let x = &r.source; + let mvy = minus_τv.apply(y); + let mdvx = minus_τv.differential(x); + let mut r_total_mass = 0.0; // Total mass of all rays with source r.source. + let mut bad_mass = 0.0; + let mut calc_goodness = |goodness : &mut F, reg_goodness : &mut F, α, z : &Loc<F, N>| { + *reg_goodness = 0.0; // Initial guess + *goodness = mvy - minus_τv.apply(z) + mdvx.dot(&(z-y)) + + ℓ * z.dist2_squared(&y); + total_goodness += *goodness * α; + r_total_mass += α; // TODO: should this include to_return from staging? (Probably not) + if *goodness < 0.0 { + bad_mass += α; + } + }; + for ray in r.rays.iter_mut() { + calc_goodness(&mut ray.goodness, &mut ray.reg_goodness, ray.δ.α, &ray.δ.x); + } + calc_goodness(&mut r.diagonal_goodness, &mut r.diagonal_reg_goodness, r.diagonal, x); + + // If the total mass of the ray set is less than that of μ at the same source, + // a diagonal component needs to be added to be able to (attempt to) transport + // all mass of μ. In the opposite case, we need to construct γ_{k+1} to ‘return’ + // the the extra mass of γ̂_k to the target z. We return mass from the oldest “bad” + // rays in the set. + if δ_mass >= r_total_mass { + r.diagonal += δ_mass - r_total_mass; + } else { + let mut reduce_transport = r_total_mass - δ_mass; + let mut good_needed = (bad_mass - reduce_transport).max(0.0); + // NOTE: reg_goodness is zero at this point, so it is not used in this code. + let mut reduce_ray = |goodness, to_return : Option<&mut F>, α : &mut F| { + if reduce_transport > 0.0 { + let return_amount = if goodness < 0.0 { + α.min(reduce_transport) + } else { + let amount = α.min(good_needed); + good_needed -= amount; + amount + }; + + if return_amount > 0.0 { + reduce_transport -= return_amount; + // Adjust total goodness by returned amount + total_goodness -= goodness * return_amount; + to_return.map(|tr| *tr += return_amount); + *α -= return_amount; + *α > 0.0 + } else { + true + } + } else { + true + } + }; + r.rays.retain_mut(|ray| { + reduce_ray(ray.goodness, Some(&mut ray.to_return), &mut ray.δ.α) + }); + // A bad diagonal is simply reduced without any 'return'. + // It was, after all, just added to match μ, but there is no need to match it. + // It's just a heuristic. + // TODO: Maybe a bad diagonal should be the first to go. + reduce_ray(r.diagonal_goodness, None, &mut r.diagonal); + } + } + + // Solve finite-dimensional subproblem several times until the dual variable for the + // regularisation term conforms to the assumptions made for the transport above. + let (d, within_tolerances) = 'adapt_transport: loop { + // If transport violates goodness requirements, shift it to ‘return’ mass to z, + // forcing y = z. Based on the badness of each ray set (sum of bad rays' goodness), + // we proportionally distribute the reductions to each ray set, and within each ray + // set, prioritise reducing the oldest bad rays' weight. + let tg = total_goodness + total_reg_goodness; + let adaptation_needed = minimum_goodness - tg; + if adaptation_needed > 0.0 { + let total_badness = γ_hat.iter().map(|r| r.total_badness()).sum(); + + let mut return_ray = |goodness : F, + reg_goodness : F, + to_return : Option<&mut F>, + α : &mut F, + left_to_return : &mut F| { + let g = goodness + reg_goodness; + assert!(*α >= 0.0 && *left_to_return >= 0.0); + if *left_to_return > 0.0 && g < 0.0 { + let return_amount = (*left_to_return / (-g)).min(*α); + *left_to_return -= (-g) * return_amount; + total_goodness -= goodness * return_amount; + total_reg_goodness -= reg_goodness * return_amount; + to_return.map(|tr| *tr += return_amount); + *α -= return_amount; + *α > 0.0 + } else { + true + } + }; + + for r in γ_hat.iter_mut() { + let mut left_to_return = adaptation_needed * r.total_badness() / total_badness; + if left_to_return > 0.0 { + for ray in r.rays.iter_mut() { + return_ray(ray.goodness, ray.reg_goodness, + Some(&mut ray.to_return), &mut ray.δ.α, &mut left_to_return); + } + return_ray(r.diagonal_goodness, r.diagonal_reg_goodness, + None, &mut r.diagonal, &mut left_to_return); + } + } + } + + // Construct μ_k + (π_#^1-π_#^0)γ_{k+1}. + // This can be broken down into + // + // μ_transported_base = [μ - π_#^0 (γ_shift + γ_return)] + π_#^1 γ_return, and + // μ_transported = π_#^1 γ_shift + // + // where γ_shift is our “true” γ_{k+1}, and γ_return is the return compoennt. + // The former can be constructed from δ.x and δ_new.x for δ in μ_base and δ_new in μ + // (which has already been shifted), and the mass stored in a γ_hat ray's δ measure + // The latter can be constructed from γ_hat rays' source and destination with the + // to_return mass. + // + // Note that μ_transported is constructed to have the same spike locations as μ, but + // to have same length as μ_base. This loop does not iterate over the spikes of μ + // (and corresponding transports of γ_hat) that have been newly added in the current + // 'adapt_transport loop. + for (δ, δ_transported, r) in izip!(μ_base.iter_spikes(), + μ_transported.iter_spikes_mut(), + γ_hat.iter()) { + let &DeltaMeasure{ref x, α} = δ; + debug_assert_eq!(*x, r.source); + let shifted_mass = r.total_mass(); + let ret_mass = r.total_return(); + // μ - π_#^0 (γ_shift + γ_return) + μ_transported_base += DeltaMeasure { x : *x, α : α - shifted_mass - ret_mass }; + // π_#^1 γ_return + μ_transported_base.extend(r.return_targets()); + // π_#^1 γ_shift + δ_transported.set_mass(shifted_mass); + } + // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b) + let transported_residual = calculate_residual2(&μ_transported, + &μ_transported_base, + opA, b); + let transported_minus_τv = opA.preadjoint() + .apply(transported_residual); + + // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. + let (mut d, within_tolerances) = insert_and_reweigh( + &mut μ, &transported_minus_τv, &μ_transported, Some(&μ_transported_base), + op𝒟, op𝒟norm, + τ, ε, + config, ®, state, &mut stats + ); + + // We have d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv; more precisely + // d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_transported, config)); + // We “essentially” assume that the subdifferential w of the regularisation term + // satisfies w'(y)=0, so for a “goodness” estimate τ[w(y)-w(z)-w'(y)(z-y)] + // that incorporates the assumption, we need to calculate τ[w(z) - w(y)] for + // some w in the subdifferential of the regularisation term, such that + // -ε ≤ τw - d ≤ ε. This is done by [`RegTerm::goodness`]. + for r in γ_hat.iter_mut() { + for ray in r.rays.iter_mut() { + ray.reg_goodness = reg.goodness(&mut d, &μ, &r.source, &ray.δ.x, τ, ε, config); + total_reg_goodness += ray.reg_goodness * ray.δ.α; + } + } + + // If update of regularisation term goodness didn't invalidate minimum goodness + // requirements, we have found our step. Otherwise we need to keep reducing + // transport by repeating the loop. + if total_goodness + total_reg_goodness >= minimum_goodness { + break 'adapt_transport (d, within_tolerances) + } + }; + + // Update γ_hat to new location + for (δ_new, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { + // Prune rays that only had a return component, as the return component becomes + // a diagonal in γ̂^{k+1}. + r.rays.retain(|ray| ray.δ.α != 0.0); + // Otherwise zero out the return component, or stage rays for pruning + // to keep memory and computational demands reasonable. + let n_rays = r.rays.len(); + for (ray, ir) in izip!(r.rays.iter_mut(), (0..n_rays).rev()) { + if ir >= sfbconfig.maximum_rays { + // Only keep sfbconfig.maximum_rays - 1 previous rays, staging others for + // pruning in next step. + ray.to_return = ray.δ.α; + ray.δ.α = 0.0; + } else { + ray.to_return = 0.0; + } + ray.goodness = 0.0; // TODO: probably not needed + ray.reg_goodness = 0.0; + } + // Add a new ray for the currently diagonal component + if r.diagonal > 0.0 { + r.rays.push(Ray{ + δ : DeltaMeasure{x : r.source, α : r.diagonal}, + goodness : 0.0, + reg_goodness : 0.0, + to_return : 0.0, + }); + // TODO: Maybe this does not need to be done here, and is sufficent to to do where + // the goodness is calculated. + r.diagonal = 0.0; + } + r.diagonal_goodness = 0.0; + + // Shift source + r.source = δ_new.x; + } + // Extend to new spikes + γ_hat.extend(μ[γ_hat.len()..].iter().map(|δ_new| { + RaySet{ + source : δ_new.x, + rays : [].into(), + diagonal : 0.0, + diagonal_goodness : 0.0, + diagonal_reg_goodness : 0.0 + } + })); + + // Prune spikes with zero weight. This also moves the marginal differences of corresponding + // transports from γ_hat to γ_pruned_marginal_diff. + // TODO: optimise standard prune with swap_remove. + μ_transported_base.clear(); + let mut i = 0; + assert_eq!(μ.len(), γ_hat.len()); + while i < μ.len() { + if μ[i].α == F::ZERO { + μ.swap_remove(i); + let r = γ_hat.swap_remove(i); + μ_transported_base.extend(r.targets().cloned()); + μ_transported_base -= DeltaMeasure{ α : r.non_diagonal_mass(), x : r.source }; + } else { + i += 1; + } + } + + // TODO: how to merge? + + // Update residual + residual = calculate_residual(&μ, opA, b); + + // Update main tolerance for next iteration + let ε_prev = ε; + ε = tolerance.update(ε, state.iteration()); + stats.this_iters += 1; + + // Give function value if needed + state.if_verbose(|| { + // Plot if so requested + plotter.plot_spikes( + format!("iter {} end; {}", state.iteration(), within_tolerances), &d, + "start".to_string(), Some(&minus_τv), + reg.target_bounds(τ, ε_prev), &μ, + ); + // Calculate mean inner iterations and reset relevant counters. + // Return the statistics + let res = IterInfo { + value : residual.norm2_squared_div2() + reg.apply(&μ), + n_spikes : μ.len(), + ε : ε_prev, + postprocessing: config.postprocessing.then(|| μ.clone()), + .. stats + }; + stats = IterInfo::new(); + res + }) + }); + + postprocess(μ, config, L2Squared, opA, b) +}