--- a/src/pdps.rs Fri Apr 28 13:15:19 2023 +0300 +++ b/src/pdps.rs Tue Dec 31 09:34:24 2024 -0500 @@ -48,12 +48,16 @@ use nalgebra::DVector; use clap::ValueEnum; -use alg_tools::iterate:: AlgIteratorFactory; +use alg_tools::iterate::{ + AlgIteratorFactory, + AlgIteratorState, +}; use alg_tools::loc::Loc; use alg_tools::euclidean::Euclidean; +use alg_tools::linops::Apply; use alg_tools::norms::{ - L1, Linfinity, - Projection, Norm, + Linfinity, + Projection, }; use alg_tools::bisection_tree::{ BTFN, @@ -71,13 +75,9 @@ use crate::types::*; use crate::measures::DiscreteMeasure; -use crate::measures::merging::{ - SpikeMerging, -}; +use crate::measures::merging::SpikeMerging; use crate::forward_model::ForwardModel; -use crate::seminorms::{ - DiscreteMeasureOp, Lipschitz -}; +use crate::seminorms::DiscreteMeasureOp; use crate::plot::{ SeqPlotter, Plotting, @@ -85,9 +85,15 @@ }; use crate::fb::{ FBGenericConfig, - FBSpecialisation, - generic_pointsource_fb_reg, - RegTerm, + insert_and_reweigh, + postprocess, + prune_and_maybe_simple_merge +}; +use crate::regularisation::RegTerm; +use crate::dataterm::{ + DataTerm, + L2Squared, + L1 }; /// Acceleration @@ -131,160 +137,54 @@ } } -/// Trait for subdifferentiable objects -pub trait Subdifferentiable<F : Float, V, U=V> { - /// Calculate some subdifferential at `x` - fn some_subdifferential(&self, x : V) -> U; +/// Trait for data terms for the PDPS +#[replace_float_literals(F::cast_from(literal))] +pub trait PDPSDataTerm<F : Float, V, const N : usize> : DataTerm<F, V, N> { + /// Calculate some subdifferential at `x` for the conjugate + fn some_subdifferential(&self, x : V) -> V; + + /// Factor of strong convexity of the conjugate + #[inline] + fn factor_of_strong_convexity(&self) -> F { + 0.0 + } + + /// Perform dual update + fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F); } -/// Type for indicating norm-2-squared data fidelity. -pub struct L2Squared; + +#[replace_float_literals(F::cast_from(literal))] +impl<F : Float, V : Euclidean<F> + AXPY<F>, const N : usize> +PDPSDataTerm<F, V, N> +for L2Squared { + fn some_subdifferential(&self, x : V) -> V { x } -impl<F : Float, V : Euclidean<F>> Subdifferentiable<F, V> for L2Squared { - fn some_subdifferential(&self, x : V) -> V { x } + fn factor_of_strong_convexity(&self) -> F { + 1.0 + } + + #[inline] + fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { + y.axpy(1.0 / (1.0 + σ), &y_prev, σ / (1.0 + σ)); + } } -impl<F : Float + nalgebra::RealField> Subdifferentiable<F, DVector<F>> for L1 { +#[replace_float_literals(F::cast_from(literal))] +impl<F : Float + nalgebra::RealField, const N : usize> +PDPSDataTerm<F, DVector<F>, N> +for L1 { fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> { // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. x.iter_mut() .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) }); x } -} -/// Specialisation of [`generic_pointsource_fb_reg`] to PDPS. -pub struct PDPS< - 'a, - F : Float + ToNalgebraRealField, - A : ForwardModel<Loc<F, N>, F>, - D, - const N : usize -> { - /// The data - b : &'a A::Observable, - /// The forward operator - opA : &'a A, - /// Primal step length - τ : F, - // Dual step length - σ : F, - /// Whether acceleration should be applied (if data term supports) - acceleration : Acceleration, - /// The dataterm. Only used by the type system. - _dataterm : D, - /// Previous dual iterate. - y_prev : A::Observable, -} - -/// Implementation of [`FBSpecialisation`] for μPDPS with norm-2-squared data fidelity. -#[replace_float_literals(F::cast_from(literal))] -impl< - 'a, - F : Float + ToNalgebraRealField, - A : ForwardModel<Loc<F, N>, F>, - const N : usize -> FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L2Squared, N> -where for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> { - - fn update( - &mut self, - μ : &mut DiscreteMeasure<Loc<F, N>, F>, - μ_base : &DiscreteMeasure<Loc<F, N>, F> - ) -> (A::Observable, Option<F>) { - let σ = self.σ; - let τ = self.τ; - let ω = match self.acceleration { - Acceleration::None => 1.0, - Acceleration::Partial => { - let ω = 1.0 / (1.0 + σ).sqrt(); - self.σ = σ * ω; - self.τ = τ / ω; - ω - }, - Acceleration::Full => { - let ω = 1.0 / (1.0 + 2.0 * σ).sqrt(); - self.σ = σ * ω; - self.τ = τ / ω; - ω - }, - }; - - μ.prune(); - - let mut y = self.b.clone(); - self.opA.gemv(&mut y, 1.0 + ω, μ, -1.0); - self.opA.gemv(&mut y, -ω, μ_base, 1.0); - y.axpy(1.0 / (1.0 + σ), &self.y_prev, σ / (1.0 + σ)); - self.y_prev.copy_from(&y); - - (y, Some(self.τ)) - } - - fn calculate_fit( - &self, - μ : &DiscreteMeasure<Loc<F, N>, F>, - _y : &A::Observable - ) -> F { - self.calculate_fit_simple(μ) - } - - fn calculate_fit_simple( - &self, - μ : &DiscreteMeasure<Loc<F, N>, F>, - ) -> F { - let mut residual = self.b.clone(); - self.opA.gemv(&mut residual, 1.0, μ, -1.0); - residual.norm2_squared_div2() - } -} - -/// Implementation of [`FBSpecialisation`] for μPDPS with norm-1 data fidelity. -#[replace_float_literals(F::cast_from(literal))] -impl< - 'a, - F : Float + ToNalgebraRealField, - A : ForwardModel<Loc<F, N>, F>, - const N : usize -> FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L1, N> -where A::Observable : Projection<F, Linfinity> + Norm<F, L1>, - for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> { - fn update( - &mut self, - μ : &mut DiscreteMeasure<Loc<F, N>, F>, - μ_base : &DiscreteMeasure<Loc<F, N>, F> - ) -> (A::Observable, Option<F>) { - let σ = self.σ; - - μ.prune(); - - //let ȳ = self.opA.apply(μ) * 2.0 - self.opA.apply(μ_base); - //*y = proj_{[-1,1]}(&self.y_prev + (ȳ - self.b) * σ) - let mut y = self.y_prev.clone(); - self.opA.gemv(&mut y, 2.0 * σ, μ, 1.0); - self.opA.gemv(&mut y, -σ, μ_base, 1.0); - y.axpy(-σ, self.b, 1.0); + #[inline] + fn dual_update(&self, y : &mut DVector<F>, y_prev : &DVector<F>, σ : F) { + y.axpy(1.0, y_prev, σ); y.proj_ball_mut(1.0, Linfinity); - self.y_prev.copy_from(&y); - - (y, None) - } - - fn calculate_fit( - &self, - μ : &DiscreteMeasure<Loc<F, N>, F>, - _y : &A::Observable - ) -> F { - self.calculate_fit_simple(μ) - } - - fn calculate_fit_simple( - &self, - μ : &DiscreteMeasure<Loc<F, N>, F>, - ) -> F { - let mut residual = self.b.clone(); - self.opA.gemv(&mut residual, 1.0, μ, -1.0); - residual.norm(L1) } } @@ -306,9 +206,9 @@ b : &'a A::Observable, reg : Reg, op𝒟 : &'a 𝒟, - config : &PDPSConfig<F>, + pdpsconfig : &PDPSConfig<F>, iterator : I, - plotter : SeqPlotter<F, N>, + mut plotter : SeqPlotter<F, N>, dataterm : D, ) -> DiscreteMeasure<Loc<F, N>, F> where F : Float + ToNalgebraRealField, @@ -319,7 +219,7 @@ A::Observable : std::ops::MulAssign<F>, GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> - + Lipschitz<𝒟, FloatType=F>, + + Lipschitz<&'a 𝒟, FloatType=F>, BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, @@ -329,27 +229,108 @@ BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, PlotLookup : Plotting<N>, DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, - PDPS<'a, F, A, D, N> : FBSpecialisation<F, A::Observable, N>, - D : Subdifferentiable<F, A::Observable>, + D : PDPSDataTerm<F, A::Observable, N>, Reg : RegTerm<F, N> { - let y = dataterm.some_subdifferential(-b); + // Set up parameters + let config = &pdpsconfig.insertion; + let op𝒟norm = op𝒟.opnorm_bound(); let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt(); - let τ = config.τ0 / l; - let σ = config.σ0 / l; + let mut τ = pdpsconfig.τ0 / l; + let mut σ = pdpsconfig.σ0 / l; + let γ = dataterm.factor_of_strong_convexity(); + + // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled + // by τ compared to the conditional gradient approach. + let tolerance = config.tolerance * τ * reg.tolerance_scaling(); + let mut ε = tolerance.initial(); + + // Initialise iterates + let mut μ = DiscreteMeasure::new(); + let mut y = dataterm.some_subdifferential(-b); + let mut y_prev = y.clone(); + let mut stats = IterInfo::new(); + + // Run the algorithm + iterator.iterate(|state| { + // Calculate smooth part of surrogate model. + // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` + // has no significant overhead. For some reosn Rust doesn't allow us simply moving + // the residual and replacing it below before the end of this closure. + y *= -τ; + let r = std::mem::replace(&mut y, opA.empty_observable()); + let minus_τv = opA.preadjoint().apply(r); + + // Save current base point + let μ_base = μ.clone(); + + // Insert and reweigh + let (d, within_tolerances) = insert_and_reweigh( + &mut μ, &minus_τv, &μ_base, None, + op𝒟, op𝒟norm, + τ, ε, + config, ®, state, &mut stats + ); + + // Prune and possibly merge spikes + prune_and_maybe_simple_merge( + &mut μ, &minus_τv, &μ_base, + op𝒟, + τ, ε, + config, ®, state, &mut stats + ); - let pdps = PDPS { - b, - opA, - τ, - σ, - acceleration : config.acceleration, - _dataterm : dataterm, - y_prev : y.clone(), - }; + // Update step length parameters + let ω = match pdpsconfig.acceleration { + Acceleration::None => 1.0, + Acceleration::Partial => { + let ω = 1.0 / (1.0 + γ * σ).sqrt(); + σ = σ * ω; + τ = τ / ω; + ω + }, + Acceleration::Full => { + let ω = 1.0 / (1.0 + 2.0 * γ * σ).sqrt(); + σ = σ * ω; + τ = τ / ω; + ω + }, + }; + + // Do dual update + y = b.clone(); // y = b + opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b + opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b + dataterm.dual_update(&mut y, &y_prev, σ); + y_prev.copy_from(&y); - generic_pointsource_fb_reg( - opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, y, pdps - ) + // Update main tolerance for next iteration + let ε_prev = ε; + ε = tolerance.update(ε, state.iteration()); + stats.this_iters += 1; + + // Give function value if needed + state.if_verbose(|| { + // Plot if so requested + plotter.plot_spikes( + format!("iter {} end; {}", state.iteration(), within_tolerances), &d, + "start".to_string(), Some(&minus_τv), + reg.target_bounds(τ, ε_prev), &μ, + ); + // Calculate mean inner iterations and reset relevant counters. + // Return the statistics + let res = IterInfo { + value : dataterm.calculate_fit_op(&μ, opA, b) + reg.apply(&μ), + n_spikes : μ.len(), + ε : ε_prev, + postprocessing: config.postprocessing.then(|| μ.clone()), + .. stats + }; + stats = IterInfo::new(); + res + }) + }); + + postprocess(μ, config, dataterm, opA, b) }