diff -r fb911f72e698 -r c5d8bd1a7728 src/sliding_pdps.rs --- a/src/sliding_pdps.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/sliding_pdps.rs Thu Jan 23 23:35:28 2025 +0100 @@ -11,30 +11,15 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; use alg_tools::norms::Norm; use alg_tools::direct_product::Pair; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - //Bounded, -}; -use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::{ BoundedLinear, AXPY, GEMV, Adjointable, IdOp, }; use alg_tools::convex::{Conjugable, Prox}; -use alg_tools::norms::{L2, Linfinity, PairNorm}; +use alg_tools::norms::{L2, PairNorm}; use crate::types::*; use crate::measures::{DiscreteMeasure, Radon, RNDM}; @@ -45,7 +30,6 @@ LipschitzValues, }; // use crate::transport::TransportLipschitz; -use crate::seminorms::DiscreteMeasureOp; //use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, @@ -101,12 +85,12 @@ /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_sliding_pdps_pair< - 'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize + F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize >( - opA : &'a A, + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, config : &SlidingPDPSConfig, iterator : I, mut plotter : SeqPlotter, @@ -120,36 +104,25 @@ where F : Float + ToNalgebraRealField, I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - for<'b> A::Preadjoint<'b> : LipschitzValues, - BTFN : DifferentiableRealMapping, - GA : SupportGenerator + Clone, A : ForwardModel< MeasureZ, F, PairNorm, - PreadjointCodomain = Pair, Z>, + PreadjointCodomain = Pair, > - + AdjointProductPairBoundedBy, 𝒟, IdOp, FloatType=F>, - BTA : BTSearch>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN, - Codomain = BTFN>, - BT𝒟 : BTSearch>, - S: RealMapping + LocalAnalysis, N> - + DifferentiableRealMapping, - K: RealMapping + LocalAnalysis, N>, - //+ Differentiable, Derivative=Loc>, - BTNodeLookup: BTNode, N>, - Cube: P2Minimise, F>, + + AdjointProductPairBoundedBy, P, IdOp, FloatType=F>, + S : DifferentiableRealMapping, + for<'b> &'b A::Observable : std::ops::Neg + Instance, + for<'b> A::Preadjoint<'b> : LipschitzValues, PlotLookup : Plotting, RNDM : SpikeMerging, Reg : SlidingRegTerm, + P : ProxPenalty, // KOpM : Linear, Codomain=Y> // + GEMV> // + Preadjointable< // RNDM, Y, - // PreadjointCodomain = BTFN, + // PreadjointCodomain = S, // > // + TransportLipschitz // + AdjointProductBoundedBy, 𝒟, FloatType=F>, @@ -185,7 +158,6 @@ let zero_z = z.similar_origin(); // Set up parameters - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); // TODO: maybe this PairNorm doesn't make sense here? let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared); @@ -193,7 +165,7 @@ let nKz = opKz.opnorm_bound(L2, L2); let ℓ = 0.0; let opIdZ = IdOp::new(); - let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap(); + let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); // We need to satisfy // // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 @@ -278,18 +250,17 @@ // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. - let (d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop { + let (maybe_d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) let residual_μ̆ = calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); - let Pair(τv̆, τz) = opA.preadjoint().apply(residual_μ̆ * τ); + let mut τv̆z = opA.preadjoint().apply(residual_μ̆ * τ); // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. - let (d, within_tolerances) = insert_and_reweigh( - &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), - op𝒟, op𝒟norm, + let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv̆z.0, &γ1, Some(&μ_base_minus_γ0), τ, ε, &config.insertion, ®, &state, &mut stats, ); @@ -300,7 +271,7 @@ &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, ε, &config.transport ) { - break 'adapt_transport (d, within_tolerances, Pair(τv̆, τz)) + break 'adapt_transport (maybe_d, within_tolerances, τv̆z) } }; @@ -364,7 +335,7 @@ stats.this_iters += 1; state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) });