--- a/src/sliding_pdps.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/sliding_pdps.rs Thu Jan 23 23:35:28 2025 +0100 @@ -11,30 +11,15 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; use alg_tools::norms::Norm; use alg_tools::direct_product::Pair; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - //Bounded, -}; -use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::{ BoundedLinear, AXPY, GEMV, Adjointable, IdOp, }; use alg_tools::convex::{Conjugable, Prox}; -use alg_tools::norms::{L2, Linfinity, PairNorm}; +use alg_tools::norms::{L2, PairNorm}; use crate::types::*; use crate::measures::{DiscreteMeasure, Radon, RNDM}; @@ -45,7 +30,6 @@ LipschitzValues, }; // use crate::transport::TransportLipschitz; -use crate::seminorms::DiscreteMeasureOp; //use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, @@ -101,12 +85,12 @@ /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_sliding_pdps_pair< - 'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize + F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize >( - opA : &'a A, + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, config : &SlidingPDPSConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, @@ -120,36 +104,25 @@ where F : Float + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, - for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, - BTFN<F, GA, BTA, N> : DifferentiableRealMapping<F, N>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, A : ForwardModel< MeasureZ<F, Z, N>, F, PairNorm<Radon, L2, L2>, - PreadjointCodomain = Pair<BTFN<F, GA, BTA, N>, Z>, + PreadjointCodomain = Pair<S, Z>, > - + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, 𝒟, IdOp<Z>, FloatType=F>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, - Codomain = BTFN<F, G𝒟, BT𝒟, N>>, - BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> - + DifferentiableRealMapping<F, N>, - K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, + + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F>, + S : DifferentiableRealMapping<F, N>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, + for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, PlotLookup : Plotting<N>, RNDM<F, N> : SpikeMerging<F>, Reg : SlidingRegTerm<F, N>, + P : ProxPenalty<F, S, Reg, N>, // KOpM : Linear<RNDM<F, N>, Codomain=Y> // + GEMV<F, RNDM<F, N>> // + Preadjointable< // RNDM<F, N>, Y, - // PreadjointCodomain = BTFN<F, GA, BTA, N>, + // PreadjointCodomain = S, // > // + TransportLipschitz<L2Squared, FloatType=F> // + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, @@ -185,7 +158,6 @@ let zero_z = z.similar_origin(); // Set up parameters - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); // TODO: maybe this PairNorm doesn't make sense here? let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared); @@ -193,7 +165,7 @@ let nKz = opKz.opnorm_bound(L2, L2); let ℓ = 0.0; let opIdZ = IdOp::new(); - let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap(); + let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); // We need to satisfy // // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 @@ -278,18 +250,17 @@ // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. - let (d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop { + let (maybe_d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) let residual_μ̆ = calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); - let Pair(τv̆, τz) = opA.preadjoint().apply(residual_μ̆ * τ); + let mut τv̆z = opA.preadjoint().apply(residual_μ̆ * τ); // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. - let (d, within_tolerances) = insert_and_reweigh( - &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), - op𝒟, op𝒟norm, + let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv̆z.0, &γ1, Some(&μ_base_minus_γ0), τ, ε, &config.insertion, ®, &state, &mut stats, ); @@ -300,7 +271,7 @@ &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, ε, &config.transport ) { - break 'adapt_transport (d, within_tolerances, Pair(τv̆, τz)) + break 'adapt_transport (maybe_d, within_tolerances, τv̆z) } }; @@ -364,7 +335,7 @@ stats.this_iters += 1; state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) });