--- a/src/sliding_fb.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/sliding_fb.rs Thu Jan 23 23:35:28 2025 +0100 @@ -12,38 +12,19 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance}; +use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; use alg_tools::norms::Norm; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - //Bounded, -}; -use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::{L2, Linfinity}; +use alg_tools::norms::L2; use crate::types::*; use crate::measures::{DiscreteMeasure, Radon, RNDM}; -use crate::measures::merging::{ - SpikeMergingMethod, - SpikeMerging, -}; +use crate::measures::merging::SpikeMerging; use crate::forward_model::{ ForwardModel, AdjointProductBoundedBy, LipschitzValues, }; -use crate::seminorms::DiscreteMeasureOp; //use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, @@ -151,7 +132,7 @@ Observable : Euclidean<F, Output=Observable>, for<'a> &'a Observable : Instance<Observable>, //for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, - D : DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F, N>>, + D : DifferentiableRealMapping<F, N>, { use TransportStepLength::*; @@ -353,40 +334,29 @@ /// The parametrisation is as for [`pointsource_fb_reg`]. /// Inertia is currently not supported. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>( - opA : &'a A, +pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N : usize>( + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, config : &SlidingFBConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, ) -> RNDM<F, N> -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, - for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, - A::PreadjointCodomain : DifferentiableMapping< - Loc<F, N>, DerivativeDomain=Loc<F, N>, Codomain=F - >, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> - + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, - //+ TransportLipschitz<L2Squared, FloatType=F>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, - Codomain = BTFN<F, G𝒟, BT𝒟, N>>, - BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> - + DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F,N>>, - K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, - PlotLookup : Plotting<N>, - RNDM<F, N> : SpikeMerging<F>, - Reg : SlidingRegTerm<F, N> { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory<IterInfo<F, N>>, + A : ForwardModel<RNDM<F, N>, F> + + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, + //+ TransportLipschitz<L2Squared, FloatType=F>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, + for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, + A::PreadjointCodomain : DifferentiableRealMapping<F, N>, + RNDM<F, N> : SpikeMerging<F>, + Reg : SlidingRegTerm<F, N>, + P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, + PlotLookup : Plotting<N>, +{ // Check parameters assert!(config.τ0 > 0.0, "Invalid step length parameter"); @@ -398,13 +368,12 @@ let mut residual = -b; // Has to equal $Aμ-b$. // Set up parameters - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); let opAnorm = opA.opnorm_bound(Radon, L2); //let max_transport = config.max_transport.scale // * reg.radon_norm_bound(b.norm2_squared() / 2.0); //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; let ℓ = 0.0; - let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap(); + let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v @@ -446,15 +415,14 @@ // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. - let (d, _within_tolerances, τv̆) = 'adapt_transport: loop { + let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); - let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); + let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. - let (d, within_tolerances) = insert_and_reweigh( - &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), - op𝒟, op𝒟norm, + let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, &config.insertion, ®, &state, &mut stats, ); @@ -464,7 +432,7 @@ &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, ε, &config.transport ) { - break 'adapt_transport (d, within_tolerances, τv̆) + break 'adapt_transport (maybe_d, within_tolerances, τv̆) } }; @@ -480,20 +448,20 @@ (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) }); - // Merge spikes. - // This expects the prune below to prune γ. - // TODO: This may not work correctly in all cases. - let ins = &config.insertion; - if ins.merge_now(&state) { - if let SpikeMergingMethod::None = ins.merging { - } else { - stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { - let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; - let mut d = &τv̆ + op𝒟.preapply(ν); - reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) - }); - } - } + // // Merge spikes. + // // This expects the prune below to prune γ. + // // TODO: This may not work correctly in all cases. + // let ins = &config.insertion; + // if ins.merge_now(&state) { + // if let SpikeMergingMethod::None = ins.merging { + // } else { + // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { + // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; + // let mut d = &τv̆ + op𝒟.preapply(ν); + // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) + // }); + // } + // } // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the // latter needs to be pruned when μ is. @@ -514,7 +482,7 @@ // Give statistics if requested state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) });