diff -r 53136eba9abf -r 6b0db7251ebe src/sliding_pdps.rs --- a/src/sliding_pdps.rs Fri Feb 14 23:16:14 2025 -0500 +++ b/src/sliding_pdps.rs Fri Feb 14 23:46:43 2025 -0500 @@ -4,83 +4,69 @@ */ use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; //use colored::Colorize; //use nalgebra::{DVector, DMatrix}; use std::iter::Iterator; -use alg_tools::iterate::AlgIteratorFactory; +use alg_tools::convex::{Conjugable, Prox}; +use alg_tools::direct_product::Pair; use alg_tools::euclidean::Euclidean; -use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; -use alg_tools::norms::{Norm, Dist}; -use alg_tools::direct_product::Pair; +use alg_tools::iterate::AlgIteratorFactory; +use alg_tools::linops::{Adjointable, BoundedLinear, IdOp, AXPY, GEMV}; +use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping}; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::linops::{ - BoundedLinear, AXPY, GEMV, Adjointable, IdOp, -}; -use alg_tools::convex::{Conjugable, Prox}; -use alg_tools::norms::{L2, PairNorm}; +use alg_tools::norms::{Dist, Norm}; +use alg_tools::norms::{PairNorm, L2}; +use crate::forward_model::{AdjointProductPairBoundedBy, BoundedCurvature, ForwardModel}; +use crate::measures::merging::SpikeMerging; +use crate::measures::{DiscreteMeasure, Radon, RNDM}; use crate::types::*; -use crate::measures::{DiscreteMeasure, Radon, RNDM}; -use crate::measures::merging::SpikeMerging; -use crate::forward_model::{ - ForwardModel, - AdjointProductPairBoundedBy, - BoundedCurvature, -}; // use crate::transport::TransportLipschitz; //use crate::tolerance::Tolerance; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; use crate::fb::*; +use crate::plot::{PlotLookup, Plotting, SeqPlotter}; use crate::regularisation::SlidingRegTerm; // use crate::dataterm::L2Squared; +use crate::dataterm::{calculate_residual, calculate_residual2}; use crate::sliding_fb::{ - TransportConfig, - TransportStepLength, - initial_transport, - aposteriori_transport, + aposteriori_transport, initial_transport, TransportConfig, TransportStepLength, }; -use crate::dataterm::{ - calculate_residual2, - calculate_residual, -}; - /// Settings for [`pointsource_sliding_pdps_pair`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct SlidingPDPSConfig { +pub struct SlidingPDPSConfig { /// Primal step length scaling. - pub τ0 : F, + pub τ0: F, /// Primal step length scaling. - pub σp0 : F, + pub σp0: F, /// Dual step length scaling. - pub σd0 : F, + pub σd0: F, /// Transport parameters - pub transport : TransportConfig, + pub transport: TransportConfig, /// Generic parameters - pub insertion : FBGenericConfig, + pub insertion: FBGenericConfig, } #[replace_float_literals(F::cast_from(literal))] -impl Default for SlidingPDPSConfig { +impl Default for SlidingPDPSConfig { fn default() -> Self { SlidingPDPSConfig { - τ0 : 0.99, - σd0 : 0.05, - σp0 : 0.99, - transport : TransportConfig { θ0 : 0.9, ..Default::default()}, - insertion : Default::default() + τ0: 0.99, + σd0: 0.05, + σp0: 0.99, + transport: TransportConfig { + θ0: 0.9, + ..Default::default() + }, + insertion: Default::default(), } } } -type MeasureZ = Pair, Z>; +type MeasureZ = Pair, Z>; /// Iteratively solve the pointsource localisation with an additional variable /// using sliding primal-dual proximal splitting @@ -88,39 +74,45 @@ /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_sliding_pdps_pair< - F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize + F, + I, + A, + S, + Reg, + P, + Z, + R, + Y, + /*KOpM, */ KOpZ, + H, + const N: usize, >( - opA : &A, - b : &A::Observable, - reg : Reg, - prox_penalty : &P, - config : &SlidingPDPSConfig, - iterator : I, - mut plotter : SeqPlotter, + opA: &A, + b: &A::Observable, + reg: Reg, + prox_penalty: &P, + config: &SlidingPDPSConfig, + iterator: I, + mut plotter: SeqPlotter, //opKμ : KOpM, - opKz : &KOpZ, - fnR : &R, - fnH : &H, - mut z : Z, - mut y : Y, + opKz: &KOpZ, + fnR: &R, + fnH: &H, + mut z: Z, + mut y: Y, ) -> MeasureZ where - F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - A : ForwardModel< - MeasureZ, - F, - PairNorm, - PreadjointCodomain = Pair, - > - + AdjointProductPairBoundedBy, P, IdOp, FloatType=F> - + BoundedCurvature, - S : DifferentiableRealMapping, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : SlidingRegTerm, - P : ProxPenalty, + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + A: ForwardModel, F, PairNorm, PreadjointCodomain = Pair> + + AdjointProductPairBoundedBy, P, IdOp, FloatType = F> + + BoundedCurvature, + S: DifferentiableRealMapping, + for<'b> &'b A::Observable: std::ops::Neg + Instance, + PlotLookup: Plotting, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, + P: ProxPenalty, // KOpM : Linear, Codomain=Y> // + GEMV> // + Preadjointable< @@ -131,27 +123,28 @@ // + AdjointProductBoundedBy, 𝒟, FloatType=F>, // for<'b> KOpM::Preadjoint<'b> : GEMV, // Since Z is Hilbert, we may just as well use adjoints for K_z. - KOpZ : BoundedLinear + KOpZ: BoundedLinear + GEMV + Adjointable, - for<'b> KOpZ::Adjoint<'b> : GEMV, - Y : AXPY + Euclidean + Clone + ClosedAdd, - for<'b> &'b Y : Instance, - Z : AXPY + Euclidean + Clone + Norm + Dist, - for<'b> &'b Z : Instance, - R : Prox, - H : Conjugable, - for<'b> H::Conjugate<'b> : Prox, + for<'b> KOpZ::Adjoint<'b>: GEMV, + Y: AXPY + Euclidean + Clone + ClosedAdd, + for<'b> &'b Y: Instance, + Z: AXPY + Euclidean + Clone + Norm + Dist, + for<'b> &'b Z: Instance, + R: Prox, + H: Conjugable, + for<'b> H::Conjugate<'b>: Prox, { - // Check parameters - assert!(config.τ0 > 0.0 && - config.τ0 < 1.0 && - config.σp0 > 0.0 && - config.σp0 < 1.0 && - config.σd0 > 0.0 && - config.σp0 * config.σd0 <= 1.0, - "Invalid step length parameters"); + assert!( + config.τ0 > 0.0 + && config.τ0 < 1.0 + && config.σp0 > 0.0 + && config.σp0 < 1.0 + && config.σd0 > 0.0 + && config.σp0 * config.σd0 <= 1.0, + "Invalid step length parameters" + ); config.transport.check(); // Initialise iterates @@ -168,7 +161,9 @@ let nKz = opKz.opnorm_bound(L2, L2); let ℓ = 0.0; let opIdZ = IdOp::new(); - let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); + let (l, l_z) = opA + .adjoint_product_pair_bound(prox_penalty, &opIdZ) + .unwrap(); // We need to satisfy // // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 @@ -184,7 +179,7 @@ // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) let φ = 1.0 - config.σp0; let a = 1.0 - σ_p * l_z; - let τ = config.τ0 * φ / ( σ_d * bigM * a + φ * l ); + let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l); let ψ = 1.0 - τ * l; let β = σ_p * config.σd0 * nKz / a; // σ_p * σ_d * (nKz * nK_z) / a; assert!(β < 1.0); @@ -192,23 +187,23 @@ let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); // The factor two in the manuscript disappears due to the definition of 𝚹 being // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. - let (maybe_ℓ_v0, maybe_transport_lip) = opA.curvature_bound_components(); + let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); let transport_lip = maybe_transport_lip.unwrap(); - let calculate_θ = |ℓ_v, max_transport| { - let ℓ_F = ℓ_v + transport_lip * max_transport; - config.transport.θ0 / (τ*(ℓ + ℓ_F) + κ * bigθ * max_transport) + let calculate_θ = |ℓ_F, max_transport| { + let ℓ_r = transport_lip * max_transport; + config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport) }; - let mut θ_or_adaptive = match maybe_ℓ_v0 { + let mut θ_or_adaptive = match maybe_ℓ_F0 { // We assume that the residual is decreasing. - Some(ℓ_v0) => TransportStepLength::AdaptiveMax { - l: ℓ_v0 * b.norm2(), // TODO: could estimate computing the real reesidual - max_transport : 0.0, - g : calculate_θ + Some(ℓ_F0) => TransportStepLength::AdaptiveMax { + l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual + max_transport: 0.0, + g: calculate_θ, }, None => TransportStepLength::FullyAdaptive { - l : F::EPSILON, - max_transport : 0.0, - g : calculate_θ + l: F::EPSILON, + max_transport: 0.0, + g: calculate_θ, }, }; // Acceleration is not currently supported @@ -223,13 +218,15 @@ let starH = fnH.conjugate(); // Statistics - let full_stats = |residual : &A::Observable, μ : &RNDM, z : &Z, ε, stats| IterInfo { - value : residual.norm2_squared_div2() + fnR.apply(z) - + reg.apply(μ) + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)), - n_spikes : μ.len(), + let full_stats = |residual: &A::Observable, μ: &RNDM, z: &Z, ε, stats| IterInfo { + value: residual.norm2_squared_div2() + + fnR.apply(z) + + reg.apply(μ) + + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)), + n_spikes: μ.len(), ε, // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), - .. stats + ..stats }; let mut stats = IterInfo::new(); @@ -244,40 +241,49 @@ // where A_ν^* becomes a multiplier. // This is much easier with K_μ = 0, which is the only reason why are enforcing it. // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. - - let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport( - &mut γ1, &mut μ, τ, &mut θ_or_adaptive, v, - ); + + let (μ_base_masses, mut μ_base_minus_γ0) = + initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) - let residual_μ̆ = calculate_residual2(Pair(&γ1, &z), - Pair(&μ_base_minus_γ0, &zero_z), - opA, b); + let residual_μ̆ = + calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( - &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), - τ, ε, &config.insertion, - ®, &state, &mut stats, + &mut μ, + &mut τv̆, + &γ1, + Some(&μ_base_minus_γ0), + τ, + ε, + &config.insertion, + ®, + &state, + &mut stats, ); // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}} let mut z_new = τz̆; - opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ); + opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p / τ); z_new = fnR.prox(σ_p, z_new + &z); // A posteriori transport adaptation. if aposteriori_transport( - &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, + &mut γ1, + &mut μ, + &mut μ_base_minus_γ0, + &μ_base_masses, Some(z_new.dist(&z, L2)), - ε, &config.transport + ε, + &config.transport, ) { - break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new) + break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new); } }; @@ -299,7 +305,14 @@ let ins = &config.insertion; if ins.merge_now(&state) { stats.merged += prox_penalty.merge_spikes_no_fitness( - &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, ins, ®, + &mut μ, + &mut τv̆, + &γ1, + Some(&μ_base_minus_γ0), + τ, + ε, + ins, + ®, //Some(|μ̃ : &RNDM| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), ); } @@ -317,9 +330,9 @@ // Do dual update // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] - opKz.gemv(&mut y, σ_d*(1.0 + ω), &z_new, 1.0); + opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0); // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b - opKz.gemv(&mut y, -σ_d*ω, z, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b + opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b y = starH.prox(σ_d, y); z = z_new; @@ -335,14 +348,20 @@ state.if_verbose(|| { plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); - full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) + full_stats( + &residual, + &μ, + &z, + ε, + std::mem::replace(&mut stats, IterInfo::new()), + ) }); // Update main tolerance for next iteration ε = tolerance.update(ε, iter); } - let fit = |μ̃ : &RNDM| { + let fit = |μ̃: &RNDM| { (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2() //+ fnR.apply(z) + reg.apply(μ) + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z))