Tue, 01 Aug 2023 10:32:12 +0300
merge
/*! Solver for the point source localisation problem using a sliding forward-backward splitting method. */ use numeric_literals::replace_float_literals; use serde::{Serialize, Deserialize}; //use colored::Colorize; //use nalgebra::{DVector, DMatrix}; use itertools::izip; use std::iter::{Map, Flatten}; use alg_tools::iterate::{ AlgIteratorFactory, AlgIteratorState }; use alg_tools::euclidean::{ Euclidean, Dot }; use alg_tools::sets::Cube; use alg_tools::loc::Loc; use alg_tools::mapping::{Apply, Differentiable}; use alg_tools::bisection_tree::{ BTFN, PreBTFN, Bounds, BTNodeLookup, BTNode, BTSearch, P2Minimise, SupportGenerator, LocalAnalysis, //Bounded, }; use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use crate::types::*; use crate::measures::{ DiscreteMeasure, DeltaMeasure, }; use crate::measures::merging::{ //SpikeMergingMethod, SpikeMerging, }; use crate::forward_model::ForwardModel; use crate::seminorms::DiscreteMeasureOp; //use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, Plotting, PlotLookup }; use crate::fb::*; use crate::regularisation::SlidingRegTerm; use crate::dataterm::{ L2Squared, //DataTerm, calculate_residual, calculate_residual2, }; use crate::transport::TransportLipschitz; /// Settings for [`pointsource_sliding_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct SlidingFBConfig<F : Float> { /// Step length scaling pub τ0 : F, /// Transport smoothness assumption pub ℓ0 : F, /// Inverse of the scaling factor $θ$ of the 2-norm-squared transport cost. /// This means that $τθ$ is the step length for the transport step. pub inverse_transport_scaling : F, /// Factor for deciding transport reduction based on smoothness assumption violation pub minimum_goodness_factor : F, /// Maximum rays to retain in transports from each source. pub maximum_rays : usize, /// Generic parameters pub insertion : FBGenericConfig<F>, } #[replace_float_literals(F::cast_from(literal))] impl<F : Float> Default for SlidingFBConfig<F> { fn default() -> Self { SlidingFBConfig { τ0 : 0.99, ℓ0 : 1.5, inverse_transport_scaling : 1.0, minimum_goodness_factor : 1.0, // TODO: totally arbitrary choice, // should be scaled by problem data? maximum_rays : 10, insertion : Default::default() } } } /// A transport ray (including various additional computational information). #[derive(Clone, Debug)] pub struct Ray<Domain, F : Num> { /// The destination of the ray, and the mass. The source is indicated in a [`RaySet`]. δ : DeltaMeasure<Domain, F>, /// Goodness of the data term for the aray: $v(z)-v(y)-⟨∇v(x), z-y⟩ + ℓ‖z-y‖^2$. goodness : F, /// Goodness of the regularisation term for the ray: $w(z)-w(y)$. /// Initially zero until $w$ can be constructed. reg_goodness : F, /// Indicates that this ray also forms a component in γ^{k+1} with the mass `to_return`. to_return : F, } /// A set of transport rays with the same source point. #[derive(Clone, Debug)] pub struct RaySet<Domain, F : Num> { /// Source of every ray in thset source : Domain, /// Mass of the diagonal ray, with destination the same as the source. diagonal: F, /// Goodness of the data term for the diagonal ray with $z=x$: /// $v(x)-v(y)-⟨∇v(x), x-y⟩ + ℓ‖x-y‖^2$. diagonal_goodness : F, /// Goodness of the data term for the diagonal ray with $z=x$: $w(x)-w(y)$. diagonal_reg_goodness : F, /// The non-diagonal rays. rays : Vec<Ray<Domain, F>>, } #[replace_float_literals(F::cast_from(literal))] impl<Domain, F : Float> RaySet<Domain, F> { fn non_diagonal_mass(&self) -> F { self.rays .iter() .map(|Ray{ δ : DeltaMeasure{ α, .. }, .. }| *α) .sum() } fn total_mass(&self) -> F { self.non_diagonal_mass() + self.diagonal } fn targets<'a>(&'a self) -> Map< std::slice::Iter<'a, Ray<Domain, F>>, fn(&'a Ray<Domain, F>) -> &'a DeltaMeasure<Domain, F> > { fn get_δ<'b, Domain, F : Float>(Ray{ δ, .. }: &'b Ray<Domain, F>) -> &'b DeltaMeasure<Domain, F> { δ } self.rays .iter() .map(get_δ) } // fn non_diagonal_goodness(&self) -> F { // self.rays // .iter() // .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| { // α * (goodness + reg_goodness) // }) // .sum() // } // fn total_goodness(&self) -> F { // self.non_diagonal_goodness() + (self.diagonal_goodness + self.diagonal_reg_goodness) // } fn non_diagonal_badness(&self) -> F { self.rays .iter() .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| { 0.0.max(- α * (goodness + reg_goodness)) }) .sum() } fn total_badness(&self) -> F { self.non_diagonal_badness() + 0.0.max(- self.diagonal * (self.diagonal_goodness + self.diagonal_reg_goodness)) } fn total_return(&self) -> F { self.rays .iter() .map(|&Ray{ to_return, .. }| to_return) .sum() } } #[replace_float_literals(F::cast_from(literal))] impl<Domain : Clone, F : Num> RaySet<Domain, F> { fn return_targets<'a>(&'a self) -> Flatten<Map< std::slice::Iter<'a, Ray<Domain, F>>, fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>> >> { fn get_return<'b, Domain : Clone, F : Num>(ray: &'b Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>> { (ray.to_return != 0.0).then_some( DeltaMeasure{x : ray.δ.x.clone(), α : ray.to_return} ) } let tmp : Map< std::slice::Iter<'a, Ray<Domain, F>>, fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>> > = self.rays .iter() .map(get_return); tmp.flatten() } } /// Iteratively solve the pointsource localisation problem using sliding forward-backward /// splitting /// /// The parametrisatio is as for [`pointsource_fb_reg`]. /// Inertia is currently not supported. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>( opA : &'a A, b : &A::Observable, reg : Reg, op𝒟 : &'a 𝒟, sfbconfig : &SlidingFBConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, ) -> DiscreteMeasure<Loc<F, N>, F> where F : Float + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow A::Observable : std::ops::MulAssign<F>, A::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output=Loc<F, N>>, GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> + Lipschitz<&'a 𝒟, FloatType=F> + TransportLipschitz<L2Squared, FloatType=F>, BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 𝒟::Codomain : RealMapping<F, N>, S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> + Differentiable<Loc<F, N>, Output=Loc<F,N>>, K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, //+ Differentiable<Loc<F, N>, Output=Loc<F,N>>, BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, Cube<F, N>: P2Minimise<Loc<F, N>, F>, PlotLookup : Plotting<N>, DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, Reg : SlidingRegTerm<F, N> { assert!(sfbconfig.τ0 > 0.0 && sfbconfig.inverse_transport_scaling > 0.0 && sfbconfig.ℓ0 > 0.0); // Set up parameters let config = &sfbconfig.insertion; let op𝒟norm = op𝒟.opnorm_bound(); let θ = sfbconfig.inverse_transport_scaling; let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap() .max(opA.transport_lipschitz_factor(L2Squared) * θ); let ℓ = sfbconfig.ℓ0; // TODO: v scaling? // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. let tolerance = config.tolerance * τ * reg.tolerance_scaling(); let mut ε = tolerance.initial(); // Initialise iterates let mut μ : DiscreteMeasure<Loc<F, N>, F> = DiscreteMeasure::new(); let mut μ_transported_base = DiscreteMeasure::new(); let mut γ_hat : Vec<RaySet<Loc<F, N>, F>> = Vec::new(); // γ̂_k and extra info let mut residual = -b; let mut stats = IterInfo::new(); // Run the algorithm iterator.iterate(|state| { // Calculate smooth part of surrogate model. // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` // has no significant overhead. For some reosn Rust doesn't allow us simply moving // the residual and replacing it below before the end of this closure. residual *= -τ; let r = std::mem::replace(&mut residual, opA.empty_observable()); let minus_τv = opA.preadjoint().apply(r); // Save current base point and shift μ to new positions. let μ_base = μ.clone(); for δ in μ.iter_spikes_mut() { δ.x += minus_τv.differential(&δ.x) * θ; } let mut μ_transported = μ.clone(); assert_eq!(μ.len(), γ_hat.len()); // Calculate the goodness λ formed from γ_hat (≈ γ̂_k) and γ^{k+1}, where the latter // transports points x from μ_base to points y in μ as shifted above, or “returns” // them “home” to z given by the rays in γ_hat. Returning is necessary if the rays // are not “good” for the smoothness assumptions, or if γ_hat has more mass than // μ_base. let mut total_goodness = 0.0; // data term goodness let mut total_reg_goodness = 0.0; // regulariser goodness let minimum_goodness = - ε * sfbconfig.minimum_goodness_factor; for (δ, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { // Calculate data term goodness for all rays. let &DeltaMeasure{ x : ref y, α : δ_mass } = δ; let x = &r.source; let mvy = minus_τv.apply(y); let mdvx = minus_τv.differential(x); let mut r_total_mass = 0.0; // Total mass of all rays with source r.source. let mut bad_mass = 0.0; let mut calc_goodness = |goodness : &mut F, reg_goodness : &mut F, α, z : &Loc<F, N>| { *reg_goodness = 0.0; // Initial guess *goodness = mvy - minus_τv.apply(z) + mdvx.dot(&(z-y)) + ℓ * z.dist2_squared(&y); total_goodness += *goodness * α; r_total_mass += α; // TODO: should this include to_return from staging? (Probably not) if *goodness < 0.0 { bad_mass += α; } }; for ray in r.rays.iter_mut() { calc_goodness(&mut ray.goodness, &mut ray.reg_goodness, ray.δ.α, &ray.δ.x); } calc_goodness(&mut r.diagonal_goodness, &mut r.diagonal_reg_goodness, r.diagonal, x); // If the total mass of the ray set is less than that of μ at the same source, // a diagonal component needs to be added to be able to (attempt to) transport // all mass of μ. In the opposite case, we need to construct γ_{k+1} to ‘return’ // the the extra mass of γ̂_k to the target z. We return mass from the oldest “bad” // rays in the set. if δ_mass >= r_total_mass { r.diagonal += δ_mass - r_total_mass; } else { let mut reduce_transport = r_total_mass - δ_mass; let mut good_needed = (bad_mass - reduce_transport).max(0.0); // NOTE: reg_goodness is zero at this point, so it is not used in this code. let mut reduce_ray = |goodness, to_return : Option<&mut F>, α : &mut F| { if reduce_transport > 0.0 { let return_amount = if goodness < 0.0 { α.min(reduce_transport) } else { let amount = α.min(good_needed); good_needed -= amount; amount }; if return_amount > 0.0 { reduce_transport -= return_amount; // Adjust total goodness by returned amount total_goodness -= goodness * return_amount; to_return.map(|tr| *tr += return_amount); *α -= return_amount; *α > 0.0 } else { true } } else { true } }; r.rays.retain_mut(|ray| { reduce_ray(ray.goodness, Some(&mut ray.to_return), &mut ray.δ.α) }); // A bad diagonal is simply reduced without any 'return'. // It was, after all, just added to match μ, but there is no need to match it. // It's just a heuristic. // TODO: Maybe a bad diagonal should be the first to go. reduce_ray(r.diagonal_goodness, None, &mut r.diagonal); } } // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. let (d, within_tolerances) = 'adapt_transport: loop { // If transport violates goodness requirements, shift it to ‘return’ mass to z, // forcing y = z. Based on the badness of each ray set (sum of bad rays' goodness), // we proportionally distribute the reductions to each ray set, and within each ray // set, prioritise reducing the oldest bad rays' weight. let tg = total_goodness + total_reg_goodness; let adaptation_needed = minimum_goodness - tg; if adaptation_needed > 0.0 { let total_badness = γ_hat.iter().map(|r| r.total_badness()).sum(); let mut return_ray = |goodness : F, reg_goodness : F, to_return : Option<&mut F>, α : &mut F, left_to_return : &mut F| { let g = goodness + reg_goodness; assert!(*α >= 0.0 && *left_to_return >= 0.0); if *left_to_return > 0.0 && g < 0.0 { let return_amount = (*left_to_return / (-g)).min(*α); *left_to_return -= (-g) * return_amount; total_goodness -= goodness * return_amount; total_reg_goodness -= reg_goodness * return_amount; to_return.map(|tr| *tr += return_amount); *α -= return_amount; *α > 0.0 } else { true } }; for r in γ_hat.iter_mut() { let mut left_to_return = adaptation_needed * r.total_badness() / total_badness; if left_to_return > 0.0 { for ray in r.rays.iter_mut() { return_ray(ray.goodness, ray.reg_goodness, Some(&mut ray.to_return), &mut ray.δ.α, &mut left_to_return); } return_ray(r.diagonal_goodness, r.diagonal_reg_goodness, None, &mut r.diagonal, &mut left_to_return); } } } // Construct μ_k + (π_#^1-π_#^0)γ_{k+1}. // This can be broken down into // // μ_transported_base = [μ - π_#^0 (γ_shift + γ_return)] + π_#^1 γ_return, and // μ_transported = π_#^1 γ_shift // // where γ_shift is our “true” γ_{k+1}, and γ_return is the return compoennt. // The former can be constructed from δ.x and δ_new.x for δ in μ_base and δ_new in μ // (which has already been shifted), and the mass stored in a γ_hat ray's δ measure // The latter can be constructed from γ_hat rays' source and destination with the // to_return mass. // // Note that μ_transported is constructed to have the same spike locations as μ, but // to have same length as μ_base. This loop does not iterate over the spikes of μ // (and corresponding transports of γ_hat) that have been newly added in the current // 'adapt_transport loop. for (δ, δ_transported, r) in izip!(μ_base.iter_spikes(), μ_transported.iter_spikes_mut(), γ_hat.iter()) { let &DeltaMeasure{ref x, α} = δ; debug_assert_eq!(*x, r.source); let shifted_mass = r.total_mass(); let ret_mass = r.total_return(); // μ - π_#^0 (γ_shift + γ_return) μ_transported_base += DeltaMeasure { x : *x, α : α - shifted_mass - ret_mass }; // π_#^1 γ_return μ_transported_base.extend(r.return_targets()); // π_#^1 γ_shift δ_transported.set_mass(shifted_mass); } // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b) let transported_residual = calculate_residual2(&μ_transported, &μ_transported_base, opA, b); let transported_minus_τv = opA.preadjoint() .apply(transported_residual); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. let (mut d, within_tolerances) = insert_and_reweigh( &mut μ, &transported_minus_τv, &μ_transported, Some(&μ_transported_base), op𝒟, op𝒟norm, τ, ε, config, ®, state, &mut stats ); // We have d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv; more precisely // d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_transported, config)); // We “essentially” assume that the subdifferential w of the regularisation term // satisfies w'(y)=0, so for a “goodness” estimate τ[w(y)-w(z)-w'(y)(z-y)] // that incorporates the assumption, we need to calculate τ[w(z) - w(y)] for // some w in the subdifferential of the regularisation term, such that // -ε ≤ τw - d ≤ ε. This is done by [`RegTerm::goodness`]. for r in γ_hat.iter_mut() { for ray in r.rays.iter_mut() { ray.reg_goodness = reg.goodness(&mut d, &μ, &r.source, &ray.δ.x, τ, ε, config); total_reg_goodness += ray.reg_goodness * ray.δ.α; } } // If update of regularisation term goodness didn't invalidate minimum goodness // requirements, we have found our step. Otherwise we need to keep reducing // transport by repeating the loop. if total_goodness + total_reg_goodness >= minimum_goodness { break 'adapt_transport (d, within_tolerances) } }; // Update γ_hat to new location for (δ_new, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) { // Prune rays that only had a return component, as the return component becomes // a diagonal in γ̂^{k+1}. r.rays.retain(|ray| ray.δ.α != 0.0); // Otherwise zero out the return component, or stage rays for pruning // to keep memory and computational demands reasonable. let n_rays = r.rays.len(); for (ray, ir) in izip!(r.rays.iter_mut(), (0..n_rays).rev()) { if ir >= sfbconfig.maximum_rays { // Only keep sfbconfig.maximum_rays - 1 previous rays, staging others for // pruning in next step. ray.to_return = ray.δ.α; ray.δ.α = 0.0; } else { ray.to_return = 0.0; } ray.goodness = 0.0; // TODO: probably not needed ray.reg_goodness = 0.0; } // Add a new ray for the currently diagonal component if r.diagonal > 0.0 { r.rays.push(Ray{ δ : DeltaMeasure{x : r.source, α : r.diagonal}, goodness : 0.0, reg_goodness : 0.0, to_return : 0.0, }); // TODO: Maybe this does not need to be done here, and is sufficent to to do where // the goodness is calculated. r.diagonal = 0.0; } r.diagonal_goodness = 0.0; // Shift source r.source = δ_new.x; } // Extend to new spikes γ_hat.extend(μ[γ_hat.len()..].iter().map(|δ_new| { RaySet{ source : δ_new.x, rays : [].into(), diagonal : 0.0, diagonal_goodness : 0.0, diagonal_reg_goodness : 0.0 } })); // Prune spikes with zero weight. This also moves the marginal differences of corresponding // transports from γ_hat to γ_pruned_marginal_diff. // TODO: optimise standard prune with swap_remove. μ_transported_base.clear(); let mut i = 0; assert_eq!(μ.len(), γ_hat.len()); while i < μ.len() { if μ[i].α == F::ZERO { μ.swap_remove(i); let r = γ_hat.swap_remove(i); μ_transported_base.extend(r.targets().cloned()); μ_transported_base -= DeltaMeasure{ α : r.non_diagonal_mass(), x : r.source }; } else { i += 1; } } // TODO: how to merge? // Update residual residual = calculate_residual(&μ, opA, b); // Update main tolerance for next iteration let ε_prev = ε; ε = tolerance.update(ε, state.iteration()); stats.this_iters += 1; // Give function value if needed state.if_verbose(|| { // Plot if so requested plotter.plot_spikes( format!("iter {} end; {}", state.iteration(), within_tolerances), &d, "start".to_string(), Some(&minus_τv), reg.target_bounds(τ, ε_prev), &μ, ); // Calculate mean inner iterations and reset relevant counters. // Return the statistics let res = IterInfo { value : residual.norm2_squared_div2() + reg.apply(&μ), n_spikes : μ.len(), ε : ε_prev, postprocessing: config.postprocessing.then(|| μ.clone()), .. stats }; stats = IterInfo::new(); res }) }); postprocess(μ, config, L2Squared, opA, b) }