src/pdps.rs

branch
dev
changeset 37
c5d8bd1a7728
parent 35
b087e3eab191
child 39
6316d68b58af
--- a/src/pdps.rs	Mon Jan 06 11:32:57 2025 -0500
+++ b/src/pdps.rs	Thu Jan 23 23:35:28 2025 +0100
@@ -44,46 +44,36 @@
 use clap::ValueEnum;
 
 use alg_tools::iterate::AlgIteratorFactory;
-use alg_tools::loc::Loc;
 use alg_tools::euclidean::Euclidean;
 use alg_tools::linops::Mapping;
 use alg_tools::norms::{
     Linfinity,
     Projection,
 };
-use alg_tools::bisection_tree::{
-    BTFN,
-    PreBTFN,
-    Bounds,
-    BTNodeLookup,
-    BTNode,
-    BTSearch,
-    SupportGenerator,
-    LocalAnalysis,
-};
 use alg_tools::mapping::{RealMapping, Instance};
 use alg_tools::nalgebra_support::ToNalgebraRealField;
 use alg_tools::linops::AXPY;
 
 use crate::types::*;
-use crate::measures::{DiscreteMeasure, RNDM, Radon};
+use crate::measures::{DiscreteMeasure, RNDM};
 use crate::measures::merging::SpikeMerging;
 use crate::forward_model::{
+    ForwardModel,
     AdjointProductBoundedBy,
-    ForwardModel
 };
-use crate::seminorms::DiscreteMeasureOp;
 use crate::plot::{
     SeqPlotter,
     Plotting,
     PlotLookup
 };
 use crate::fb::{
-    FBGenericConfig,
-    insert_and_reweigh,
     postprocess,
     prune_with_stats
 };
+pub use crate::prox_penalty::{
+    FBGenericConfig,
+    ProxPenalty
+};
 use crate::regularisation::RegTerm;
 use crate::dataterm::{
     DataTerm,
@@ -223,33 +213,29 @@
 ///
 /// Returns the final iterate.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>(
-    opA : &'a A,
-    b : &'a A::Observable,
+pub fn pointsource_pdps_reg<F, I, A, D, Reg, P, const N : usize>(
+    opA : &A,
+    b : &A::Observable,
     reg : Reg,
-    op𝒟 : &'a 𝒟,
+    prox_penalty : &P,
     pdpsconfig : &PDPSConfig<F>,
     iterator : I,
     mut plotter : SeqPlotter<F, N>,
     dataterm : D,
 ) -> RNDM<F, N>
-where F : Float + ToNalgebraRealField,
-      I : AlgIteratorFactory<IterInfo<F, N>>,
-      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
-      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
-      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
-          + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
-      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
-      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
-      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
-      𝒟::Codomain : RealMapping<F, N>,
-      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
-      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
-      PlotLookup : Plotting<N>,
-      RNDM<F, N> : SpikeMerging<F>,
-      D : PDPSDataTerm<F, A::Observable, N>,
-      Reg : RegTerm<F, N> {
+where
+    F : Float + ToNalgebraRealField,
+    I : AlgIteratorFactory<IterInfo<F, N>>,
+    A : ForwardModel<RNDM<F, N>, F>
+        + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>,
+    A::PreadjointCodomain : RealMapping<F, N>,
+    for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
+    PlotLookup : Plotting<N>,
+    RNDM<F, N> : SpikeMerging<F>,
+    D : PDPSDataTerm<F, A::Observable, N>,
+    Reg : RegTerm<F, N>,
+    P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
+{
 
     // Check parameters
     assert!(pdpsconfig.τ0 > 0.0 &&
@@ -259,8 +245,7 @@
 
     // Set up parameters
     let config = &pdpsconfig.generic;
-    let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
-    let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt();
+    let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt();
     let mut τ = pdpsconfig.τ0 / l;
     let mut σ = pdpsconfig.σ0 / l;
     let γ = dataterm.factor_of_strong_convexity();
@@ -286,25 +271,21 @@
     // Run the algorithm
     for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        let τv = opA.preadjoint().apply(y * τ);
+        let mut τv = opA.preadjoint().apply(y * τ);
 
         // Save current base point
         let μ_base = μ.clone();
         
         // Insert and reweigh
-        let (d, _within_tolerances) = insert_and_reweigh(
-            &mut μ, &τv, &μ_base, None,
-            op𝒟, op𝒟norm,
+        let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
+            &mut μ, &mut τv, &μ_base, None,
             τ, ε,
             config, &reg, &state, &mut stats
         );
 
         // Prune and possibly merge spikes
         if config.merge_now(&state) {
-            stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
-                let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base));
-                reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
-            });
+            stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, &reg);
         }
         stats.pruned += prune_with_stats(&mut μ);
 
@@ -323,7 +304,7 @@
         stats.this_iters += 1;
 
         state.if_verbose(|| {
-            plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ);
+            plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
             full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
         });
 

mercurial