diff -r 9738b51d90d7 -r 4f468d35fa29 src/pdps.rs
--- a/src/pdps.rs Sun Apr 27 15:03:51 2025 -0500
+++ b/src/pdps.rs Thu Feb 26 11:38:43 2026 -0500
@@ -38,50 +38,25 @@
*/
-use numeric_literals::replace_float_literals;
-use serde::{Serialize, Deserialize};
-use nalgebra::DVector;
-use clap::ValueEnum;
-
-use alg_tools::iterate::AlgIteratorFactory;
-use alg_tools::euclidean::Euclidean;
-use alg_tools::linops::Mapping;
-use alg_tools::norms::{
- Linfinity,
- Projection,
-};
-use alg_tools::mapping::{RealMapping, Instance};
-use alg_tools::nalgebra_support::ToNalgebraRealField;
-use alg_tools::linops::AXPY;
-
-use crate::types::*;
-use crate::measures::{DiscreteMeasure, RNDM};
+use crate::fb::{postprocess, prune_with_stats};
+use crate::forward_model::ForwardModel;
use crate::measures::merging::SpikeMerging;
-use crate::forward_model::{
- ForwardModel,
- AdjointProductBoundedBy,
-};
-use crate::plot::{
- SeqPlotter,
- Plotting,
- PlotLookup
-};
-use crate::fb::{
- postprocess,
- prune_with_stats
-};
-pub use crate::prox_penalty::{
- FBGenericConfig,
- ProxPenalty
-};
+use crate::measures::merging::SpikeMergingMethod;
+use crate::measures::{DiscreteMeasure, RNDM};
+use crate::plot::Plotter;
+pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBoundPD};
use crate::regularisation::RegTerm;
-use crate::dataterm::{
- DataTerm,
- L2Squared,
- L1
-};
-use crate::measures::merging::SpikeMergingMethod;
-
+use crate::types::*;
+use alg_tools::convex::{Conjugable, ConvexMapping, Prox};
+use alg_tools::error::DynResult;
+use alg_tools::iterate::AlgIteratorFactory;
+use alg_tools::linops::{Mapping, AXPY};
+use alg_tools::mapping::{DataTerm, Instance};
+use alg_tools::nalgebra_support::ToNalgebraRealField;
+use anyhow::ensure;
+use clap::ValueEnum;
+use numeric_literals::replace_float_literals;
+use serde::{Deserialize, Serialize};
/// Acceleration
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)]
@@ -93,15 +68,18 @@
#[clap(name = "partial", help = "Partial acceleration, ω = 1/√(1+σ)")]
Partial,
/// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed
- #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")]
- Full
+ #[clap(
+ name = "full",
+ help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed"
+ )]
+ Full,
}
#[replace_float_literals(F::cast_from(literal))]
impl Acceleration {
/// PDPS parameter acceleration. Updates τ and σ and returns ω.
/// This uses dual strong convexity, not primal.
- fn accelerate(self, τ : &mut F, σ : &mut F, γ : F) -> F {
+ fn accelerate(self, τ: &mut F, σ: &mut F, γ: F) -> F {
match self {
Acceleration::None => 1.0,
Acceleration::Partial => {
@@ -109,13 +87,13 @@
*σ *= ω;
*τ /= ω;
ω
- },
+ }
Acceleration::Full => {
let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt();
*σ *= ω;
*τ /= ω;
ω
- },
+ }
}
}
}
@@ -123,91 +101,35 @@
/// Settings for [`pointsource_pdps_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
-pub struct PDPSConfig {
+pub struct PDPSConfig {
/// Primal step length scaling. We must have `τ0 * σ0 < 1`.
- pub τ0 : F,
+ pub τ0: F,
/// Dual step length scaling. We must have `τ0 * σ0 < 1`.
- pub σ0 : F,
+ pub σ0: F,
/// Accelerate if available
- pub acceleration : Acceleration,
+ pub acceleration: Acceleration,
/// Generic parameters
- pub generic : FBGenericConfig,
+ pub generic: InsertionConfig,
}
#[replace_float_literals(F::cast_from(literal))]
-impl Default for PDPSConfig {
+impl Default for PDPSConfig {
fn default() -> Self {
let τ0 = 5.0;
PDPSConfig {
τ0,
- σ0 : 0.99/τ0,
- acceleration : Acceleration::Partial,
- generic : FBGenericConfig {
- merging : SpikeMergingMethod { enabled : true, ..Default::default() },
- .. Default::default()
+ σ0: 0.99 / τ0,
+ acceleration: Acceleration::Partial,
+ generic: InsertionConfig {
+ merging: SpikeMergingMethod { enabled: true, ..Default::default() },
+ ..Default::default()
},
}
}
}
-/// Trait for data terms for the PDPS
-#[replace_float_literals(F::cast_from(literal))]
-pub trait PDPSDataTerm : DataTerm {
- /// Calculate some subdifferential at `x` for the conjugate
- fn some_subdifferential(&self, x : V) -> V;
-
- /// Factor of strong convexity of the conjugate
- #[inline]
- fn factor_of_strong_convexity(&self) -> F {
- 0.0
- }
-
- /// Perform dual update
- fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F);
-}
-
-
-#[replace_float_literals(F::cast_from(literal))]
-impl PDPSDataTerm
-for L2Squared
-where
- F : Float,
- V : Euclidean + AXPY,
- for<'b> &'b V : Instance,
-{
- fn some_subdifferential(&self, x : V) -> V { x }
-
- fn factor_of_strong_convexity(&self) -> F {
- 1.0
- }
-
- #[inline]
- fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) {
- y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ));
- }
-}
-
-#[replace_float_literals(F::cast_from(literal))]
-impl
-PDPSDataTerm, N>
-for L1 {
- fn some_subdifferential(&self, mut x : DVector) -> DVector {
- // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well.
- x.iter_mut()
- .for_each(|v| if *v != F::ZERO { *v = *v/::abs(*v) });
- x
- }
-
- #[inline]
- fn dual_update(&self, y : &mut DVector, y_prev : &DVector, σ : F) {
- y.axpy(1.0, y_prev, σ);
- y.proj_ball_mut(1.0, Linfinity);
- }
-}
-
/// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting.
///
-/// The `dataterm` should be either [`L1`] for norm-1 data term or [`L2Squared`] for norm-2-squared.
/// The settings in `config` have their [respective documentation](PDPSConfig). `opA` is the
/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
/// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
@@ -218,42 +140,44 @@
///
/// Returns the final iterate.
#[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_pdps_reg(
- opA : &A,
- b : &A::Observable,
- reg : Reg,
- prox_penalty : &P,
- pdpsconfig : &PDPSConfig,
- iterator : I,
- mut plotter : SeqPlotter,
- dataterm : D,
-) -> RNDM
+pub fn pointsource_pdps_reg<'a, F, I, A, Phi, Reg, Plot, P, const N: usize>(
+ f: &'a DataTerm, A, Phi>,
+ reg: &Reg,
+ prox_penalty: &P,
+ pdpsconfig: &PDPSConfig,
+ iterator: I,
+ mut plotter: Plot,
+ μ0 : Option>,
+) -> DynResult>
where
- F : Float + ToNalgebraRealField,
- I : AlgIteratorFactory>,
- A : ForwardModel, F>
- + AdjointProductBoundedBy, P, FloatType=F>,
- A::PreadjointCodomain : RealMapping,
- for<'b> &'b A::Observable : std::ops::Neg