Thu, 23 Jan 2025 23:35:28 +0100
Generic proximal penalty support
--- a/src/experiments.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/experiments.rs Thu Jan 23 23:35:28 2025 +0100 @@ -24,7 +24,8 @@ ExperimentBiased, Named, DefaultAlgorithm, - AlgorithmConfig + AlgorithmConfig, + ProxTerm }; //use crate::fb::FBGenericConfig; use crate::rand_distr::{SerializableNormal, SaltAndPepper}; @@ -153,7 +154,11 @@ .. Default::default() } }; - + let defaults_2d = HashMap::from([ + (DefaultAlgorithm::PDPS, AlgorithmConfig::PDPS(pdps_2d(), ProxTerm::Wave)), + (DefaultAlgorithm::RadonPDPS, AlgorithmConfig::PDPS(pdps_2d(), ProxTerm::RadonSquared)) + ]); + // We add a hash of the experiment name to the configured // noise seed to not use the same noise for different experiments. let mut h = DefaultHasher::new(); @@ -212,9 +217,7 @@ kernel : Prod(AutoConvolution(spread_cutoff), base_spread), kernel_plot_width, noise_seed, - algorithm_defaults: HashMap::from([ - (DefaultAlgorithm::PDPS, AlgorithmConfig::PDPS(pdps_2d())) - ]), + algorithm_defaults: defaults_2d, }}) }, Experiment2DFast => { @@ -231,9 +234,7 @@ kernel : base_spread, kernel_plot_width, noise_seed, - algorithm_defaults: HashMap::from([ - (DefaultAlgorithm::PDPS, AlgorithmConfig::PDPS(pdps_2d())) - ]), + algorithm_defaults: defaults_2d, }}) }, Experiment1D_L1 => { @@ -295,9 +296,7 @@ kernel : Prod(AutoConvolution(spread_cutoff), base_spread), kernel_plot_width, noise_seed, - algorithm_defaults: HashMap::from([ - (DefaultAlgorithm::PDPS, AlgorithmConfig::PDPS(pdps_2d())) - ]), + algorithm_defaults: defaults_2d, }}) }, Experiment2D_L1_Fast => { @@ -317,9 +316,7 @@ kernel : base_spread, kernel_plot_width, noise_seed, - algorithm_defaults: HashMap::from([ - (DefaultAlgorithm::PDPS, AlgorithmConfig::PDPS(pdps_2d())) - ]), + algorithm_defaults: defaults_2d, }}) }, Experiment1D_TV_Fast => {
--- a/src/fb.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/fb.rs Thu Jan 23 23:35:28 2025 +0100 @@ -80,40 +80,18 @@ use numeric_literals::replace_float_literals; use serde::{Serialize, Deserialize}; use colored::Colorize; -use nalgebra::DVector; -use alg_tools::iterate::{ - AlgIteratorFactory, - AlgIteratorIteration, - AlgIterator, -}; +use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; use alg_tools::linops::{Mapping, GEMV}; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - BothGenerators, -}; use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::instance::Instance; -use alg_tools::norms::Linfinity; use crate::types::*; use crate::measures::{ DiscreteMeasure, RNDM, - DeltaMeasure, - Radon, }; use crate::measures::merging::{ SpikeMergingMethod, @@ -121,14 +99,8 @@ }; use crate::forward_model::{ ForwardModel, - AdjointProductBoundedBy + AdjointProductBoundedBy, }; -use crate::seminorms::DiscreteMeasureOp; -use crate::subproblem::{ - InnerSettings, - InnerMethod, -}; -use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, Plotting, @@ -140,6 +112,10 @@ L2Squared, DataTerm, }; +pub use crate::prox_penalty::{ + FBGenericConfig, + ProxPenalty +}; /// Settings for [`pointsource_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] @@ -151,51 +127,6 @@ pub generic : FBGenericConfig<F>, } -/// Settings for the solution of the stepwise optimality condition. -#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] -#[serde(default)] -pub struct FBGenericConfig<F : Float> { - /// Tolerance for point insertion. - pub tolerance : Tolerance<F>, - - /// Stop looking for predual maximum (where to isert a new point) below - /// `tolerance` multiplied by this factor. - /// - /// Not used by [`super::radon_fb`]. - pub insertion_cutoff_factor : F, - - /// Settings for branch and bound refinement when looking for predual maxima - pub refinement : RefinementSettings<F>, - - /// Maximum insertions within each outer iteration - /// - /// Not used by [`super::radon_fb`]. - pub max_insertions : usize, - - /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. - /// - /// Not used by [`super::radon_fb`]. - pub bootstrap_insertions : Option<(usize, usize)>, - - /// Inner method settings - pub inner : InnerSettings<F>, - - /// Spike merging method - pub merging : SpikeMergingMethod<F>, - - /// Tolerance multiplier for merges - pub merge_tolerance_mult : F, - - /// Spike merging method after the last step - pub final_merging : SpikeMergingMethod<F>, - - /// Iterations between merging heuristic tries - pub merge_every : usize, - - // /// Save $μ$ for postprocessing optimisation - // pub postprocessing : bool -} - #[replace_float_literals(F::cast_from(literal))] impl<F : Float> Default for FBConfig<F> { fn default() -> Self { @@ -206,155 +137,6 @@ } } -#[replace_float_literals(F::cast_from(literal))] -impl<F : Float> Default for FBGenericConfig<F> { - fn default() -> Self { - FBGenericConfig { - tolerance : Default::default(), - insertion_cutoff_factor : 1.0, - refinement : Default::default(), - max_insertions : 100, - //bootstrap_insertions : None, - bootstrap_insertions : Some((10, 1)), - inner : InnerSettings { - method : InnerMethod::Default, - .. Default::default() - }, - merging : SpikeMergingMethod::None, - //merging : Default::default(), - final_merging : Default::default(), - merge_every : 10, - merge_tolerance_mult : 2.0, - // postprocessing : false, - } - } -} - -impl<F : Float> FBGenericConfig<F> { - /// Check if merging should be attempted this iteration - pub fn merge_now<I : AlgIterator>(&self, state : &AlgIteratorIteration<I>) -> bool { - state.iteration() % self.merge_every == 0 - } -} - -/// TODO: document. -/// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike -/// locations, while `ν_delta` may have different locations. -#[replace_float_literals(F::cast_from(literal))] -pub(crate) fn insert_and_reweigh< - 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, I, const N : usize ->( - μ : &mut RNDM<F, N>, - τv : &BTFN<F, GA, BTA, N>, - μ_base : &RNDM<F, N>, - ν_delta: Option<&RNDM<F, N>>, - op𝒟 : &'a 𝒟, - op𝒟norm : F, - τ : F, - ε : F, - config : &FBGenericConfig<F>, - reg : &Reg, - state : &AlgIteratorIteration<I>, - stats : &mut IterInfo<F, N>, -) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool) -where F : Float + ToNalgebraRealField, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, - 𝒟::Codomain : RealMapping<F, N>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Reg : RegTerm<F, N>, - I : AlgIterator { - - // Maximum insertion count and measure difference calculation depend on insertion style. - let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { - (i, Some((l, k))) if i <= l => (k, false), - _ => (config.max_insertions, !state.is_quiet()), - }; - - let ω0 = match ν_delta { - None => op𝒟.apply(μ_base), - Some(ν) => op𝒟.apply(μ_base + ν), - }; - - // Add points to support until within error tolerance or maximum insertion count reached. - let mut count = 0; - let (within_tolerances, d) = 'insertion: loop { - if μ.len() > 0 { - // Form finite-dimensional subproblem. The subproblem references to the original μ^k - // from the beginning of the iteration are all contained in the immutable c and g. - // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional - // problems have not yet been updated to sign change. - let à = op𝒟.findim_matrix(μ.iter_locations()); - let g̃ = DVector::from_iterator(μ.len(), - μ.iter_locations() - .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) - .map(F::to_nalgebra_mixed)); - let mut x = μ.masses_dvector(); - - // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. - // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ - // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ - // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 - // = n |𝒟| |x|_2, where n is the number of points. Therefore - let Ã_normest = op𝒟norm * F::cast_from(μ.len()); - - // Solve finite-dimensional subproblem. - stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); - - // Update masses of μ based on solution of finite-dimensional subproblem. - μ.set_masses_dvector(&x); - } - - // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality - // conditions in the predual space, and finding new points for insertion, if necessary. - let mut d = τv + match ν_delta { - None => op𝒟.preapply(μ.sub_matching(μ_base)), - Some(ν) => op𝒟.preapply(μ.sub_matching(μ_base) - ν) - }; - - // If no merging heuristic is used, let's be more conservative about spike insertion, - // and skip it after first round. If merging is done, being more greedy about spike - // insertion also seems to improve performance. - let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { - false - } else { - count > 0 - }; - - // Find a spike to insert, if needed - let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( - &mut d, τ, ε, skip_by_rough_check, config - ) { - None => break 'insertion (true, d), - Some(res) => res, - }; - - // Break if maximum insertion count reached - if count >= max_insertions { - break 'insertion (in_bounds, d) - } - - // No point in optimising the weight here; the finite-dimensional algorithm is fast. - *μ += DeltaMeasure { x : ξ, α : 0.0 }; - count += 1; - stats.inserted += 1; - }; - - if !within_tolerances && warn_insertions { - // Complain (but continue) if we failed to get within tolerances - // by inserting more points. - let err = format!("Maximum insertions reached without achieving \ - subproblem solution tolerance"); - println!("{}", err.red()); - } - - (d, within_tolerances) -} - pub(crate) fn prune_with_stats<F : Float, const N : usize>( μ : &mut RNDM<F, N>, ) -> usize { @@ -409,38 +191,32 @@ /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_fb_reg< - 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize + F, I, A, Reg, P, const N : usize >( - opA : &'a A, + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, fbconfig : &FBConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, ) -> RNDM<F, N> -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> - + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, - 𝒟::Codomain : RealMapping<F, N>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, - PlotLookup : Plotting<N>, - RNDM<F, N> : SpikeMerging<F>, - Reg : RegTerm<F, N> { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory<IterInfo<F, N>>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, + A : ForwardModel<RNDM<F, N>, F> + + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, + A::PreadjointCodomain : RealMapping<F, N>, + PlotLookup : Plotting<N>, + RNDM<F, N> : SpikeMerging<F>, + Reg : RegTerm<F, N>, + P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, +{ // Set up parameters let config = &fbconfig.generic; - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); - let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap(); + let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. let tolerance = config.tolerance * τ * reg.tolerance_scaling(); @@ -465,26 +241,23 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let τv = opA.preadjoint().apply(residual * τ); + let mut τv = opA.preadjoint().apply(residual * τ); // Save current base point let μ_base = μ.clone(); // Insert and reweigh - let (d, _within_tolerances) = insert_and_reweigh( - &mut μ, &τv, &μ_base, None, - op𝒟, op𝒟norm, + let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats ); // Prune and possibly merge spikes if config.merge_now(&state) { - stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { - let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); - reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) - }); + stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, ®); } + stats.pruned += prune_with_stats(&mut μ); // Update residual @@ -495,7 +268,7 @@ // Give statistics if needed state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) }); @@ -526,38 +299,32 @@ /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_fista_reg< - 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize + F, I, A, Reg, P, const N : usize >( - opA : &'a A, + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, fbconfig : &FBConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, ) -> RNDM<F, N> -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> - + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, - 𝒟::Codomain : RealMapping<F, N>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, - PlotLookup : Plotting<N>, - RNDM<F, N> : SpikeMerging<F>, - Reg : RegTerm<F, N> { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory<IterInfo<F, N>>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, + A : ForwardModel<RNDM<F, N>, F> + + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, + A::PreadjointCodomain : RealMapping<F, N>, + PlotLookup : Plotting<N>, + RNDM<F, N> : SpikeMerging<F>, + Reg : RegTerm<F, N>, + P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, +{ // Set up parameters let config = &fbconfig.generic; - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); - let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap(); + let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); let mut λ = 1.0; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. @@ -583,15 +350,14 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let τv = opA.preadjoint().apply(residual * τ); + let mut τv = opA.preadjoint().apply(residual * τ); // Save current base point let μ_base = μ.clone(); // Insert new spikes and reweigh - let (d, _within_tolerances) = insert_and_reweigh( - &mut μ, &τv, &μ_base, None, - op𝒟, op𝒟norm, + let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats ); @@ -632,7 +398,7 @@ // Give statistics if needed state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ_prev); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ_prev); full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new())) });
--- a/src/forward_model/bias.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/forward_model/bias.rs Thu Jan 23 23:35:28 2025 +0100 @@ -43,7 +43,6 @@ F : Float, Z : Clone + Space + ClosedAdd, A : AdjointProductBoundedBy<Domain, D, FloatType=F, Codomain = Z>, - D : Linear<Domain>, A::Codomain : ClosedAdd, { type FloatType = F; @@ -92,7 +91,6 @@ } -/// TODO: should assume `D` to be positive semi-definite and self-adjoint. #[replace_float_literals(F::cast_from(literal))] impl<'a, F, D, XD, Y, const N : usize> AdjointProductBoundedBy<RNDM<F, N>, D> for ZeroOp<'a, RNDM<F, N>, XD, Y, F>
--- a/src/forward_pdps.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/forward_pdps.rs Thu Jan 23 23:35:28 2025 +0100 @@ -8,30 +8,15 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::mapping::{Mapping, Instance}; +use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; use alg_tools::norms::Norm; use alg_tools::direct_product::Pair; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - //Bounded, -}; -use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::{ BoundedLinear, AXPY, GEMV, Adjointable, IdOp, }; use alg_tools::convex::{Conjugable, Prox}; -use alg_tools::norms::{L2, Linfinity, PairNorm}; +use alg_tools::norms::{L2, PairNorm}; use crate::types::*; use crate::measures::{DiscreteMeasure, Radon, RNDM}; @@ -40,7 +25,6 @@ ForwardModel, AdjointProductPairBoundedBy, }; -use crate::seminorms::DiscreteMeasureOp; use crate::plot::{ SeqPlotter, Plotting, @@ -83,12 +67,12 @@ /// using primal-dual proximal splitting with a forward step. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_forward_pdps_pair< - 'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize + F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize >( - opA : &'a A, + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, config : &ForwardPDPSConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, @@ -102,27 +86,19 @@ where F : Float + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, A : ForwardModel< MeasureZ<F, Z, N>, F, PairNorm<Radon, L2, L2>, - PreadjointCodomain = Pair<BTFN<F, GA, BTA, N>, Z>, + PreadjointCodomain = Pair<S, Z>, > - + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, 𝒟, IdOp<Z>, FloatType=F>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, - Codomain = BTFN<F, G𝒟, BT𝒟, N>>, - BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, + + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F>, + S: DifferentiableRealMapping<F, N>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, PlotLookup : Plotting<N>, RNDM<F, N> : SpikeMerging<F>, Reg : RegTerm<F, N>, + P : ProxPenalty<F, S, Reg, N>, KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y> + GEMV<F, Z> + Adjointable<Z, Y, AdjointCodomain = Z>, @@ -150,11 +126,10 @@ let mut residual = calculate_residual(Pair(&μ, &z), opA, b); // Set up parameters - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); - let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); + let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt(); let nKz = opKz.opnorm_bound(L2, L2); let opIdZ = IdOp::new(); - let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap(); + let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); // We need to satisfy // // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 @@ -196,14 +171,13 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { // Calculate initial transport - let Pair(τv, τz) = opA.preadjoint().apply(residual * τ); + let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ); let z_base = z.clone(); let μ_base = μ.clone(); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. - let (d, _within_tolerances) = insert_and_reweigh( - &mut μ, &τv, &μ_base, None, - op𝒟, op𝒟norm, + let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv, &μ_base, None, τ, ε, &config.insertion, ®, &state, &mut stats, ); @@ -248,7 +222,7 @@ stats.this_iters += 1; state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) });
--- a/src/frank_wolfe.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/frank_wolfe.rs Thu Jan 23 23:35:28 2025 +0100 @@ -425,8 +425,8 @@ /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to /// save intermediate iteration states as images. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fw_reg<'a, F, I, A, GA, BTA, S, Reg, const N : usize>( - opA : &'a A, +pub fn pointsource_fw_reg<F, I, A, GA, BTA, S, Reg, const N : usize>( + opA : &A, b : &A::Observable, reg : Reg, //domain : Cube<F, N>,
--- a/src/main.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/main.rs Thu Jan 23 23:35:28 2025 +0100 @@ -37,8 +37,8 @@ pub mod tolerance; pub mod regularisation; pub mod dataterm; +pub mod prox_penalty; pub mod fb; -pub mod radon_fb; pub mod sliding_fb; pub mod sliding_pdps; pub mod forward_pdps;
--- a/src/measures/discrete.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/measures/discrete.rs Thu Jan 23 23:35:28 2025 +0100 @@ -326,6 +326,18 @@ pub fn set_masses_dvector(&mut self, x : &DVector<F::MixedType>) { self.set_masses(x.iter().map(|&α| F::from_nalgebra_mixed(α))); } + + /// Extracts the masses of the spikes as a [`Vec`]. + pub fn masses_vec(&self) -> Vec<F::MixedType> { + self.iter_masses() + .map(|α| α.to_nalgebra_mixed()) + .collect() + } + + /// Sets the masses of the spikes from the values of a [`Vec`]. + pub fn set_masses_vec(&mut self, x : &Vec<F::MixedType>) { + self.set_masses(x.iter().map(|&α| F::from_nalgebra_mixed(α))); + } } // impl<Domain, F :Num> Index<usize> for DiscreteMeasure<Domain, F> {
--- a/src/pdps.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/pdps.rs Thu Jan 23 23:35:28 2025 +0100 @@ -44,46 +44,36 @@ use clap::ValueEnum; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::loc::Loc; use alg_tools::euclidean::Euclidean; use alg_tools::linops::Mapping; use alg_tools::norms::{ Linfinity, Projection, }; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - SupportGenerator, - LocalAnalysis, -}; use alg_tools::mapping::{RealMapping, Instance}; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::AXPY; use crate::types::*; -use crate::measures::{DiscreteMeasure, RNDM, Radon}; +use crate::measures::{DiscreteMeasure, RNDM}; use crate::measures::merging::SpikeMerging; use crate::forward_model::{ + ForwardModel, AdjointProductBoundedBy, - ForwardModel }; -use crate::seminorms::DiscreteMeasureOp; use crate::plot::{ SeqPlotter, Plotting, PlotLookup }; use crate::fb::{ - FBGenericConfig, - insert_and_reweigh, postprocess, prune_with_stats }; +pub use crate::prox_penalty::{ + FBGenericConfig, + ProxPenalty +}; use crate::regularisation::RegTerm; use crate::dataterm::{ DataTerm, @@ -223,33 +213,29 @@ /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( - opA : &'a A, - b : &'a A::Observable, +pub fn pointsource_pdps_reg<F, I, A, D, Reg, P, const N : usize>( + opA : &A, + b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, pdpsconfig : &PDPSConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, dataterm : D, ) -> RNDM<F, N> -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> - + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, - 𝒟::Codomain : RealMapping<F, N>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - PlotLookup : Plotting<N>, - RNDM<F, N> : SpikeMerging<F>, - D : PDPSDataTerm<F, A::Observable, N>, - Reg : RegTerm<F, N> { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory<IterInfo<F, N>>, + A : ForwardModel<RNDM<F, N>, F> + + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, + A::PreadjointCodomain : RealMapping<F, N>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, + PlotLookup : Plotting<N>, + RNDM<F, N> : SpikeMerging<F>, + D : PDPSDataTerm<F, A::Observable, N>, + Reg : RegTerm<F, N>, + P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, +{ // Check parameters assert!(pdpsconfig.τ0 > 0.0 && @@ -259,8 +245,7 @@ // Set up parameters let config = &pdpsconfig.generic; - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); - let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt(); + let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); let mut τ = pdpsconfig.τ0 / l; let mut σ = pdpsconfig.σ0 / l; let γ = dataterm.factor_of_strong_convexity(); @@ -286,25 +271,21 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let τv = opA.preadjoint().apply(y * τ); + let mut τv = opA.preadjoint().apply(y * τ); // Save current base point let μ_base = μ.clone(); // Insert and reweigh - let (d, _within_tolerances) = insert_and_reweigh( - &mut μ, &τv, &μ_base, None, - op𝒟, op𝒟norm, + let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats ); // Prune and possibly merge spikes if config.merge_now(&state) { - stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { - let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); - reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) - }); + stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, ®); } stats.pruned += prune_with_stats(&mut μ); @@ -323,7 +304,7 @@ stats.this_iters += 1; state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) });
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/prox_penalty.rs Thu Jan 23 23:35:28 2025 +0100 @@ -0,0 +1,158 @@ +/*! +Proximal penalty abstraction +*/ + +use numeric_literals::replace_float_literals; +use alg_tools::types::*; +use serde::{Serialize, Deserialize}; + +use alg_tools::mapping::RealMapping; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::iterate::{ + AlgIteratorIteration, + AlgIterator, +}; +use crate::measures::{ + RNDM, +}; +use crate::types::{ + RefinementSettings, + IterInfo, +}; +use crate::subproblem::{ + InnerSettings, + InnerMethod, +}; +use crate::tolerance::Tolerance; +use crate::measures::merging::SpikeMergingMethod; +use crate::regularisation::RegTerm; + +pub mod wave; +pub mod radon_squared; +pub use radon_squared::RadonSquared; + +/// Settings for the solution of the stepwise optimality condition. +#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] +#[serde(default)] +pub struct FBGenericConfig<F : Float> { + /// Tolerance for point insertion. + pub tolerance : Tolerance<F>, + + /// Stop looking for predual maximum (where to isert a new point) below + /// `tolerance` multiplied by this factor. + /// + /// Not used by [`super::radon_fb`]. + pub insertion_cutoff_factor : F, + + /// Settings for branch and bound refinement when looking for predual maxima + pub refinement : RefinementSettings<F>, + + /// Maximum insertions within each outer iteration + /// + /// Not used by [`super::radon_fb`]. + pub max_insertions : usize, + + /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. + /// + /// Not used by [`super::radon_fb`]. + pub bootstrap_insertions : Option<(usize, usize)>, + + /// Inner method settings + pub inner : InnerSettings<F>, + + /// Spike merging method + pub merging : SpikeMergingMethod<F>, + + /// Tolerance multiplier for merges + pub merge_tolerance_mult : F, + + /// Spike merging method after the last step + pub final_merging : SpikeMergingMethod<F>, + + /// Iterations between merging heuristic tries + pub merge_every : usize, + + // /// Save $μ$ for postprocessing optimisation + // pub postprocessing : bool +} + +#[replace_float_literals(F::cast_from(literal))] +impl<F : Float> Default for FBGenericConfig<F> { + fn default() -> Self { + FBGenericConfig { + tolerance : Default::default(), + insertion_cutoff_factor : 1.0, + refinement : Default::default(), + max_insertions : 100, + //bootstrap_insertions : None, + bootstrap_insertions : Some((10, 1)), + inner : InnerSettings { + method : InnerMethod::Default, + .. Default::default() + }, + merging : SpikeMergingMethod::None, + //merging : Default::default(), + final_merging : Default::default(), + merge_every : 10, + merge_tolerance_mult : 2.0, + // postprocessing : false, + } + } +} + +impl<F : Float> FBGenericConfig<F> { + /// Check if merging should be attempted this iteration + pub fn merge_now<I : AlgIterator>(&self, state : &AlgIteratorIteration<I>) -> bool { + state.iteration() % self.merge_every == 0 + } +} + + +/// Trait for proximal penalties +pub trait ProxPenalty<F, PreadjointCodomain, Reg, const N : usize> +where + F : Float + ToNalgebraRealField, + Reg : RegTerm<F, N>, +{ + type ReturnMapping : RealMapping<F, N>; + + /// Insert new spikes into `μ` to approximately satisfy optimality conditions + /// with the forward step term fixed to `τv`. + /// + /// May return `τv + w` for `w` a subdifferential of the regularisation term `reg`, + /// as well as an indication of whether the tolerance bounds `ε` are satisfied. + /// + /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same + /// spike locations, while `ν_delta` may have different locations. + /// + /// `τv` is mutable to allow [`alg_tools::bisection_tree::btfn::BTFN`] refinement. + /// Actual values of `τv` are not supposed to be mutated. + fn insert_and_reweigh<I>( + &self, + μ : &mut RNDM<F, N>, + τv : &mut PreadjointCodomain, + μ_base : &RNDM<F, N>, + ν_delta: Option<&RNDM<F, N>>, + τ : F, + ε : F, + config : &FBGenericConfig<F>, + reg : &Reg, + state : &AlgIteratorIteration<I>, + stats : &mut IterInfo<F, N>, + ) -> (Option<Self::ReturnMapping>, bool) + where + I : AlgIterator; + + + /// Merge spikes, if possible. + fn merge_spikes( + &self, + μ : &mut RNDM<F, N>, + τv : &mut PreadjointCodomain, + μ_base : &RNDM<F, N>, + τ : F, + ε : F, + config : &FBGenericConfig<F>, + reg : &Reg, + ) -> usize; +}
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/prox_penalty/radon_squared.rs Thu Jan 23 23:35:28 2025 +0100 @@ -0,0 +1,170 @@ +/*! +Solver for the point source localisation problem using a simplified forward-backward splitting method. + +Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map. +*/ + +use numeric_literals::replace_float_literals; +use serde::{Serialize, Deserialize}; +use nalgebra::DVector; + +use alg_tools::iterate::{ + AlgIteratorIteration, + AlgIterator +}; +use alg_tools::norms::L2; +use alg_tools::linops::Mapping; +use alg_tools::bisection_tree::{ + BTFN, + Bounds, + BTSearch, + SupportGenerator, + LocalAnalysis, +}; +use alg_tools::mapping::RealMapping; +use alg_tools::nalgebra_support::ToNalgebraRealField; + +use crate::types::*; +use crate::measures::{ + RNDM, + DeltaMeasure, + Radon, +}; +use crate::measures::merging::SpikeMerging; +use crate::regularisation::RegTerm; +use crate::forward_model::{ + ForwardModel, + AdjointProductBoundedBy +}; +use super::{ + FBGenericConfig, + ProxPenalty, +}; + +/// Radon-norm squared proximal penalty + +#[derive(Copy,Clone,Serialize,Deserialize)] +pub struct RadonSquared; + +#[replace_float_literals(F::cast_from(literal))] +impl<F, GA, BTA, S, Reg, const N : usize> +ProxPenalty<F, BTFN<F, GA, BTA, N>, Reg, N> for RadonSquared +where + F : Float + ToNalgebraRealField, + GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, + BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, + S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, + Reg : RegTerm<F, N>, + RNDM<F, N> : SpikeMerging<F>, +{ + type ReturnMapping = BTFN<F, GA, BTA, N>; + + fn insert_and_reweigh<I>( + &self, + μ : &mut RNDM<F, N>, + τv : &mut BTFN<F, GA, BTA, N>, + μ_base : &RNDM<F, N>, + ν_delta: Option<&RNDM<F, N>>, + τ : F, + ε : F, + config : &FBGenericConfig<F>, + reg : &Reg, + _state : &AlgIteratorIteration<I>, + stats : &mut IterInfo<F, N>, + ) -> (Option<Self::ReturnMapping>, bool) + where + I : AlgIterator + { + assert!(ν_delta.is_none(), "Transport not implemented for Radon-squared prox term"); + + let mut y = μ_base.masses_vec(); + + 'i_and_w: for i in 0..=1 { + // Optimise weights + if μ.len() > 0 { + // Form finite-dimensional subproblem. The subproblem references to the original μ^k + // from the beginning of the iteration are all contained in the immutable c and g. + // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional + // problems have not yet been updated to sign change. + let g̃ = DVector::from_iterator(μ.len(), + μ.iter_locations() + .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ)))); + let mut x = μ.masses_dvector(); + // Ugly hack because DVector::push doesn't push but copies. + let yvec = DVector::from_column_slice(y.as_slice()); + // Solve finite-dimensional subproblem. + stats.inner_iters += reg.solve_findim_l1squared(&yvec, &g̃, τ, &mut x, ε, config); + + // Update masses of μ based on solution of finite-dimensional subproblem. + μ.set_masses_dvector(&x); + } + + if i>0 { + // Simple debugging test to see if more inserts would be needed. Doesn't seem so. + //let n = μ.dist_matching(μ_base); + //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n)); + break 'i_and_w + } + + // Calculate ‖μ - μ_base‖_ℳ + let n = μ.dist_matching(μ_base); + + // Find a spike to insert, if needed. + // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, + // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. + match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) { + None => { break 'i_and_w }, + Some((ξ, _v_ξ, _in_bounds)) => { + // Weight is found out by running the finite-dimensional optimisation algorithm + // above + *μ += DeltaMeasure { x : ξ, α : 0.0 }; + //*μ_base += DeltaMeasure { x : ξ, α : 0.0 }; + y.push(0.0.to_nalgebra_mixed()); + stats.inserted += 1; + } + }; + } + + (None, true) + } + + fn merge_spikes( + &self, + μ : &mut RNDM<F, N>, + τv : &mut BTFN<F, GA, BTA, N>, + μ_base : &RNDM<F, N>, + τ : F, + ε : F, + config : &FBGenericConfig<F>, + reg : &Reg, + ) -> usize + { + μ.merge_spikes(config.merging, |μ_candidate| { + // Important: μ_candidate's new points are afterwards, + // and do not conflict with μ_base. + // TODO: could simplify to requiring μ_base instead of μ_radon. + // but may complicate with sliding base's exgtra points that need to be + // after μ_candidate's extra points. + // TODO: doesn't seem to work, maybe need to merge μ_base as well? + // Although that doesn't seem to make sense. + let μ_radon = μ_candidate.sub_matching(μ_base); + reg.verify_merge_candidate_radonsq(τv, μ_candidate, τ, ε, &config, &μ_radon) + //let n = μ_candidate.dist_matching(μ_base); + //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() + }) + } +} + + +impl<F, A, const N : usize> AdjointProductBoundedBy<RNDM<F, N>, RadonSquared> +for A +where + F : Float, + A : ForwardModel<RNDM<F, N>, F> +{ + type FloatType = F; + + fn adjoint_product_bound(&self, _ : &RadonSquared) -> Option<Self::FloatType> { + self.opnorm_bound(Radon, L2).powi(2).into() + } +}
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/prox_penalty/wave.rs Thu Jan 23 23:35:28 2025 +0100 @@ -0,0 +1,182 @@ +/*! +Basic proximal penalty based on convolution operators $𝒟$. + */ + +use numeric_literals::replace_float_literals; +use nalgebra::DVector; +use colored::Colorize; + +use alg_tools::types::*; +use alg_tools::loc::Loc; +use alg_tools::mapping::{Mapping, RealMapping}; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::norms::Linfinity; +use alg_tools::iterate::{ + AlgIteratorIteration, + AlgIterator, +}; +use alg_tools::bisection_tree::{ + BTFN, + PreBTFN, + Bounds, + BTSearch, + SupportGenerator, + LocalAnalysis, + BothGenerators, +}; +use crate::measures::{ + RNDM, + DeltaMeasure, + Radon, +}; +use crate::measures::merging::{ + SpikeMerging, +}; +use crate::seminorms::DiscreteMeasureOp; +use crate::types::{ + IterInfo, +}; +use crate::measures::merging::SpikeMergingMethod; +use crate::regularisation::RegTerm; +use super::{ProxPenalty, FBGenericConfig}; + +#[replace_float_literals(F::cast_from(literal))] +impl<F, GA, BTA, S, Reg, 𝒟, G𝒟, K, const N : usize> +ProxPenalty<F, BTFN<F, GA, BTA, N>, Reg, N> for 𝒟 +where + F : Float + ToNalgebraRealField, + GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, + BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, + S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, + G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, + 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, + 𝒟::Codomain : RealMapping<F, N>, + K : RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, + Reg : RegTerm<F, N>, + RNDM<F, N> : SpikeMerging<F>, +{ + type ReturnMapping = BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>; + + fn insert_and_reweigh<I>( + &self, + μ : &mut RNDM<F, N>, + τv : &mut BTFN<F, GA, BTA, N>, + μ_base : &RNDM<F, N>, + ν_delta: Option<&RNDM<F, N>>, + τ : F, + ε : F, + config : &FBGenericConfig<F>, + reg : &Reg, + state : &AlgIteratorIteration<I>, + stats : &mut IterInfo<F, N>, + ) -> (Option<BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>>, bool) + where + I : AlgIterator + { + + // TODO: is this inefficient to do in every iteration? + let op𝒟norm = self.opnorm_bound(Radon, Linfinity); + + // Maximum insertion count and measure difference calculation depend on insertion style. + let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { + (i, Some((l, k))) if i <= l => (k, false), + _ => (config.max_insertions, !state.is_quiet()), + }; + + let ω0 = match ν_delta { + None => self.apply(μ_base), + Some(ν) => self.apply(μ_base + ν), + }; + + // Add points to support until within error tolerance or maximum insertion count reached. + let mut count = 0; + let (within_tolerances, d) = 'insertion: loop { + if μ.len() > 0 { + // Form finite-dimensional subproblem. The subproblem references to the original μ^k + // from the beginning of the iteration are all contained in the immutable c and g. + // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional + // problems have not yet been updated to sign change. + let à = self.findim_matrix(μ.iter_locations()); + let g̃ = DVector::from_iterator(μ.len(), + μ.iter_locations() + .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) + .map(F::to_nalgebra_mixed)); + let mut x = μ.masses_dvector(); + + // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. + // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ + // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ + // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 + // = n |𝒟| |x|_2, where n is the number of points. Therefore + let Ã_normest = op𝒟norm * F::cast_from(μ.len()); + + // Solve finite-dimensional subproblem. + stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); + + // Update masses of μ based on solution of finite-dimensional subproblem. + μ.set_masses_dvector(&x); + } + + // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality + // conditions in the predual space, and finding new points for insertion, if necessary. + let mut d = &*τv + match ν_delta { + None => self.preapply(μ.sub_matching(μ_base)), + Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν) + }; + + // If no merging heuristic is used, let's be more conservative about spike insertion, + // and skip it after first round. If merging is done, being more greedy about spike + // insertion also seems to improve performance. + let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { + false + } else { + count > 0 + }; + + // Find a spike to insert, if needed + let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( + &mut d, τ, ε, skip_by_rough_check, config + ) { + None => break 'insertion (true, d), + Some(res) => res, + }; + + // Break if maximum insertion count reached + if count >= max_insertions { + break 'insertion (in_bounds, d) + } + + // No point in optimising the weight here; the finite-dimensional algorithm is fast. + *μ += DeltaMeasure { x : ξ, α : 0.0 }; + count += 1; + stats.inserted += 1; + }; + + if !within_tolerances && warn_insertions { + // Complain (but continue) if we failed to get within tolerances + // by inserting more points. + let err = format!("Maximum insertions reached without achieving \ + subproblem solution tolerance"); + println!("{}", err.red()); + } + + (Some(d), within_tolerances) + } + + fn merge_spikes( + &self, + μ : &mut RNDM<F, N>, + τv : &mut BTFN<F, GA, BTA, N>, + μ_base : &RNDM<F, N>, + τ : F, + ε : F, + config : &FBGenericConfig<F>, + reg : &Reg, + ) -> usize + { + μ.merge_spikes(config.merging, |μ_candidate| { + let mut d = &*τv + self.preapply(μ_candidate.sub_matching(μ_base)); + reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config) + }) + } +}
--- a/src/radon_fb.rs Mon Jan 06 11:32:57 2025 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,404 +0,0 @@ -/*! -Solver for the point source localisation problem using a simplified forward-backward splitting method. - -Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map. -*/ - -use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; -use colored::Colorize; -use nalgebra::DVector; - -use alg_tools::iterate::{ - AlgIteratorFactory, - AlgIteratorIteration, - AlgIterator -}; -use alg_tools::euclidean::Euclidean; -use alg_tools::linops::Mapping; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::bisection_tree::{ - BTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, -}; -use alg_tools::mapping::RealMapping; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::L2; - -use crate::types::*; -use crate::measures::{ - RNDM, - DiscreteMeasure, - DeltaMeasure, - Radon, -}; -use crate::measures::merging::{ - SpikeMergingMethod, - SpikeMerging, -}; -use crate::forward_model::ForwardModel; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; -use crate::regularisation::RegTerm; -use crate::dataterm::{ - calculate_residual, - L2Squared, - DataTerm, -}; - -use crate::fb::{ - FBGenericConfig, - postprocess, - prune_with_stats -}; - -/// Settings for [`pointsource_radon_fb_reg`]. -#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] -#[serde(default)] -pub struct RadonFBConfig<F : Float> { - /// Step length scaling - pub τ0 : F, - /// Generic parameters - pub insertion : FBGenericConfig<F>, -} - -#[replace_float_literals(F::cast_from(literal))] -impl<F : Float> Default for RadonFBConfig<F> { - fn default() -> Self { - RadonFBConfig { - τ0 : 0.99, - insertion : Default::default() - } - } -} - -#[replace_float_literals(F::cast_from(literal))] -pub(crate) fn insert_and_reweigh< - 'a, F, GA, BTA, S, Reg, I, const N : usize ->( - μ : &mut RNDM<F, N>, - τv : &mut BTFN<F, GA, BTA, N>, - μ_base : &mut RNDM<F, N>, - //_ν_delta: Option<&RNDM<F, N>>, - τ : F, - ε : F, - config : &FBGenericConfig<F>, - reg : &Reg, - _state : &AlgIteratorIteration<I>, - stats : &mut IterInfo<F, N>, -) -where F : Float + ToNalgebraRealField, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - RNDM<F, N> : SpikeMerging<F>, - Reg : RegTerm<F, N>, - I : AlgIterator { - - 'i_and_w: for i in 0..=1 { - // Optimise weights - if μ.len() > 0 { - // Form finite-dimensional subproblem. The subproblem references to the original μ^k - // from the beginning of the iteration are all contained in the immutable c and g. - // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional - // problems have not yet been updated to sign change. - let g̃ = DVector::from_iterator(μ.len(), - μ.iter_locations() - .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ)))); - let mut x = μ.masses_dvector(); - let y = μ_base.masses_dvector(); - - // Solve finite-dimensional subproblem. - stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config); - - // Update masses of μ based on solution of finite-dimensional subproblem. - μ.set_masses_dvector(&x); - } - - if i>0 { - // Simple debugging test to see if more inserts would be needed. Doesn't seem so. - //let n = μ.dist_matching(μ_base); - //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n)); - break 'i_and_w - } - - // Calculate ‖μ - μ_base‖_ℳ - let n = μ.dist_matching(μ_base); - - // Find a spike to insert, if needed. - // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, - // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. - match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) { - None => { break 'i_and_w }, - Some((ξ, _v_ξ, _in_bounds)) => { - // Weight is found out by running the finite-dimensional optimisation algorithm - // above - *μ += DeltaMeasure { x : ξ, α : 0.0 }; - *μ_base += DeltaMeasure { x : ξ, α : 0.0 }; - stats.inserted += 1; - } - }; - } -} - - -/// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting. -/// -/// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the -/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. -/// Finally, the `iterator` is an outer loop verbosity and iteration count control -/// as documented in [`alg_tools::iterate`]. -/// -/// For details on the mathematical formulation, see the [module level](self) documentation. -/// -/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of -/// sums of simple functions usign bisection trees, and the related -/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions -/// active at a specific points, and to maximise their sums. Through the implementation of the -/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features -/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. -/// -/// Returns the final iterate. -#[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_radon_fb_reg< - 'a, F, I, A, GA, BTA, S, Reg, const N : usize ->( - opA : &'a A, - b : &A::Observable, - reg : Reg, - fbconfig : &RadonFBConfig<F>, - iterator : I, - mut _plotter : SeqPlotter<F, N>, -) -> RNDM<F, N> -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, - RNDM<F, N> : SpikeMerging<F>, - Reg : RegTerm<F, N> { - - // Set up parameters - let config = &fbconfig.insertion; - // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ - // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such - // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. - let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2); - // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled - // by τ compared to the conditional gradient approach. - let tolerance = config.tolerance * τ * reg.tolerance_scaling(); - let mut ε = tolerance.initial(); - - // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut residual = -b; - - // Statistics - let full_stats = |residual : &A::Observable, - μ : &RNDM<F, N>, - ε, stats| IterInfo { - value : residual.norm2_squared_div2() + reg.apply(μ), - n_spikes : μ.len(), - ε, - // postprocessing: config.postprocessing.then(|| μ.clone()), - .. stats - }; - let mut stats = IterInfo::new(); - - // Run the algorithm - for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { - // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(residual * τ); - - // Save current base point - let mut μ_base = μ.clone(); - - // Insert and reweigh - insert_and_reweigh( - &mut μ, &mut τv, &mut μ_base, //None, - τ, ε, - config, ®, &state, &mut stats - ); - - // Prune and possibly merge spikes - assert!(μ_base.len() <= μ.len()); - if config.merge_now(&state) { - stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { - // Important: μ_candidate's new points are afterwards, - // and do not conflict with μ_base. - // TODO: could simplify to requiring μ_base instead of μ_radon. - // but may complicate with sliding base's exgtra points that need to be - // after μ_candidate's extra points. - // TODO: doesn't seem to work, maybe need to merge μ_base as well? - // Although that doesn't seem to make sense. - let μ_radon = μ_candidate.sub_matching(&μ_base); - reg.verify_merge_candidate_radonsq(&mut τv, μ_candidate, τ, ε, &config, &μ_radon) - //let n = μ_candidate.dist_matching(μ_base); - //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() - }); - } - stats.pruned += prune_with_stats(&mut μ); - - // Update residual - residual = calculate_residual(&μ, opA, b); - - let iter = state.iteration(); - stats.this_iters += 1; - - // Give statistics if needed - state.if_verbose(|| { - full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) - }); - - // Update main tolerance for next iteration - ε = tolerance.update(ε, iter); - } - - postprocess(μ, config, L2Squared, opA, b) -} - -/// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting. -/// -/// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the -/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. -/// Finally, the `iterator` is an outer loop verbosity and iteration count control -/// as documented in [`alg_tools::iterate`]. -/// -/// For details on the mathematical formulation, see the [module level](self) documentation. -/// -/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of -/// sums of simple functions usign bisection trees, and the related -/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions -/// active at a specific points, and to maximise their sums. Through the implementation of the -/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features -/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. -/// -/// Returns the final iterate. -#[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_radon_fista_reg< - 'a, F, I, A, GA, BTA, S, Reg, const N : usize ->( - opA : &'a A, - b : &A::Observable, - reg : Reg, - fbconfig : &RadonFBConfig<F>, - iterator : I, - mut plotter : SeqPlotter<F, N>, -) -> RNDM<F, N> -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, - PlotLookup : Plotting<N>, - RNDM<F, N> : SpikeMerging<F>, - Reg : RegTerm<F, N> { - - // Set up parameters - let config = &fbconfig.insertion; - // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ - // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such - // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. - let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2); - let mut λ = 1.0; - // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled - // by τ compared to the conditional gradient approach. - let tolerance = config.tolerance * τ * reg.tolerance_scaling(); - let mut ε = tolerance.initial(); - - // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut μ_prev = DiscreteMeasure::new(); - let mut residual = -b; - let mut warned_merging = false; - - // Statistics - let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo { - value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), - n_spikes : ν.len(), - ε, - // postprocessing: config.postprocessing.then(|| ν.clone()), - .. stats - }; - let mut stats = IterInfo::new(); - - // Run the algorithm - for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { - // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(residual * τ); - - // Save current base point - let mut μ_base = μ.clone(); - - // Insert new spikes and reweigh - insert_and_reweigh( - &mut μ, &mut τv, &mut μ_base, //None, - τ, ε, - config, ®, &state, &mut stats - ); - - // (Do not) merge spikes. - if config.merge_now(&state) { - match config.merging { - SpikeMergingMethod::None => { }, - _ => if !warned_merging { - let err = format!("Merging not supported for μFISTA"); - println!("{}", err.red()); - warned_merging = true; - } - } - } - - // Update inertial prameters - let λ_prev = λ; - λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); - let θ = λ / λ_prev - λ; - - // Perform inertial update on μ. - // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ - // and μ_prev have zero weight. Since both have weights from the finite-dimensional - // subproblem with a proximal projection step, this is likely to happen when the - // spike is not needed. A copy of the pruned μ without artithmetic performed is - // stored in μ_prev. - let n_before_prune = μ.len(); - μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); - debug_assert!(μ.len() <= n_before_prune); - stats.pruned += n_before_prune - μ.len(); - - // Update residual - residual = calculate_residual(&μ, opA, b); - - let iter = state.iteration(); - stats.this_iters += 1; - - // Give statistics if needed - state.if_verbose(|| { - plotter.plot_spikes(iter, Option::<&S>::None, Some(&τv), &μ_prev); - full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new())) - }); - - // Update main tolerance for next iteration - ε = tolerance.update(ε, iter); - } - - postprocess(μ_prev, config, L2Squared, opA, b) -}
--- a/src/run.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/run.rs Thu Jan 23 23:35:28 2025 +0100 @@ -72,11 +72,6 @@ pointsource_fb_reg, pointsource_fista_reg, }; -use crate::radon_fb::{ - RadonFBConfig, - pointsource_radon_fb_reg, - pointsource_radon_fista_reg, -}; use crate::sliding_fb::{ SlidingFBConfig, TransportConfig, @@ -114,22 +109,33 @@ L1, L2Squared, }; +use crate::prox_penalty::{ + RadonSquared, + //ProxPenalty, +}; use alg_tools::norms::{L2, NormExponent}; use alg_tools::operator_arithmetic::Weighted; use anyhow::anyhow; +/// Available proximal terms +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +pub enum ProxTerm { + /// Partial-to-wave operator 𝒟. + Wave, + /// Radon-norm squared + RadonSquared +} + /// Available algorithms and their configurations #[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub enum AlgorithmConfig<F : Float> { - FB(FBConfig<F>), - FISTA(FBConfig<F>), + FB(FBConfig<F>, ProxTerm), + FISTA(FBConfig<F>, ProxTerm), FW(FWConfig<F>), - PDPS(PDPSConfig<F>), - RadonFB(RadonFBConfig<F>), - RadonFISTA(RadonFBConfig<F>), - SlidingFB(SlidingFBConfig<F>), - ForwardPDPS(ForwardPDPSConfig<F>), - SlidingPDPS(SlidingPDPSConfig<F>), + PDPS(PDPSConfig<F>, ProxTerm), + SlidingFB(SlidingFBConfig<F>, ProxTerm), + ForwardPDPS(ForwardPDPSConfig<F>, ProxTerm), + SlidingPDPS(SlidingPDPSConfig<F>, ProxTerm), } fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { @@ -165,45 +171,35 @@ use AlgorithmConfig::*; match self { - FB(fb) => FB(FBConfig { + FB(fb, prox) => FB(FBConfig { τ0 : cli.tau0.unwrap_or(fb.τ0), generic : override_fb_generic(fb.generic), .. fb - }), - FISTA(fb) => FISTA(FBConfig { + }, prox), + FISTA(fb, prox) => FISTA(FBConfig { τ0 : cli.tau0.unwrap_or(fb.τ0), generic : override_fb_generic(fb.generic), .. fb - }), - PDPS(pdps) => PDPS(PDPSConfig { + }, prox), + PDPS(pdps, prox) => PDPS(PDPSConfig { τ0 : cli.tau0.unwrap_or(pdps.τ0), σ0 : cli.sigma0.unwrap_or(pdps.σ0), acceleration : cli.acceleration.unwrap_or(pdps.acceleration), generic : override_fb_generic(pdps.generic), .. pdps - }), + }, prox), FW(fw) => FW(FWConfig { merging : cli.merging.clone().unwrap_or(fw.merging), tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), .. fw }), - RadonFB(fb) => RadonFB(RadonFBConfig { - τ0 : cli.tau0.unwrap_or(fb.τ0), - insertion : override_fb_generic(fb.insertion), - .. fb - }), - RadonFISTA(fb) => RadonFISTA(RadonFBConfig { - τ0 : cli.tau0.unwrap_or(fb.τ0), - insertion : override_fb_generic(fb.insertion), - .. fb - }), - SlidingFB(sfb) => SlidingFB(SlidingFBConfig { + SlidingFB(sfb, prox) => SlidingFB(SlidingFBConfig { τ0 : cli.tau0.unwrap_or(sfb.τ0), transport : override_transport(sfb.transport), insertion : override_fb_generic(sfb.insertion), .. sfb - }), - SlidingPDPS(spdps) => SlidingPDPS(SlidingPDPSConfig { + }, prox), + SlidingPDPS(spdps, prox) => SlidingPDPS(SlidingPDPSConfig { τ0 : cli.tau0.unwrap_or(spdps.τ0), σp0 : cli.sigmap0.unwrap_or(spdps.σp0), σd0 : cli.sigma0.unwrap_or(spdps.σd0), @@ -211,15 +207,15 @@ transport : override_transport(spdps.transport), insertion : override_fb_generic(spdps.insertion), .. spdps - }), - ForwardPDPS(fpdps) => ForwardPDPS(ForwardPDPSConfig { + }, prox), + ForwardPDPS(fpdps, prox) => ForwardPDPS(ForwardPDPSConfig { τ0 : cli.tau0.unwrap_or(fpdps.τ0), σp0 : cli.sigmap0.unwrap_or(fpdps.σp0), σd0 : cli.sigma0.unwrap_or(fpdps.σd0), //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), insertion : override_fb_generic(fpdps.insertion), .. fpdps - }), + }, prox), } } } @@ -250,12 +246,6 @@ /// The μPDPS primal-dual proximal splitting method #[clap(name = "pdps")] PDPS, - /// The RadonFB forward-backward method - #[clap(name = "radon_fb")] - RadonFB, - /// The RadonFISTA inertial forward-backward method - #[clap(name = "radon_fista")] - RadonFISTA, /// The sliding FB method #[clap(name = "sliding_fb", alias = "sfb")] SlidingFB, @@ -265,6 +255,27 @@ /// The PDPS method with a forward step for the smooth function #[clap(name = "forward_pdps", alias = "fpdps")] ForwardPDPS, + + // Radon variants + + /// The μFB forward-backward method with radon-norm squared proximal term + #[clap(name = "radon_fb")] + RadonFB, + /// The μFISTA inertial forward-backward method with radon-norm squared proximal term + #[clap(name = "radon_fista")] + RadonFISTA, + /// The μPDPS primal-dual proximal splitting method with radon-norm squared proximal term + #[clap(name = "radon_pdps")] + RadonPDPS, + /// The sliding FB method with radon-norm squared proximal term + #[clap(name = "radon_sliding_fb", alias = "radon_sfb")] + RadonSlidingFB, + /// The sliding PDPS method with radon-norm squared proximal term + #[clap(name = "radon_sliding_pdps", alias = "radon_spdps")] + RadonSlidingPDPS, + /// The PDPS method with a forward step for the smooth function with radon-norm squared proximal term + #[clap(name = "radon_forward_pdps", alias = "radon_fpdps")] + RadonForwardPDPS, } impl DefaultAlgorithm { @@ -272,19 +283,26 @@ pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { use DefaultAlgorithm::*; match *self { - FB => AlgorithmConfig::FB(Default::default()), - FISTA => AlgorithmConfig::FISTA(Default::default()), + FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave), + FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave), FW => AlgorithmConfig::FW(Default::default()), FWRelax => AlgorithmConfig::FW(FWConfig{ variant : FWVariant::Relaxed, .. Default::default() }), - PDPS => AlgorithmConfig::PDPS(Default::default()), - RadonFB => AlgorithmConfig::RadonFB(Default::default()), - RadonFISTA => AlgorithmConfig::RadonFISTA(Default::default()), - SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), - SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default()), - ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default()), + PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave), + SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave), + SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave), + ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave), + + // Radon variants + + RadonFB => AlgorithmConfig::FB(Default::default(), ProxTerm::RadonSquared), + RadonFISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::RadonSquared), + RadonPDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::RadonSquared), + RadonSlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::RadonSquared), + RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::RadonSquared), + RadonForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::RadonSquared), } } @@ -602,7 +620,7 @@ } #[replace_float_literals(F::cast_from(literal))] -impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for +impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N : usize> RunnableExperiment<F> for Named<ExperimentV2<F, NoiseDistr, S, K, P, N>> where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, @@ -628,7 +646,13 @@ DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, RNDM<F, N> : SpikeMerging<F>, - NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug + NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, + // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, + // PreadjointCodomain : Space + Bounded<F> + DifferentiableRealMapping<F, N>, + // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, + // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, + // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, + // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, { fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> { @@ -663,7 +687,7 @@ let mut rng = StdRng::seed_from_u64(noise_seed); // Generate the data and calculate SSNR statistic - let b_hat : DVector<_> = opA.apply(μ_hat); + let b_hat = opA.apply(μ_hat); let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); let b = &b_hat + &noise; // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField @@ -683,128 +707,157 @@ |alg, iterator, plotter, running| { let μ = match alg { - AlgorithmConfig::FB(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + AlgorithmConfig::FB(ref algconfig, prox) => { + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fb_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fb_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_fb_reg( + &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter + ) + }), + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_fb_reg( + &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter + ) + }), _ => Err(NotImplemented) } }, - AlgorithmConfig::FISTA(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + AlgorithmConfig::FISTA(ref algconfig, prox) => { + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fista_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fista_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), - _ => Err(NotImplemented), - } - }, - AlgorithmConfig::RadonFB(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); - pointsource_radon_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), algconfig, + pointsource_fista_reg( + &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); - pointsource_radon_fb_reg( - &opA, &b, RadonRegTerm(α), algconfig, + pointsource_fista_reg( + &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter ) }), _ => Err(NotImplemented), } }, - AlgorithmConfig::RadonFISTA(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ - print!("{running}"); - pointsource_radon_fista_reg( - &opA, &b, NonnegRadonRegTerm(α), algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ - print!("{running}"); - pointsource_radon_fista_reg( - &opA, &b, RadonRegTerm(α), algconfig, - iterator, plotter - ) - }), - _ => Err(NotImplemented), - } - }, - AlgorithmConfig::SlidingFB(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + AlgorithmConfig::SlidingFB(ref algconfig, prox) => { + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_fb_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_fb_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_sliding_fb_reg( + &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter + ) + }), + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_sliding_fb_reg( + &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter + ) + }), _ => Err(NotImplemented), } }, - AlgorithmConfig::PDPS(ref algconfig) => { + AlgorithmConfig::PDPS(ref algconfig, prox) => { print!("{running}"); - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L2Squared ) }), - (Regularisation::Radon(α),DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L2Squared ) }), - (Regularisation::NonnegRadon(α), DataTerm::L1) => Ok({ + (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L1 ) }), - (Regularisation::Radon(α), DataTerm::L1) => Ok({ + (Regularisation::Radon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L1 ) }), + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + pointsource_pdps_reg( + &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, L2Squared + ) + }), + (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + pointsource_pdps_reg( + &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, L2Squared + ) + }), + (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ + pointsource_pdps_reg( + &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, L1 + ) + }), + (Regularisation::Radon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ + pointsource_pdps_reg( + &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, L1 + ) + }), + // _ => Err(NotImplemented), } }, AlgorithmConfig::FW(ref algconfig) => { @@ -831,7 +884,7 @@ #[replace_float_literals(F::cast_from(literal))] -impl<F, NoiseDistr, S, K, P, B, const N : usize> RunnableExperiment<F> for +impl<F, NoiseDistr, S, K, P, B, /*PreadjointCodomain,*/ const N : usize> RunnableExperiment<F> for Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, @@ -859,6 +912,12 @@ RNDM<F, N> : SpikeMerging<F>, NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, + // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, + // PreadjointCodomain : Bounded<F> + DifferentiableRealMapping<F, N>, + // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, + // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, + // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, + // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, { fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> { @@ -937,9 +996,9 @@ |alg, iterator, plotter, running| { let Pair(μ, z) = match alg { - AlgorithmConfig::ForwardPDPS(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + AlgorithmConfig::ForwardPDPS(ref algconfig, prox) => { + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_forward_pdps_pair( &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, @@ -947,7 +1006,7 @@ /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_forward_pdps_pair( &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, @@ -955,12 +1014,28 @@ /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_forward_pdps_pair( + &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, + /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), + ) + }), + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_forward_pdps_pair( + &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, + /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), + ) + }), _ => Err(NotImplemented) } }, - AlgorithmConfig::SlidingPDPS(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + AlgorithmConfig::SlidingPDPS(ref algconfig, prox) => { + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_pdps_pair( &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, @@ -968,7 +1043,7 @@ /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_pdps_pair( &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, @@ -976,6 +1051,22 @@ /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_sliding_pdps_pair( + &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, + /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), + ) + }), + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_sliding_pdps_pair( + &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, + /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), + ) + }), _ => Err(NotImplemented) } },
--- a/src/sliding_fb.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/sliding_fb.rs Thu Jan 23 23:35:28 2025 +0100 @@ -12,38 +12,19 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance}; +use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; use alg_tools::norms::Norm; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - //Bounded, -}; -use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::{L2, Linfinity}; +use alg_tools::norms::L2; use crate::types::*; use crate::measures::{DiscreteMeasure, Radon, RNDM}; -use crate::measures::merging::{ - SpikeMergingMethod, - SpikeMerging, -}; +use crate::measures::merging::SpikeMerging; use crate::forward_model::{ ForwardModel, AdjointProductBoundedBy, LipschitzValues, }; -use crate::seminorms::DiscreteMeasureOp; //use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, @@ -151,7 +132,7 @@ Observable : Euclidean<F, Output=Observable>, for<'a> &'a Observable : Instance<Observable>, //for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, - D : DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F, N>>, + D : DifferentiableRealMapping<F, N>, { use TransportStepLength::*; @@ -353,40 +334,29 @@ /// The parametrisation is as for [`pointsource_fb_reg`]. /// Inertia is currently not supported. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>( - opA : &'a A, +pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N : usize>( + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, config : &SlidingFBConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, ) -> RNDM<F, N> -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, - for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, - A::PreadjointCodomain : DifferentiableMapping< - Loc<F, N>, DerivativeDomain=Loc<F, N>, Codomain=F - >, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, - A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> - + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, - //+ TransportLipschitz<L2Squared, FloatType=F>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, - Codomain = BTFN<F, G𝒟, BT𝒟, N>>, - BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> - + DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F,N>>, - K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, - PlotLookup : Plotting<N>, - RNDM<F, N> : SpikeMerging<F>, - Reg : SlidingRegTerm<F, N> { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory<IterInfo<F, N>>, + A : ForwardModel<RNDM<F, N>, F> + + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, + //+ TransportLipschitz<L2Squared, FloatType=F>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, + for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, + A::PreadjointCodomain : DifferentiableRealMapping<F, N>, + RNDM<F, N> : SpikeMerging<F>, + Reg : SlidingRegTerm<F, N>, + P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, + PlotLookup : Plotting<N>, +{ // Check parameters assert!(config.τ0 > 0.0, "Invalid step length parameter"); @@ -398,13 +368,12 @@ let mut residual = -b; // Has to equal $Aμ-b$. // Set up parameters - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); let opAnorm = opA.opnorm_bound(Radon, L2); //let max_transport = config.max_transport.scale // * reg.radon_norm_bound(b.norm2_squared() / 2.0); //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; let ℓ = 0.0; - let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap(); + let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v @@ -446,15 +415,14 @@ // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. - let (d, _within_tolerances, τv̆) = 'adapt_transport: loop { + let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); - let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); + let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. - let (d, within_tolerances) = insert_and_reweigh( - &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), - op𝒟, op𝒟norm, + let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, &config.insertion, ®, &state, &mut stats, ); @@ -464,7 +432,7 @@ &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, ε, &config.transport ) { - break 'adapt_transport (d, within_tolerances, τv̆) + break 'adapt_transport (maybe_d, within_tolerances, τv̆) } }; @@ -480,20 +448,20 @@ (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) }); - // Merge spikes. - // This expects the prune below to prune γ. - // TODO: This may not work correctly in all cases. - let ins = &config.insertion; - if ins.merge_now(&state) { - if let SpikeMergingMethod::None = ins.merging { - } else { - stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { - let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; - let mut d = &τv̆ + op𝒟.preapply(ν); - reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) - }); - } - } + // // Merge spikes. + // // This expects the prune below to prune γ. + // // TODO: This may not work correctly in all cases. + // let ins = &config.insertion; + // if ins.merge_now(&state) { + // if let SpikeMergingMethod::None = ins.merging { + // } else { + // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { + // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; + // let mut d = &τv̆ + op𝒟.preapply(ν); + // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) + // }); + // } + // } // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the // latter needs to be pruned when μ is. @@ -514,7 +482,7 @@ // Give statistics if requested state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) });
--- a/src/sliding_pdps.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/sliding_pdps.rs Thu Jan 23 23:35:28 2025 +0100 @@ -11,30 +11,15 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; use alg_tools::norms::Norm; use alg_tools::direct_product::Pair; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - //Bounded, -}; -use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::{ BoundedLinear, AXPY, GEMV, Adjointable, IdOp, }; use alg_tools::convex::{Conjugable, Prox}; -use alg_tools::norms::{L2, Linfinity, PairNorm}; +use alg_tools::norms::{L2, PairNorm}; use crate::types::*; use crate::measures::{DiscreteMeasure, Radon, RNDM}; @@ -45,7 +30,6 @@ LipschitzValues, }; // use crate::transport::TransportLipschitz; -use crate::seminorms::DiscreteMeasureOp; //use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, @@ -101,12 +85,12 @@ /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_sliding_pdps_pair< - 'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize + F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize >( - opA : &'a A, + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, config : &SlidingPDPSConfig<F>, iterator : I, mut plotter : SeqPlotter<F, N>, @@ -120,36 +104,25 @@ where F : Float + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, - for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, - for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, - BTFN<F, GA, BTA, N> : DifferentiableRealMapping<F, N>, - GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, A : ForwardModel< MeasureZ<F, Z, N>, F, PairNorm<Radon, L2, L2>, - PreadjointCodomain = Pair<BTFN<F, GA, BTA, N>, Z>, + PreadjointCodomain = Pair<S, Z>, > - + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, 𝒟, IdOp<Z>, FloatType=F>, - BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, - Codomain = BTFN<F, G𝒟, BT𝒟, N>>, - BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, - S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> - + DifferentiableRealMapping<F, N>, - K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, - //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>, - BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - Cube<F, N>: P2Minimise<Loc<F, N>, F>, + + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F>, + S : DifferentiableRealMapping<F, N>, + for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, + for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, PlotLookup : Plotting<N>, RNDM<F, N> : SpikeMerging<F>, Reg : SlidingRegTerm<F, N>, + P : ProxPenalty<F, S, Reg, N>, // KOpM : Linear<RNDM<F, N>, Codomain=Y> // + GEMV<F, RNDM<F, N>> // + Preadjointable< // RNDM<F, N>, Y, - // PreadjointCodomain = BTFN<F, GA, BTA, N>, + // PreadjointCodomain = S, // > // + TransportLipschitz<L2Squared, FloatType=F> // + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, @@ -185,7 +158,6 @@ let zero_z = z.similar_origin(); // Set up parameters - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); // TODO: maybe this PairNorm doesn't make sense here? let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared); @@ -193,7 +165,7 @@ let nKz = opKz.opnorm_bound(L2, L2); let ℓ = 0.0; let opIdZ = IdOp::new(); - let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap(); + let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); // We need to satisfy // // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 @@ -278,18 +250,17 @@ // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. - let (d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop { + let (maybe_d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) let residual_μ̆ = calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); - let Pair(τv̆, τz) = opA.preadjoint().apply(residual_μ̆ * τ); + let mut τv̆z = opA.preadjoint().apply(residual_μ̆ * τ); // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. - let (d, within_tolerances) = insert_and_reweigh( - &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), - op𝒟, op𝒟norm, + let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv̆z.0, &γ1, Some(&μ_base_minus_γ0), τ, ε, &config.insertion, ®, &state, &mut stats, ); @@ -300,7 +271,7 @@ &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, ε, &config.transport ) { - break 'adapt_transport (d, within_tolerances, Pair(τv̆, τz)) + break 'adapt_transport (maybe_d, within_tolerances, τv̆z) } }; @@ -364,7 +335,7 @@ stats.this_iters += 1; state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) });