src/pdps.rs

branch
dev
changeset 35
b087e3eab191
parent 34
efa60bc4f743
child 37
c5d8bd1a7728
--- a/src/pdps.rs	Thu Aug 29 00:00:00 2024 -0500
+++ b/src/pdps.rs	Tue Dec 31 09:25:45 2024 -0500
@@ -6,8 +6,7 @@
  * Valkonen T. - _Proximal methods for point source localisation_,
    [arXiv:2212.02991](https://arxiv.org/abs/2212.02991).
 
-The main routine is [`pointsource_pdps`]. It is based on specilisatinn of
-[`generic_pointsource_fb_reg`] through relevant [`FBSpecialisation`] implementations.
+The main routine is [`pointsource_pdps_reg`].
 Both norm-2-squared and norm-1 data terms are supported. That is, implemented are solvers for
 <div>
 $$
@@ -37,10 +36,6 @@
 For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$.
 For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$.
 </p>
-
-Based on zero initialisation for $μ$, we use the [`Subdifferentiable`] trait to make an
-initialisation corresponding to the second part of the optimality conditions.
-In the algorithm itself, standard proximal steps are taking with respect to $F\_0^* + ⟨b, ·⟩$.
 */
 
 use numeric_literals::replace_float_literals;
@@ -48,13 +43,10 @@
 use nalgebra::DVector;
 use clap::ValueEnum;
 
-use alg_tools::iterate::{
-    AlgIteratorFactory,
-    AlgIteratorState,
-};
+use alg_tools::iterate::AlgIteratorFactory;
 use alg_tools::loc::Loc;
 use alg_tools::euclidean::Euclidean;
-use alg_tools::linops::Apply;
+use alg_tools::linops::Mapping;
 use alg_tools::norms::{
     Linfinity,
     Projection,
@@ -69,14 +61,17 @@
     SupportGenerator,
     LocalAnalysis,
 };
-use alg_tools::mapping::RealMapping;
+use alg_tools::mapping::{RealMapping, Instance};
 use alg_tools::nalgebra_support::ToNalgebraRealField;
 use alg_tools::linops::AXPY;
 
 use crate::types::*;
-use crate::measures::DiscreteMeasure;
+use crate::measures::{DiscreteMeasure, RNDM, Radon};
 use crate::measures::merging::SpikeMerging;
-use crate::forward_model::ForwardModel;
+use crate::forward_model::{
+    AdjointProductBoundedBy,
+    ForwardModel
+};
 use crate::seminorms::DiscreteMeasureOp;
 use crate::plot::{
     SeqPlotter,
@@ -87,7 +82,7 @@
     FBGenericConfig,
     insert_and_reweigh,
     postprocess,
-    prune_and_maybe_simple_merge
+    prune_with_stats
 };
 use crate::regularisation::RegTerm;
 use crate::dataterm::{
@@ -110,7 +105,30 @@
     Full
 }
 
-/// Settings for [`pointsource_pdps`].
+#[replace_float_literals(F::cast_from(literal))]
+impl Acceleration {
+    /// PDPS parameter acceleration. Updates τ and σ and returns ω.
+    /// This uses dual strong convexity, not primal.
+    fn accelerate<F : Float>(self, τ : &mut F, σ : &mut F, γ : F) -> F {
+        match self {
+            Acceleration::None => 1.0,
+            Acceleration::Partial => {
+                let ω = 1.0 / (1.0 + γ * (*σ)).sqrt();
+                *σ *= ω;
+                *τ /= ω;
+                ω
+            },
+            Acceleration::Full => {
+                let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt();
+                *σ *= ω;
+                *τ /= ω;
+                ω
+            },
+        }
+    }
+}
+
+/// Settings for [`pointsource_pdps_reg`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
 pub struct PDPSConfig<F : Float> {
@@ -155,9 +173,13 @@
 
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float, V :  Euclidean<F> + AXPY<F>, const N : usize>
-PDPSDataTerm<F, V, N>
-for L2Squared {
+impl<F, V, const N : usize> PDPSDataTerm<F, V, N>
+for L2Squared
+where
+    F : Float,
+    V :  Euclidean<F> + AXPY<F>,
+    for<'b> &'b V : Instance<V>,
+{
     fn some_subdifferential(&self, x : V) -> V { x }
 
     fn factor_of_strong_convexity(&self) -> F {
@@ -166,7 +188,7 @@
 
     #[inline]
     fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) {
-        y.axpy(1.0 / (1.0 + σ), &y_prev, σ / (1.0 + σ));
+        y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ));
     }
 }
 
@@ -210,16 +232,13 @@
     iterator : I,
     mut plotter : SeqPlotter<F, N>,
     dataterm : D,
-) -> DiscreteMeasure<Loc<F, N>, F>
+) -> RNDM<F, N>
 where F : Float + ToNalgebraRealField,
       I : AlgIteratorFactory<IterInfo<F, N>>,
-      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>
-                                  + std::ops::Add<A::Observable, Output=A::Observable>,
-                                  //+ std::ops::Mul<F, Output=A::Observable>, // <-- FIXME: compiler overflow
-      A::Observable : std::ops::MulAssign<F>,
+      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
       GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
-          + Lipschitz<&'a 𝒟, FloatType=F>,
+      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
+          + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
       𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
@@ -228,14 +247,20 @@
       K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
       BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
       PlotLookup : Plotting<N>,
-      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      RNDM<F, N> : SpikeMerging<F>,
       D : PDPSDataTerm<F, A::Observable, N>,
       Reg : RegTerm<F, N> {
 
+    // Check parameters
+    assert!(pdpsconfig.τ0 > 0.0 &&
+            pdpsconfig.σ0 > 0.0 &&
+            pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0,
+            "Invalid step length parameters");
+
     // Set up parameters
     let config = &pdpsconfig.generic;
-    let op𝒟norm = op𝒟.opnorm_bound();
-    let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt();
+    let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
+    let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt();
     let mut τ = pdpsconfig.τ0 / l;
     let mut σ = pdpsconfig.σ0 / l;
     let γ = dataterm.factor_of_strong_convexity();
@@ -249,53 +274,42 @@
     let mut μ = DiscreteMeasure::new();
     let mut y = dataterm.some_subdifferential(-b);
     let mut y_prev = y.clone();
+    let full_stats = |μ : &RNDM<F, N>, ε, stats| IterInfo {
+        value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ),
+        n_spikes : μ.len(),
+        ε,
+        // postprocessing: config.postprocessing.then(|| μ.clone()),
+        .. stats
+    };
     let mut stats = IterInfo::new();
 
     // Run the algorithm
-    iterator.iterate(|state| {
+    for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
-        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
-        // the residual and replacing it below before the end of this closure.
-        y *= -τ;
-        let r = std::mem::replace(&mut y, opA.empty_observable());
-        let minus_τv = opA.preadjoint().apply(r);
+        let τv = opA.preadjoint().apply(y * τ);
 
         // Save current base point
         let μ_base = μ.clone();
         
         // Insert and reweigh
-        let (d, within_tolerances) = insert_and_reweigh(
-            &mut μ, &minus_τv, &μ_base, None,
+        let (d, _within_tolerances) = insert_and_reweigh(
+            &mut μ, &τv, &μ_base, None,
             op𝒟, op𝒟norm,
             τ, ε,
-            config, &reg, state, &mut stats
+            config, &reg, &state, &mut stats
         );
 
         // Prune and possibly merge spikes
-        prune_and_maybe_simple_merge(
-            &mut μ, &minus_τv, &μ_base,
-            op𝒟,
-            τ, ε,
-            config, &reg, state, &mut stats
-        );
+        if config.merge_now(&state) {
+            stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
+                let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base));
+                reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
+            });
+        }
+        stats.pruned += prune_with_stats(&mut μ);
 
         // Update step length parameters
-        let ω = match pdpsconfig.acceleration {
-            Acceleration::None => 1.0,
-            Acceleration::Partial => {
-                let ω = 1.0 / (1.0 + γ * σ).sqrt();
-                σ = σ * ω;
-                τ = τ / ω;
-                ω
-            },
-            Acceleration::Full => {
-                let ω = 1.0 / (1.0 + 2.0 * γ * σ).sqrt();
-                σ = σ * ω;
-                τ = τ / ω;
-                ω
-            },
-        };
+        let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ);
 
         // Do dual update
         y = b.clone();                          // y = b
@@ -304,32 +318,17 @@
         dataterm.dual_update(&mut y, &y_prev, σ);
         y_prev.copy_from(&y);
 
-        // Update main tolerance for next iteration
-        let ε_prev = ε;
-        ε = tolerance.update(ε, state.iteration());
+        // Give statistics if requested
+        let iter = state.iteration();
         stats.this_iters += 1;
 
-        // Give function value if needed
         state.if_verbose(|| {
-            // Plot if so requested
-            plotter.plot_spikes(
-                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
-                "start".to_string(), Some(&minus_τv),
-                reg.target_bounds(τ, ε_prev), &μ,
-            );
-            // Calculate mean inner iterations and reset relevant counters.
-            // Return the statistics
-            let res = IterInfo {
-                value : dataterm.calculate_fit_op(&μ, opA, b) + reg.apply(&μ),
-                n_spikes : μ.len(),
-                ε : ε_prev,
-                postprocessing: config.postprocessing.then(|| μ.clone()),
-                .. stats
-            };
-            stats = IterInfo::new();
-            res
-        })
-    });
+            plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ);
+            full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
+        });
+
+        ε = tolerance.update(ε, iter);
+    }
 
     postprocess(μ, config, dataterm, opA, b)
 }

mercurial