--- a/src/forward_pdps.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/forward_pdps.rs Thu Jan 23 23:35:28 2025 +0100 @@ -8,30 +8,15 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::mapping::{Mapping, Instance}; +use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; use alg_tools::norms::Norm; use alg_tools::direct_product::Pair; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - //Bounded, -}; -use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::{ BoundedLinear, AXPY, GEMV, Adjointable, IdOp, }; use alg_tools::convex::{Conjugable, Prox}; -use alg_tools::norms::{L2, Linfinity, PairNorm}; +use alg_tools::norms::{L2, PairNorm}; use crate::types::*; use crate::measures::{DiscreteMeasure, Radon, RNDM}; @@ -40,7 +25,6 @@ ForwardModel, AdjointProductPairBoundedBy, }; -use crate::seminorms::DiscreteMeasureOp; use crate::plot::{ SeqPlotter, Plotting, @@ -83,12 +67,12 @@ /// using primal-dual proximal splitting with a forward step. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_forward_pdps_pair< - 'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize + F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize >( - opA : &'a A, + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, config : &ForwardPDPSConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, @@ -102,27 +86,19 @@ where F : Float + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, A : ForwardModel< MeasureZ<F, Z, N>, F, PairNorm<Radon, L2, L2>, - PreadjointCodomain = Pair<BTFN<F, GA, BTA, N>, Z>, + PreadjointCodomain = Pair<S, Z>, > - + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, 𝒟, IdOp<Z>, FloatType=F>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, - Codomain = BTFN<F, G𝒟, BT𝒟, N>>, - BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, + + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F>, + S: DifferentiableRealMapping<F, N>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, PlotLookup : Plotting<N>, RNDM<F, N> : SpikeMerging<F>, Reg : RegTerm<F, N>, + P : ProxPenalty<F, S, Reg, N>, KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y> + GEMV<F, Z> + Adjointable<Z, Y, AdjointCodomain = Z>, @@ -150,11 +126,10 @@ let mut residual = calculate_residual(Pair(&μ, &z), opA, b); // Set up parameters - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); - let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); + let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt(); let nKz = opKz.opnorm_bound(L2, L2); let opIdZ = IdOp::new(); - let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap(); + let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); // We need to satisfy // // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 @@ -196,14 +171,13 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { // Calculate initial transport - let Pair(τv, τz) = opA.preadjoint().apply(residual * τ); + let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ); let z_base = z.clone(); let μ_base = μ.clone(); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. - let (d, _within_tolerances) = insert_and_reweigh( - &mut μ, &τv, &μ_base, None, - op𝒟, op𝒟norm, + let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv, &μ_base, None, τ, ε, &config.insertion, ®, &state, &mut stats, ); @@ -248,7 +222,7 @@ stats.this_iters += 1; state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) });