diff -r fb911f72e698 -r c5d8bd1a7728 src/pdps.rs --- a/src/pdps.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/pdps.rs Thu Jan 23 23:35:28 2025 +0100 @@ -44,46 +44,36 @@ use clap::ValueEnum; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::loc::Loc; use alg_tools::euclidean::Euclidean; use alg_tools::linops::Mapping; use alg_tools::norms::{ Linfinity, Projection, }; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - SupportGenerator, - LocalAnalysis, -}; use alg_tools::mapping::{RealMapping, Instance}; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::AXPY; use crate::types::*; -use crate::measures::{DiscreteMeasure, RNDM, Radon}; +use crate::measures::{DiscreteMeasure, RNDM}; use crate::measures::merging::SpikeMerging; use crate::forward_model::{ + ForwardModel, AdjointProductBoundedBy, - ForwardModel }; -use crate::seminorms::DiscreteMeasureOp; use crate::plot::{ SeqPlotter, Plotting, PlotLookup }; use crate::fb::{ - FBGenericConfig, - insert_and_reweigh, postprocess, prune_with_stats }; +pub use crate::prox_penalty::{ + FBGenericConfig, + ProxPenalty +}; use crate::regularisation::RegTerm; use crate::dataterm::{ DataTerm, @@ -223,33 +213,29 @@ /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( - opA : &'a A, - b : &'a A::Observable, +pub fn pointsource_pdps_reg( + opA : &A, + b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, pdpsconfig : &PDPSConfig, iterator : I, mut plotter : SeqPlotter, dataterm : D, ) -> RNDM -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - GA : SupportGenerator + Clone, - A : ForwardModel, F, PreadjointCodomain = BTFN> - + AdjointProductBoundedBy, 𝒟, FloatType=F>, - BTA : BTSearch>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, - 𝒟::Codomain : RealMapping, - S: RealMapping + LocalAnalysis, N>, - K: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - PlotLookup : Plotting, - RNDM : SpikeMerging, - D : PDPSDataTerm, - Reg : RegTerm { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory>, + A : ForwardModel, F> + + AdjointProductBoundedBy, P, FloatType=F>, + A::PreadjointCodomain : RealMapping, + for<'b> &'b A::Observable : std::ops::Neg + Instance, + PlotLookup : Plotting, + RNDM : SpikeMerging, + D : PDPSDataTerm, + Reg : RegTerm, + P : ProxPenalty, +{ // Check parameters assert!(pdpsconfig.τ0 > 0.0 && @@ -259,8 +245,7 @@ // Set up parameters let config = &pdpsconfig.generic; - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); - let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt(); + let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); let mut τ = pdpsconfig.τ0 / l; let mut σ = pdpsconfig.σ0 / l; let γ = dataterm.factor_of_strong_convexity(); @@ -286,25 +271,21 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let τv = opA.preadjoint().apply(y * τ); + let mut τv = opA.preadjoint().apply(y * τ); // Save current base point let μ_base = μ.clone(); // Insert and reweigh - let (d, _within_tolerances) = insert_and_reweigh( - &mut μ, &τv, &μ_base, None, - op𝒟, op𝒟norm, + let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats ); // Prune and possibly merge spikes if config.merge_now(&state) { - stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { - let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); - reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) - }); + stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, ®); } stats.pruned += prune_with_stats(&mut μ); @@ -323,7 +304,7 @@ stats.this_iters += 1; state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) });