--- a/src/pdps.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/pdps.rs Thu Jan 23 23:35:28 2025 +0100 @@ -44,46 +44,36 @@ use clap::ValueEnum; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::loc::Loc; use alg_tools::euclidean::Euclidean; use alg_tools::linops::Mapping; use alg_tools::norms::{ Linfinity, Projection, }; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - SupportGenerator, - LocalAnalysis, -}; use alg_tools::mapping::{RealMapping, Instance}; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::AXPY; use crate::types::*; -use crate::measures::{DiscreteMeasure, RNDM, Radon}; +use crate::measures::{DiscreteMeasure, RNDM}; use crate::measures::merging::SpikeMerging; use crate::forward_model::{ + ForwardModel, AdjointProductBoundedBy, - ForwardModel }; -use crate::seminorms::DiscreteMeasureOp; use crate::plot::{ SeqPlotter, Plotting, PlotLookup }; use crate::fb::{ - FBGenericConfig, - insert_and_reweigh, postprocess, prune_with_stats }; +pub use crate::prox_penalty::{ + FBGenericConfig, + ProxPenalty +}; use crate::regularisation::RegTerm; use crate::dataterm::{ DataTerm, @@ -223,33 +213,29 @@ /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( - opA : &'a A, - b : &'a A::Observable, +pub fn pointsource_pdps_reg<F, I, A, D, Reg, P, const N : usize>( + opA : &A, + b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, pdpsconfig : &PDPSConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, dataterm : D, ) -> RNDM<F, N> -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> - + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, - 𝒟::Codomain : RealMapping<F, N>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - PlotLookup : Plotting<N>, - RNDM<F, N> : SpikeMerging<F>, - D : PDPSDataTerm<F, A::Observable, N>, - Reg : RegTerm<F, N> { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory<IterInfo<F, N>>, + A : ForwardModel<RNDM<F, N>, F> + + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, + A::PreadjointCodomain : RealMapping<F, N>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, + PlotLookup : Plotting<N>, + RNDM<F, N> : SpikeMerging<F>, + D : PDPSDataTerm<F, A::Observable, N>, + Reg : RegTerm<F, N>, + P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, +{ // Check parameters assert!(pdpsconfig.τ0 > 0.0 && @@ -259,8 +245,7 @@ // Set up parameters let config = &pdpsconfig.generic; - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); - let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt(); + let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); let mut τ = pdpsconfig.τ0 / l; let mut σ = pdpsconfig.σ0 / l; let γ = dataterm.factor_of_strong_convexity(); @@ -286,25 +271,21 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let τv = opA.preadjoint().apply(y * τ); + let mut τv = opA.preadjoint().apply(y * τ); // Save current base point let μ_base = μ.clone(); // Insert and reweigh - let (d, _within_tolerances) = insert_and_reweigh( - &mut μ, &τv, &μ_base, None, - op𝒟, op𝒟norm, + let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats ); // Prune and possibly merge spikes if config.merge_now(&state) { - stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { - let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); - reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) - }); + stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, ®); } stats.pruned += prune_with_stats(&mut μ); @@ -323,7 +304,7 @@ stats.this_iters += 1; state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) });