src/pdps.rs

branch
dev
changeset 32
56c8adc32b09
parent 29
87649ccfa6a8
child 34
efa60bc4f743
--- a/src/pdps.rs	Fri Apr 28 13:15:19 2023 +0300
+++ b/src/pdps.rs	Tue Dec 31 09:34:24 2024 -0500
@@ -48,12 +48,16 @@
 use nalgebra::DVector;
 use clap::ValueEnum;
 
-use alg_tools::iterate:: AlgIteratorFactory;
+use alg_tools::iterate::{
+    AlgIteratorFactory,
+    AlgIteratorState,
+};
 use alg_tools::loc::Loc;
 use alg_tools::euclidean::Euclidean;
+use alg_tools::linops::Apply;
 use alg_tools::norms::{
-    L1, Linfinity,
-    Projection, Norm,
+    Linfinity,
+    Projection,
 };
 use alg_tools::bisection_tree::{
     BTFN,
@@ -71,13 +75,9 @@
 
 use crate::types::*;
 use crate::measures::DiscreteMeasure;
-use crate::measures::merging::{
-    SpikeMerging,
-};
+use crate::measures::merging::SpikeMerging;
 use crate::forward_model::ForwardModel;
-use crate::seminorms::{
-    DiscreteMeasureOp, Lipschitz
-};
+use crate::seminorms::DiscreteMeasureOp;
 use crate::plot::{
     SeqPlotter,
     Plotting,
@@ -85,9 +85,15 @@
 };
 use crate::fb::{
     FBGenericConfig,
-    FBSpecialisation,
-    generic_pointsource_fb_reg,
-    RegTerm,
+    insert_and_reweigh,
+    postprocess,
+    prune_and_maybe_simple_merge
+};
+use crate::regularisation::RegTerm;
+use crate::dataterm::{
+    DataTerm,
+    L2Squared,
+    L1
 };
 
 /// Acceleration
@@ -131,160 +137,54 @@
     }
 }
 
-/// Trait for subdifferentiable objects
-pub trait Subdifferentiable<F : Float, V, U=V> {
-    /// Calculate some subdifferential at `x`
-    fn some_subdifferential(&self, x : V) -> U;
+/// Trait for data terms for the PDPS
+#[replace_float_literals(F::cast_from(literal))]
+pub trait PDPSDataTerm<F : Float, V, const N : usize> : DataTerm<F, V, N> {
+    /// Calculate some subdifferential at `x` for the conjugate
+    fn some_subdifferential(&self, x : V) -> V;
+
+    /// Factor of strong convexity of the conjugate
+    #[inline]
+    fn factor_of_strong_convexity(&self) -> F {
+        0.0
+    }
+
+    /// Perform dual update
+    fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F);
 }
 
-/// Type for indicating norm-2-squared data fidelity.
-pub struct L2Squared;
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F : Float, V :  Euclidean<F> + AXPY<F>, const N : usize>
+PDPSDataTerm<F, V, N>
+for L2Squared {
+    fn some_subdifferential(&self, x : V) -> V { x }
 
-impl<F : Float, V : Euclidean<F>> Subdifferentiable<F, V> for L2Squared {
-    fn some_subdifferential(&self, x : V) -> V { x }
+    fn factor_of_strong_convexity(&self) -> F {
+        1.0
+    }
+
+    #[inline]
+    fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) {
+        y.axpy(1.0 / (1.0 + σ), &y_prev, σ / (1.0 + σ));
+    }
 }
 
-impl<F : Float + nalgebra::RealField> Subdifferentiable<F, DVector<F>> for L1 {
+#[replace_float_literals(F::cast_from(literal))]
+impl<F : Float + nalgebra::RealField, const N : usize>
+PDPSDataTerm<F, DVector<F>, N>
+for L1 {
     fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> {
         // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well.
         x.iter_mut()
          .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) });
         x
     }
-}
 
-/// Specialisation of [`generic_pointsource_fb_reg`] to PDPS.
-pub struct PDPS<
-    'a,
-    F : Float + ToNalgebraRealField,
-    A : ForwardModel<Loc<F, N>, F>,
-    D,
-    const N : usize
-> {
-    /// The data
-    b : &'a A::Observable,
-    /// The forward operator
-    opA : &'a A,
-    /// Primal step length
-    τ : F,
-    // Dual step length
-    σ : F,
-    /// Whether acceleration should be applied (if data term supports)
-    acceleration : Acceleration,
-    /// The dataterm. Only used by the type system.
-    _dataterm : D,
-    /// Previous dual iterate.
-    y_prev : A::Observable,
-}
-
-/// Implementation of [`FBSpecialisation`] for μPDPS with norm-2-squared data fidelity.
-#[replace_float_literals(F::cast_from(literal))]
-impl<
-    'a,
-    F : Float + ToNalgebraRealField,
-    A : ForwardModel<Loc<F, N>, F>,
-    const N : usize
-> FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L2Squared, N>
-where for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> {
-
-    fn update(
-        &mut self,
-        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
-        μ_base : &DiscreteMeasure<Loc<F, N>, F>
-    ) -> (A::Observable, Option<F>) {
-        let σ = self.σ;
-        let τ = self.τ;
-        let ω = match self.acceleration {
-            Acceleration::None => 1.0,
-            Acceleration::Partial => {
-                let ω = 1.0 / (1.0 + σ).sqrt();
-                self.σ = σ * ω;
-                self.τ = τ / ω;
-                ω
-            },
-            Acceleration::Full => {
-                let ω = 1.0 / (1.0 + 2.0 * σ).sqrt();
-                self.σ = σ * ω;
-                self.τ = τ / ω;
-                ω
-            },
-        };
-
-        μ.prune();
-
-        let mut y = self.b.clone();
-        self.opA.gemv(&mut y, 1.0 + ω, μ, -1.0);
-        self.opA.gemv(&mut y, -ω, μ_base, 1.0);
-        y.axpy(1.0 / (1.0 + σ), &self.y_prev,  σ / (1.0 + σ));
-        self.y_prev.copy_from(&y);
-
-        (y, Some(self.τ))
-    }
-
-    fn calculate_fit(
-        &self,
-        μ : &DiscreteMeasure<Loc<F, N>, F>,
-        _y : &A::Observable
-    ) -> F {
-        self.calculate_fit_simple(μ)
-    }
-
-    fn calculate_fit_simple(
-        &self,
-        μ : &DiscreteMeasure<Loc<F, N>, F>,
-    ) -> F {
-        let mut residual = self.b.clone();
-        self.opA.gemv(&mut residual, 1.0, μ, -1.0);
-        residual.norm2_squared_div2()
-    }
-}
-
-/// Implementation of [`FBSpecialisation`] for μPDPS with norm-1 data fidelity.
-#[replace_float_literals(F::cast_from(literal))]
-impl<
-    'a,
-    F : Float + ToNalgebraRealField,
-    A : ForwardModel<Loc<F, N>, F>,
-    const N : usize
-> FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L1, N>
-where A::Observable : Projection<F, Linfinity> + Norm<F, L1>,
-      for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> {
-    fn update(
-        &mut self,
-        μ : &mut DiscreteMeasure<Loc<F, N>, F>,
-        μ_base : &DiscreteMeasure<Loc<F, N>, F>
-    ) -> (A::Observable, Option<F>) {
-        let σ = self.σ;
-
-        μ.prune();
-
-        //let ȳ = self.opA.apply(μ) * 2.0 - self.opA.apply(μ_base);
-        //*y = proj_{[-1,1]}(&self.y_prev + (ȳ - self.b) * σ)
-        let mut y = self.y_prev.clone();
-        self.opA.gemv(&mut y, 2.0 * σ, μ, 1.0);
-        self.opA.gemv(&mut y, -σ, μ_base, 1.0);
-        y.axpy(-σ, self.b, 1.0);
+     #[inline]
+     fn dual_update(&self, y : &mut DVector<F>, y_prev : &DVector<F>, σ : F) {
+        y.axpy(1.0, y_prev, σ);
         y.proj_ball_mut(1.0, Linfinity);
-        self.y_prev.copy_from(&y);
-
-        (y, None)
-    }
-
-    fn calculate_fit(
-        &self,
-        μ : &DiscreteMeasure<Loc<F, N>, F>,
-        _y : &A::Observable
-    ) -> F {
-        self.calculate_fit_simple(μ)
-    }
-
-    fn calculate_fit_simple(
-        &self,
-        μ : &DiscreteMeasure<Loc<F, N>, F>,
-    ) -> F {
-        let mut residual = self.b.clone();
-        self.opA.gemv(&mut residual, 1.0, μ, -1.0);
-        residual.norm(L1)
     }
 }
 
@@ -306,9 +206,9 @@
     b : &'a A::Observable,
     reg : Reg,
     op𝒟 : &'a 𝒟,
-    config : &PDPSConfig<F>,
+    pdpsconfig : &PDPSConfig<F>,
     iterator : I,
-    plotter : SeqPlotter<F, N>,
+    mut plotter : SeqPlotter<F, N>,
     dataterm : D,
 ) -> DiscreteMeasure<Loc<F, N>, F>
 where F : Float + ToNalgebraRealField,
@@ -319,7 +219,7 @@
       A::Observable : std::ops::MulAssign<F>,
       GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
       A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
-          + Lipschitz<𝒟, FloatType=F>,
+          + Lipschitz<&'a 𝒟, FloatType=F>,
       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
       𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
@@ -329,27 +229,108 @@
       BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
       PlotLookup : Plotting<N>,
       DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
-      PDPS<'a, F, A, D, N> : FBSpecialisation<F, A::Observable, N>,
-      D : Subdifferentiable<F, A::Observable>,
+      D : PDPSDataTerm<F, A::Observable, N>,
       Reg : RegTerm<F, N> {
 
-    let y = dataterm.some_subdifferential(-b);
+    // Set up parameters
+    let config = &pdpsconfig.insertion;
+    let op𝒟norm = op𝒟.opnorm_bound();
     let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt();
-    let τ = config.τ0 / l;
-    let σ = config.σ0 / l;
+    let mut τ = pdpsconfig.τ0 / l;
+    let mut σ = pdpsconfig.σ0 / l;
+    let γ = dataterm.factor_of_strong_convexity();
+
+    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
+    // by τ compared to the conditional gradient approach.
+    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
+    let mut ε = tolerance.initial();
+
+    // Initialise iterates
+    let mut μ = DiscreteMeasure::new();
+    let mut y = dataterm.some_subdifferential(-b);
+    let mut y_prev = y.clone();
+    let mut stats = IterInfo::new();
+
+    // Run the algorithm
+    iterator.iterate(|state| {
+        // Calculate smooth part of surrogate model.
+        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
+        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
+        // the residual and replacing it below before the end of this closure.
+        y *= -τ;
+        let r = std::mem::replace(&mut y, opA.empty_observable());
+        let minus_τv = opA.preadjoint().apply(r);
+
+        // Save current base point
+        let μ_base = μ.clone();
+        
+        // Insert and reweigh
+        let (d, within_tolerances) = insert_and_reweigh(
+            &mut μ, &minus_τv, &μ_base, None,
+            op𝒟, op𝒟norm,
+            τ, ε,
+            config, &reg, state, &mut stats
+        );
+
+        // Prune and possibly merge spikes
+        prune_and_maybe_simple_merge(
+            &mut μ, &minus_τv, &μ_base,
+            op𝒟,
+            τ, ε,
+            config, &reg, state, &mut stats
+        );
 
-    let pdps = PDPS {
-        b,
-        opA,
-        τ,
-        σ,
-        acceleration : config.acceleration,
-        _dataterm : dataterm,
-        y_prev : y.clone(),
-    };
+        // Update step length parameters
+        let ω = match pdpsconfig.acceleration {
+            Acceleration::None => 1.0,
+            Acceleration::Partial => {
+                let ω = 1.0 / (1.0 + γ * σ).sqrt();
+                σ = σ * ω;
+                τ = τ / ω;
+                ω
+            },
+            Acceleration::Full => {
+                let ω = 1.0 / (1.0 + 2.0 * γ * σ).sqrt();
+                σ = σ * ω;
+                τ = τ / ω;
+                ω
+            },
+        };
+
+        // Do dual update
+        y = b.clone();                          // y = b
+        opA.gemv(&mut y, 1.0 + ω, &μ, -1.0);    // y = A[(1+ω)μ^{k+1}]-b
+        opA.gemv(&mut y, -ω, &μ_base, 1.0);     // y = A[(1+ω)μ^{k+1} - ω μ^k]-b
+        dataterm.dual_update(&mut y, &y_prev, σ);
+        y_prev.copy_from(&y);
 
-    generic_pointsource_fb_reg(
-        opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, y, pdps
-    )
+        // Update main tolerance for next iteration
+        let ε_prev = ε;
+        ε = tolerance.update(ε, state.iteration());
+        stats.this_iters += 1;
+
+        // Give function value if needed
+        state.if_verbose(|| {
+            // Plot if so requested
+            plotter.plot_spikes(
+                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
+                "start".to_string(), Some(&minus_τv),
+                reg.target_bounds(τ, ε_prev), &μ,
+            );
+            // Calculate mean inner iterations and reset relevant counters.
+            // Return the statistics
+            let res = IterInfo {
+                value : dataterm.calculate_fit_op(&μ, opA, b) + reg.apply(&μ),
+                n_spikes : μ.len(),
+                ε : ε_prev,
+                postprocessing: config.postprocessing.then(|| μ.clone()),
+                .. stats
+            };
+            stats = IterInfo::new();
+            res
+        })
+    });
+
+    postprocess(μ, config, dataterm, opA, b)
 }
 

mercurial