diff -r fb911f72e698 -r c5d8bd1a7728 src/sliding_fb.rs --- a/src/sliding_fb.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/sliding_fb.rs Thu Jan 23 23:35:28 2025 +0100 @@ -12,38 +12,19 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance}; +use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; use alg_tools::norms::Norm; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - //Bounded, -}; -use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::{L2, Linfinity}; +use alg_tools::norms::L2; use crate::types::*; use crate::measures::{DiscreteMeasure, Radon, RNDM}; -use crate::measures::merging::{ - SpikeMergingMethod, - SpikeMerging, -}; +use crate::measures::merging::SpikeMerging; use crate::forward_model::{ ForwardModel, AdjointProductBoundedBy, LipschitzValues, }; -use crate::seminorms::DiscreteMeasureOp; //use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, @@ -151,7 +132,7 @@ Observable : Euclidean, for<'a> &'a Observable : Instance, //for<'b> A::Preadjoint<'b> : LipschitzValues, - D : DifferentiableMapping, DerivativeDomain=Loc>, + D : DifferentiableRealMapping, { use TransportStepLength::*; @@ -353,40 +334,29 @@ /// The parametrisation is as for [`pointsource_fb_reg`]. /// Inertia is currently not supported. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>( - opA : &'a A, +pub fn pointsource_sliding_fb_reg( + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, config : &SlidingFBConfig, iterator : I, mut plotter : SeqPlotter, ) -> RNDM -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - for<'b> A::Preadjoint<'b> : LipschitzValues, - A::PreadjointCodomain : DifferentiableMapping< - Loc, DerivativeDomain=Loc, Codomain=F - >, - GA : SupportGenerator + Clone, - A : ForwardModel, F, PreadjointCodomain = BTFN> - + AdjointProductBoundedBy, 𝒟, FloatType=F>, - //+ TransportLipschitz, - BTA : BTSearch>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN, - Codomain = BTFN>, - BT𝒟 : BTSearch>, - S: RealMapping + LocalAnalysis, N> - + DifferentiableMapping, DerivativeDomain=Loc>, - K: RealMapping + LocalAnalysis, N>, - //+ Differentiable, Derivative=Loc>, - BTNodeLookup: BTNode, N>, - Cube: P2Minimise, F>, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : SlidingRegTerm { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory>, + A : ForwardModel, F> + + AdjointProductBoundedBy, P, FloatType=F>, + //+ TransportLipschitz, + for<'b> &'b A::Observable : std::ops::Neg + Instance, + for<'b> A::Preadjoint<'b> : LipschitzValues, + A::PreadjointCodomain : DifferentiableRealMapping, + RNDM : SpikeMerging, + Reg : SlidingRegTerm, + P : ProxPenalty, + PlotLookup : Plotting, +{ // Check parameters assert!(config.τ0 > 0.0, "Invalid step length parameter"); @@ -398,13 +368,12 @@ let mut residual = -b; // Has to equal $Aμ-b$. // Set up parameters - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); let opAnorm = opA.opnorm_bound(Radon, L2); //let max_transport = config.max_transport.scale // * reg.radon_norm_bound(b.norm2_squared() / 2.0); //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; let ℓ = 0.0; - let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap(); + let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v @@ -446,15 +415,14 @@ // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. - let (d, _within_tolerances, τv̆) = 'adapt_transport: loop { + let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); - let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); + let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. - let (d, within_tolerances) = insert_and_reweigh( - &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), - op𝒟, op𝒟norm, + let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, &config.insertion, ®, &state, &mut stats, ); @@ -464,7 +432,7 @@ &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, ε, &config.transport ) { - break 'adapt_transport (d, within_tolerances, τv̆) + break 'adapt_transport (maybe_d, within_tolerances, τv̆) } }; @@ -480,20 +448,20 @@ (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) }); - // Merge spikes. - // This expects the prune below to prune γ. - // TODO: This may not work correctly in all cases. - let ins = &config.insertion; - if ins.merge_now(&state) { - if let SpikeMergingMethod::None = ins.merging { - } else { - stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { - let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; - let mut d = &τv̆ + op𝒟.preapply(ν); - reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) - }); - } - } + // // Merge spikes. + // // This expects the prune below to prune γ. + // // TODO: This may not work correctly in all cases. + // let ins = &config.insertion; + // if ins.merge_now(&state) { + // if let SpikeMergingMethod::None = ins.merging { + // } else { + // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { + // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; + // let mut d = &τv̆ + op𝒟.preapply(ν); + // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) + // }); + // } + // } // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the // latter needs to be pruned when μ is. @@ -514,7 +482,7 @@ // Give statistics if requested state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) });