# HG changeset patch # User Tuomo Valkonen # Date 1737671728 -3600 # Node ID c5d8bd1a7728c99ce2acbb5c2dd0768116035ad1 # Parent fb911f72e6981ce8bdcaa171d552d878e6b3e85a Generic proximal penalty support diff -r fb911f72e698 -r c5d8bd1a7728 src/experiments.rs --- a/src/experiments.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/experiments.rs Thu Jan 23 23:35:28 2025 +0100 @@ -24,7 +24,8 @@ ExperimentBiased, Named, DefaultAlgorithm, - AlgorithmConfig + AlgorithmConfig, + ProxTerm }; //use crate::fb::FBGenericConfig; use crate::rand_distr::{SerializableNormal, SaltAndPepper}; @@ -153,7 +154,11 @@ .. Default::default() } }; - + let defaults_2d = HashMap::from([ + (DefaultAlgorithm::PDPS, AlgorithmConfig::PDPS(pdps_2d(), ProxTerm::Wave)), + (DefaultAlgorithm::RadonPDPS, AlgorithmConfig::PDPS(pdps_2d(), ProxTerm::RadonSquared)) + ]); + // We add a hash of the experiment name to the configured // noise seed to not use the same noise for different experiments. let mut h = DefaultHasher::new(); @@ -212,9 +217,7 @@ kernel : Prod(AutoConvolution(spread_cutoff), base_spread), kernel_plot_width, noise_seed, - algorithm_defaults: HashMap::from([ - (DefaultAlgorithm::PDPS, AlgorithmConfig::PDPS(pdps_2d())) - ]), + algorithm_defaults: defaults_2d, }}) }, Experiment2DFast => { @@ -231,9 +234,7 @@ kernel : base_spread, kernel_plot_width, noise_seed, - algorithm_defaults: HashMap::from([ - (DefaultAlgorithm::PDPS, AlgorithmConfig::PDPS(pdps_2d())) - ]), + algorithm_defaults: defaults_2d, }}) }, Experiment1D_L1 => { @@ -295,9 +296,7 @@ kernel : Prod(AutoConvolution(spread_cutoff), base_spread), kernel_plot_width, noise_seed, - algorithm_defaults: HashMap::from([ - (DefaultAlgorithm::PDPS, AlgorithmConfig::PDPS(pdps_2d())) - ]), + algorithm_defaults: defaults_2d, }}) }, Experiment2D_L1_Fast => { @@ -317,9 +316,7 @@ kernel : base_spread, kernel_plot_width, noise_seed, - algorithm_defaults: HashMap::from([ - (DefaultAlgorithm::PDPS, AlgorithmConfig::PDPS(pdps_2d())) - ]), + algorithm_defaults: defaults_2d, }}) }, Experiment1D_TV_Fast => { diff -r fb911f72e698 -r c5d8bd1a7728 src/fb.rs --- a/src/fb.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/fb.rs Thu Jan 23 23:35:28 2025 +0100 @@ -80,40 +80,18 @@ use numeric_literals::replace_float_literals; use serde::{Serialize, Deserialize}; use colored::Colorize; -use nalgebra::DVector; -use alg_tools::iterate::{ - AlgIteratorFactory, - AlgIteratorIteration, - AlgIterator, -}; +use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; use alg_tools::linops::{Mapping, GEMV}; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - BothGenerators, -}; use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::instance::Instance; -use alg_tools::norms::Linfinity; use crate::types::*; use crate::measures::{ DiscreteMeasure, RNDM, - DeltaMeasure, - Radon, }; use crate::measures::merging::{ SpikeMergingMethod, @@ -121,14 +99,8 @@ }; use crate::forward_model::{ ForwardModel, - AdjointProductBoundedBy + AdjointProductBoundedBy, }; -use crate::seminorms::DiscreteMeasureOp; -use crate::subproblem::{ - InnerSettings, - InnerMethod, -}; -use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, Plotting, @@ -140,6 +112,10 @@ L2Squared, DataTerm, }; +pub use crate::prox_penalty::{ + FBGenericConfig, + ProxPenalty +}; /// Settings for [`pointsource_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] @@ -151,51 +127,6 @@ pub generic : FBGenericConfig, } -/// Settings for the solution of the stepwise optimality condition. -#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] -#[serde(default)] -pub struct FBGenericConfig { - /// Tolerance for point insertion. - pub tolerance : Tolerance, - - /// Stop looking for predual maximum (where to isert a new point) below - /// `tolerance` multiplied by this factor. - /// - /// Not used by [`super::radon_fb`]. - pub insertion_cutoff_factor : F, - - /// Settings for branch and bound refinement when looking for predual maxima - pub refinement : RefinementSettings, - - /// Maximum insertions within each outer iteration - /// - /// Not used by [`super::radon_fb`]. - pub max_insertions : usize, - - /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. - /// - /// Not used by [`super::radon_fb`]. - pub bootstrap_insertions : Option<(usize, usize)>, - - /// Inner method settings - pub inner : InnerSettings, - - /// Spike merging method - pub merging : SpikeMergingMethod, - - /// Tolerance multiplier for merges - pub merge_tolerance_mult : F, - - /// Spike merging method after the last step - pub final_merging : SpikeMergingMethod, - - /// Iterations between merging heuristic tries - pub merge_every : usize, - - // /// Save $μ$ for postprocessing optimisation - // pub postprocessing : bool -} - #[replace_float_literals(F::cast_from(literal))] impl Default for FBConfig { fn default() -> Self { @@ -206,155 +137,6 @@ } } -#[replace_float_literals(F::cast_from(literal))] -impl Default for FBGenericConfig { - fn default() -> Self { - FBGenericConfig { - tolerance : Default::default(), - insertion_cutoff_factor : 1.0, - refinement : Default::default(), - max_insertions : 100, - //bootstrap_insertions : None, - bootstrap_insertions : Some((10, 1)), - inner : InnerSettings { - method : InnerMethod::Default, - .. Default::default() - }, - merging : SpikeMergingMethod::None, - //merging : Default::default(), - final_merging : Default::default(), - merge_every : 10, - merge_tolerance_mult : 2.0, - // postprocessing : false, - } - } -} - -impl FBGenericConfig { - /// Check if merging should be attempted this iteration - pub fn merge_now(&self, state : &AlgIteratorIteration) -> bool { - state.iteration() % self.merge_every == 0 - } -} - -/// TODO: document. -/// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike -/// locations, while `ν_delta` may have different locations. -#[replace_float_literals(F::cast_from(literal))] -pub(crate) fn insert_and_reweigh< - 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, I, const N : usize ->( - μ : &mut RNDM, - τv : &BTFN, - μ_base : &RNDM, - ν_delta: Option<&RNDM>, - op𝒟 : &'a 𝒟, - op𝒟norm : F, - τ : F, - ε : F, - config : &FBGenericConfig, - reg : &Reg, - state : &AlgIteratorIteration, - stats : &mut IterInfo, -) -> (BTFN, BTA, N>, bool) -where F : Float + ToNalgebraRealField, - GA : SupportGenerator + Clone, - BTA : BTSearch>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, - 𝒟::Codomain : RealMapping, - S: RealMapping + LocalAnalysis, N>, - K: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - Reg : RegTerm, - I : AlgIterator { - - // Maximum insertion count and measure difference calculation depend on insertion style. - let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { - (i, Some((l, k))) if i <= l => (k, false), - _ => (config.max_insertions, !state.is_quiet()), - }; - - let ω0 = match ν_delta { - None => op𝒟.apply(μ_base), - Some(ν) => op𝒟.apply(μ_base + ν), - }; - - // Add points to support until within error tolerance or maximum insertion count reached. - let mut count = 0; - let (within_tolerances, d) = 'insertion: loop { - if μ.len() > 0 { - // Form finite-dimensional subproblem. The subproblem references to the original μ^k - // from the beginning of the iteration are all contained in the immutable c and g. - // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional - // problems have not yet been updated to sign change. - let à = op𝒟.findim_matrix(μ.iter_locations()); - let g̃ = DVector::from_iterator(μ.len(), - μ.iter_locations() - .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) - .map(F::to_nalgebra_mixed)); - let mut x = μ.masses_dvector(); - - // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. - // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ - // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ - // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 - // = n |𝒟| |x|_2, where n is the number of points. Therefore - let Ã_normest = op𝒟norm * F::cast_from(μ.len()); - - // Solve finite-dimensional subproblem. - stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); - - // Update masses of μ based on solution of finite-dimensional subproblem. - μ.set_masses_dvector(&x); - } - - // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality - // conditions in the predual space, and finding new points for insertion, if necessary. - let mut d = τv + match ν_delta { - None => op𝒟.preapply(μ.sub_matching(μ_base)), - Some(ν) => op𝒟.preapply(μ.sub_matching(μ_base) - ν) - }; - - // If no merging heuristic is used, let's be more conservative about spike insertion, - // and skip it after first round. If merging is done, being more greedy about spike - // insertion also seems to improve performance. - let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { - false - } else { - count > 0 - }; - - // Find a spike to insert, if needed - let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( - &mut d, τ, ε, skip_by_rough_check, config - ) { - None => break 'insertion (true, d), - Some(res) => res, - }; - - // Break if maximum insertion count reached - if count >= max_insertions { - break 'insertion (in_bounds, d) - } - - // No point in optimising the weight here; the finite-dimensional algorithm is fast. - *μ += DeltaMeasure { x : ξ, α : 0.0 }; - count += 1; - stats.inserted += 1; - }; - - if !within_tolerances && warn_insertions { - // Complain (but continue) if we failed to get within tolerances - // by inserting more points. - let err = format!("Maximum insertions reached without achieving \ - subproblem solution tolerance"); - println!("{}", err.red()); - } - - (d, within_tolerances) -} - pub(crate) fn prune_with_stats( μ : &mut RNDM, ) -> usize { @@ -409,38 +191,32 @@ /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_fb_reg< - 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize + F, I, A, Reg, P, const N : usize >( - opA : &'a A, + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, fbconfig : &FBConfig, iterator : I, mut plotter : SeqPlotter, ) -> RNDM -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg, - GA : SupportGenerator + Clone, - A : ForwardModel, F, PreadjointCodomain = BTFN> - + AdjointProductBoundedBy, 𝒟, FloatType=F>, - BTA : BTSearch>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, - 𝒟::Codomain : RealMapping, - S: RealMapping + LocalAnalysis, N>, - K: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - Cube: P2Minimise, F>, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : RegTerm { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory>, + for<'b> &'b A::Observable : std::ops::Neg, + A : ForwardModel, F> + + AdjointProductBoundedBy, P, FloatType=F>, + A::PreadjointCodomain : RealMapping, + PlotLookup : Plotting, + RNDM : SpikeMerging, + Reg : RegTerm, + P : ProxPenalty, +{ // Set up parameters let config = &fbconfig.generic; - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); - let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap(); + let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. let tolerance = config.tolerance * τ * reg.tolerance_scaling(); @@ -465,26 +241,23 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let τv = opA.preadjoint().apply(residual * τ); + let mut τv = opA.preadjoint().apply(residual * τ); // Save current base point let μ_base = μ.clone(); // Insert and reweigh - let (d, _within_tolerances) = insert_and_reweigh( - &mut μ, &τv, &μ_base, None, - op𝒟, op𝒟norm, + let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats ); // Prune and possibly merge spikes if config.merge_now(&state) { - stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { - let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); - reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) - }); + stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, ®); } + stats.pruned += prune_with_stats(&mut μ); // Update residual @@ -495,7 +268,7 @@ // Give statistics if needed state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) }); @@ -526,38 +299,32 @@ /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_fista_reg< - 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize + F, I, A, Reg, P, const N : usize >( - opA : &'a A, + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, fbconfig : &FBConfig, iterator : I, mut plotter : SeqPlotter, ) -> RNDM -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg, - GA : SupportGenerator + Clone, - A : ForwardModel, F, PreadjointCodomain = BTFN> - + AdjointProductBoundedBy, 𝒟, FloatType=F>, - BTA : BTSearch>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, - 𝒟::Codomain : RealMapping, - S: RealMapping + LocalAnalysis, N>, - K: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - Cube: P2Minimise, F>, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : RegTerm { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory>, + for<'b> &'b A::Observable : std::ops::Neg, + A : ForwardModel, F> + + AdjointProductBoundedBy, P, FloatType=F>, + A::PreadjointCodomain : RealMapping, + PlotLookup : Plotting, + RNDM : SpikeMerging, + Reg : RegTerm, + P : ProxPenalty, +{ // Set up parameters let config = &fbconfig.generic; - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); - let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap(); + let τ = fbconfig.τ0/opA.adjoint_product_bound(prox_penalty).unwrap(); let mut λ = 1.0; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. @@ -583,15 +350,14 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let τv = opA.preadjoint().apply(residual * τ); + let mut τv = opA.preadjoint().apply(residual * τ); // Save current base point let μ_base = μ.clone(); // Insert new spikes and reweigh - let (d, _within_tolerances) = insert_and_reweigh( - &mut μ, &τv, &μ_base, None, - op𝒟, op𝒟norm, + let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats ); @@ -632,7 +398,7 @@ // Give statistics if needed state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ_prev); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ_prev); full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new())) }); diff -r fb911f72e698 -r c5d8bd1a7728 src/forward_model/bias.rs --- a/src/forward_model/bias.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/forward_model/bias.rs Thu Jan 23 23:35:28 2025 +0100 @@ -43,7 +43,6 @@ F : Float, Z : Clone + Space + ClosedAdd, A : AdjointProductBoundedBy, - D : Linear, A::Codomain : ClosedAdd, { type FloatType = F; @@ -92,7 +91,6 @@ } -/// TODO: should assume `D` to be positive semi-definite and self-adjoint. #[replace_float_literals(F::cast_from(literal))] impl<'a, F, D, XD, Y, const N : usize> AdjointProductBoundedBy, D> for ZeroOp<'a, RNDM, XD, Y, F> diff -r fb911f72e698 -r c5d8bd1a7728 src/forward_pdps.rs --- a/src/forward_pdps.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/forward_pdps.rs Thu Jan 23 23:35:28 2025 +0100 @@ -8,30 +8,15 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::mapping::{Mapping, Instance}; +use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; use alg_tools::norms::Norm; use alg_tools::direct_product::Pair; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - //Bounded, -}; -use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::{ BoundedLinear, AXPY, GEMV, Adjointable, IdOp, }; use alg_tools::convex::{Conjugable, Prox}; -use alg_tools::norms::{L2, Linfinity, PairNorm}; +use alg_tools::norms::{L2, PairNorm}; use crate::types::*; use crate::measures::{DiscreteMeasure, Radon, RNDM}; @@ -40,7 +25,6 @@ ForwardModel, AdjointProductPairBoundedBy, }; -use crate::seminorms::DiscreteMeasureOp; use crate::plot::{ SeqPlotter, Plotting, @@ -83,12 +67,12 @@ /// using primal-dual proximal splitting with a forward step. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_forward_pdps_pair< - 'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize + F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize >( - opA : &'a A, + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, config : &ForwardPDPSConfig, iterator : I, mut plotter : SeqPlotter, @@ -102,27 +86,19 @@ where F : Float + ToNalgebraRealField, I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - GA : SupportGenerator + Clone, A : ForwardModel< MeasureZ, F, PairNorm, - PreadjointCodomain = Pair, Z>, + PreadjointCodomain = Pair, > - + AdjointProductPairBoundedBy, 𝒟, IdOp, FloatType=F>, - BTA : BTSearch>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN, - Codomain = BTFN>, - BT𝒟 : BTSearch>, - S: RealMapping + LocalAnalysis, N>, - K: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - Cube: P2Minimise, F>, + + AdjointProductPairBoundedBy, P, IdOp, FloatType=F>, + S: DifferentiableRealMapping, + for<'b> &'b A::Observable : std::ops::Neg + Instance, PlotLookup : Plotting, RNDM : SpikeMerging, Reg : RegTerm, + P : ProxPenalty, KOpZ : BoundedLinear + GEMV + Adjointable, @@ -150,11 +126,10 @@ let mut residual = calculate_residual(Pair(&μ, &z), opA, b); // Set up parameters - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); - let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); + let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt(); let nKz = opKz.opnorm_bound(L2, L2); let opIdZ = IdOp::new(); - let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap(); + let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); // We need to satisfy // // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 @@ -196,14 +171,13 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { // Calculate initial transport - let Pair(τv, τz) = opA.preadjoint().apply(residual * τ); + let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ); let z_base = z.clone(); let μ_base = μ.clone(); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. - let (d, _within_tolerances) = insert_and_reweigh( - &mut μ, &τv, &μ_base, None, - op𝒟, op𝒟norm, + let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv, &μ_base, None, τ, ε, &config.insertion, ®, &state, &mut stats, ); @@ -248,7 +222,7 @@ stats.this_iters += 1; state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) }); diff -r fb911f72e698 -r c5d8bd1a7728 src/frank_wolfe.rs --- a/src/frank_wolfe.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/frank_wolfe.rs Thu Jan 23 23:35:28 2025 +0100 @@ -425,8 +425,8 @@ /// `iterator` is used to iterate the steps of the method, and `plotter` may be used to /// save intermediate iteration states as images. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_fw_reg<'a, F, I, A, GA, BTA, S, Reg, const N : usize>( - opA : &'a A, +pub fn pointsource_fw_reg( + opA : &A, b : &A::Observable, reg : Reg, //domain : Cube, diff -r fb911f72e698 -r c5d8bd1a7728 src/main.rs --- a/src/main.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/main.rs Thu Jan 23 23:35:28 2025 +0100 @@ -37,8 +37,8 @@ pub mod tolerance; pub mod regularisation; pub mod dataterm; +pub mod prox_penalty; pub mod fb; -pub mod radon_fb; pub mod sliding_fb; pub mod sliding_pdps; pub mod forward_pdps; diff -r fb911f72e698 -r c5d8bd1a7728 src/measures/discrete.rs --- a/src/measures/discrete.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/measures/discrete.rs Thu Jan 23 23:35:28 2025 +0100 @@ -326,6 +326,18 @@ pub fn set_masses_dvector(&mut self, x : &DVector) { self.set_masses(x.iter().map(|&α| F::from_nalgebra_mixed(α))); } + + /// Extracts the masses of the spikes as a [`Vec`]. + pub fn masses_vec(&self) -> Vec { + self.iter_masses() + .map(|α| α.to_nalgebra_mixed()) + .collect() + } + + /// Sets the masses of the spikes from the values of a [`Vec`]. + pub fn set_masses_vec(&mut self, x : &Vec) { + self.set_masses(x.iter().map(|&α| F::from_nalgebra_mixed(α))); + } } // impl Index for DiscreteMeasure { diff -r fb911f72e698 -r c5d8bd1a7728 src/pdps.rs --- a/src/pdps.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/pdps.rs Thu Jan 23 23:35:28 2025 +0100 @@ -44,46 +44,36 @@ use clap::ValueEnum; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::loc::Loc; use alg_tools::euclidean::Euclidean; use alg_tools::linops::Mapping; use alg_tools::norms::{ Linfinity, Projection, }; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - SupportGenerator, - LocalAnalysis, -}; use alg_tools::mapping::{RealMapping, Instance}; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::AXPY; use crate::types::*; -use crate::measures::{DiscreteMeasure, RNDM, Radon}; +use crate::measures::{DiscreteMeasure, RNDM}; use crate::measures::merging::SpikeMerging; use crate::forward_model::{ + ForwardModel, AdjointProductBoundedBy, - ForwardModel }; -use crate::seminorms::DiscreteMeasureOp; use crate::plot::{ SeqPlotter, Plotting, PlotLookup }; use crate::fb::{ - FBGenericConfig, - insert_and_reweigh, postprocess, prune_with_stats }; +pub use crate::prox_penalty::{ + FBGenericConfig, + ProxPenalty +}; use crate::regularisation::RegTerm; use crate::dataterm::{ DataTerm, @@ -223,33 +213,29 @@ /// /// Returns the final iterate. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( - opA : &'a A, - b : &'a A::Observable, +pub fn pointsource_pdps_reg( + opA : &A, + b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, pdpsconfig : &PDPSConfig, iterator : I, mut plotter : SeqPlotter, dataterm : D, ) -> RNDM -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - GA : SupportGenerator + Clone, - A : ForwardModel, F, PreadjointCodomain = BTFN> - + AdjointProductBoundedBy, 𝒟, FloatType=F>, - BTA : BTSearch>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, - 𝒟::Codomain : RealMapping, - S: RealMapping + LocalAnalysis, N>, - K: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - PlotLookup : Plotting, - RNDM : SpikeMerging, - D : PDPSDataTerm, - Reg : RegTerm { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory>, + A : ForwardModel, F> + + AdjointProductBoundedBy, P, FloatType=F>, + A::PreadjointCodomain : RealMapping, + for<'b> &'b A::Observable : std::ops::Neg + Instance, + PlotLookup : Plotting, + RNDM : SpikeMerging, + D : PDPSDataTerm, + Reg : RegTerm, + P : ProxPenalty, +{ // Check parameters assert!(pdpsconfig.τ0 > 0.0 && @@ -259,8 +245,7 @@ // Set up parameters let config = &pdpsconfig.generic; - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); - let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt(); + let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); let mut τ = pdpsconfig.τ0 / l; let mut σ = pdpsconfig.σ0 / l; let γ = dataterm.factor_of_strong_convexity(); @@ -286,25 +271,21 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let τv = opA.preadjoint().apply(y * τ); + let mut τv = opA.preadjoint().apply(y * τ); // Save current base point let μ_base = μ.clone(); // Insert and reweigh - let (d, _within_tolerances) = insert_and_reweigh( - &mut μ, &τv, &μ_base, None, - op𝒟, op𝒟norm, + let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats ); // Prune and possibly merge spikes if config.merge_now(&state) { - stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { - let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); - reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) - }); + stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, ®); } stats.pruned += prune_with_stats(&mut μ); @@ -323,7 +304,7 @@ stats.this_iters += 1; state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) }); diff -r fb911f72e698 -r c5d8bd1a7728 src/prox_penalty.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/prox_penalty.rs Thu Jan 23 23:35:28 2025 +0100 @@ -0,0 +1,158 @@ +/*! +Proximal penalty abstraction +*/ + +use numeric_literals::replace_float_literals; +use alg_tools::types::*; +use serde::{Serialize, Deserialize}; + +use alg_tools::mapping::RealMapping; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::iterate::{ + AlgIteratorIteration, + AlgIterator, +}; +use crate::measures::{ + RNDM, +}; +use crate::types::{ + RefinementSettings, + IterInfo, +}; +use crate::subproblem::{ + InnerSettings, + InnerMethod, +}; +use crate::tolerance::Tolerance; +use crate::measures::merging::SpikeMergingMethod; +use crate::regularisation::RegTerm; + +pub mod wave; +pub mod radon_squared; +pub use radon_squared::RadonSquared; + +/// Settings for the solution of the stepwise optimality condition. +#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] +#[serde(default)] +pub struct FBGenericConfig { + /// Tolerance for point insertion. + pub tolerance : Tolerance, + + /// Stop looking for predual maximum (where to isert a new point) below + /// `tolerance` multiplied by this factor. + /// + /// Not used by [`super::radon_fb`]. + pub insertion_cutoff_factor : F, + + /// Settings for branch and bound refinement when looking for predual maxima + pub refinement : RefinementSettings, + + /// Maximum insertions within each outer iteration + /// + /// Not used by [`super::radon_fb`]. + pub max_insertions : usize, + + /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. + /// + /// Not used by [`super::radon_fb`]. + pub bootstrap_insertions : Option<(usize, usize)>, + + /// Inner method settings + pub inner : InnerSettings, + + /// Spike merging method + pub merging : SpikeMergingMethod, + + /// Tolerance multiplier for merges + pub merge_tolerance_mult : F, + + /// Spike merging method after the last step + pub final_merging : SpikeMergingMethod, + + /// Iterations between merging heuristic tries + pub merge_every : usize, + + // /// Save $μ$ for postprocessing optimisation + // pub postprocessing : bool +} + +#[replace_float_literals(F::cast_from(literal))] +impl Default for FBGenericConfig { + fn default() -> Self { + FBGenericConfig { + tolerance : Default::default(), + insertion_cutoff_factor : 1.0, + refinement : Default::default(), + max_insertions : 100, + //bootstrap_insertions : None, + bootstrap_insertions : Some((10, 1)), + inner : InnerSettings { + method : InnerMethod::Default, + .. Default::default() + }, + merging : SpikeMergingMethod::None, + //merging : Default::default(), + final_merging : Default::default(), + merge_every : 10, + merge_tolerance_mult : 2.0, + // postprocessing : false, + } + } +} + +impl FBGenericConfig { + /// Check if merging should be attempted this iteration + pub fn merge_now(&self, state : &AlgIteratorIteration) -> bool { + state.iteration() % self.merge_every == 0 + } +} + + +/// Trait for proximal penalties +pub trait ProxPenalty +where + F : Float + ToNalgebraRealField, + Reg : RegTerm, +{ + type ReturnMapping : RealMapping; + + /// Insert new spikes into `μ` to approximately satisfy optimality conditions + /// with the forward step term fixed to `τv`. + /// + /// May return `τv + w` for `w` a subdifferential of the regularisation term `reg`, + /// as well as an indication of whether the tolerance bounds `ε` are satisfied. + /// + /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same + /// spike locations, while `ν_delta` may have different locations. + /// + /// `τv` is mutable to allow [`alg_tools::bisection_tree::btfn::BTFN`] refinement. + /// Actual values of `τv` are not supposed to be mutated. + fn insert_and_reweigh( + &self, + μ : &mut RNDM, + τv : &mut PreadjointCodomain, + μ_base : &RNDM, + ν_delta: Option<&RNDM>, + τ : F, + ε : F, + config : &FBGenericConfig, + reg : &Reg, + state : &AlgIteratorIteration, + stats : &mut IterInfo, + ) -> (Option, bool) + where + I : AlgIterator; + + + /// Merge spikes, if possible. + fn merge_spikes( + &self, + μ : &mut RNDM, + τv : &mut PreadjointCodomain, + μ_base : &RNDM, + τ : F, + ε : F, + config : &FBGenericConfig, + reg : &Reg, + ) -> usize; +} diff -r fb911f72e698 -r c5d8bd1a7728 src/prox_penalty/radon_squared.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/prox_penalty/radon_squared.rs Thu Jan 23 23:35:28 2025 +0100 @@ -0,0 +1,170 @@ +/*! +Solver for the point source localisation problem using a simplified forward-backward splitting method. + +Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map. +*/ + +use numeric_literals::replace_float_literals; +use serde::{Serialize, Deserialize}; +use nalgebra::DVector; + +use alg_tools::iterate::{ + AlgIteratorIteration, + AlgIterator +}; +use alg_tools::norms::L2; +use alg_tools::linops::Mapping; +use alg_tools::bisection_tree::{ + BTFN, + Bounds, + BTSearch, + SupportGenerator, + LocalAnalysis, +}; +use alg_tools::mapping::RealMapping; +use alg_tools::nalgebra_support::ToNalgebraRealField; + +use crate::types::*; +use crate::measures::{ + RNDM, + DeltaMeasure, + Radon, +}; +use crate::measures::merging::SpikeMerging; +use crate::regularisation::RegTerm; +use crate::forward_model::{ + ForwardModel, + AdjointProductBoundedBy +}; +use super::{ + FBGenericConfig, + ProxPenalty, +}; + +/// Radon-norm squared proximal penalty + +#[derive(Copy,Clone,Serialize,Deserialize)] +pub struct RadonSquared; + +#[replace_float_literals(F::cast_from(literal))] +impl +ProxPenalty, Reg, N> for RadonSquared +where + F : Float + ToNalgebraRealField, + GA : SupportGenerator + Clone, + BTA : BTSearch>, + S: RealMapping + LocalAnalysis, N>, + Reg : RegTerm, + RNDM : SpikeMerging, +{ + type ReturnMapping = BTFN; + + fn insert_and_reweigh( + &self, + μ : &mut RNDM, + τv : &mut BTFN, + μ_base : &RNDM, + ν_delta: Option<&RNDM>, + τ : F, + ε : F, + config : &FBGenericConfig, + reg : &Reg, + _state : &AlgIteratorIteration, + stats : &mut IterInfo, + ) -> (Option, bool) + where + I : AlgIterator + { + assert!(ν_delta.is_none(), "Transport not implemented for Radon-squared prox term"); + + let mut y = μ_base.masses_vec(); + + 'i_and_w: for i in 0..=1 { + // Optimise weights + if μ.len() > 0 { + // Form finite-dimensional subproblem. The subproblem references to the original μ^k + // from the beginning of the iteration are all contained in the immutable c and g. + // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional + // problems have not yet been updated to sign change. + let g̃ = DVector::from_iterator(μ.len(), + μ.iter_locations() + .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ)))); + let mut x = μ.masses_dvector(); + // Ugly hack because DVector::push doesn't push but copies. + let yvec = DVector::from_column_slice(y.as_slice()); + // Solve finite-dimensional subproblem. + stats.inner_iters += reg.solve_findim_l1squared(&yvec, &g̃, τ, &mut x, ε, config); + + // Update masses of μ based on solution of finite-dimensional subproblem. + μ.set_masses_dvector(&x); + } + + if i>0 { + // Simple debugging test to see if more inserts would be needed. Doesn't seem so. + //let n = μ.dist_matching(μ_base); + //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n)); + break 'i_and_w + } + + // Calculate ‖μ - μ_base‖_ℳ + let n = μ.dist_matching(μ_base); + + // Find a spike to insert, if needed. + // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, + // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. + match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) { + None => { break 'i_and_w }, + Some((ξ, _v_ξ, _in_bounds)) => { + // Weight is found out by running the finite-dimensional optimisation algorithm + // above + *μ += DeltaMeasure { x : ξ, α : 0.0 }; + //*μ_base += DeltaMeasure { x : ξ, α : 0.0 }; + y.push(0.0.to_nalgebra_mixed()); + stats.inserted += 1; + } + }; + } + + (None, true) + } + + fn merge_spikes( + &self, + μ : &mut RNDM, + τv : &mut BTFN, + μ_base : &RNDM, + τ : F, + ε : F, + config : &FBGenericConfig, + reg : &Reg, + ) -> usize + { + μ.merge_spikes(config.merging, |μ_candidate| { + // Important: μ_candidate's new points are afterwards, + // and do not conflict with μ_base. + // TODO: could simplify to requiring μ_base instead of μ_radon. + // but may complicate with sliding base's exgtra points that need to be + // after μ_candidate's extra points. + // TODO: doesn't seem to work, maybe need to merge μ_base as well? + // Although that doesn't seem to make sense. + let μ_radon = μ_candidate.sub_matching(μ_base); + reg.verify_merge_candidate_radonsq(τv, μ_candidate, τ, ε, &config, &μ_radon) + //let n = μ_candidate.dist_matching(μ_base); + //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() + }) + } +} + + +impl AdjointProductBoundedBy, RadonSquared> +for A +where + F : Float, + A : ForwardModel, F> +{ + type FloatType = F; + + fn adjoint_product_bound(&self, _ : &RadonSquared) -> Option { + self.opnorm_bound(Radon, L2).powi(2).into() + } +} diff -r fb911f72e698 -r c5d8bd1a7728 src/prox_penalty/wave.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/prox_penalty/wave.rs Thu Jan 23 23:35:28 2025 +0100 @@ -0,0 +1,182 @@ +/*! +Basic proximal penalty based on convolution operators $𝒟$. + */ + +use numeric_literals::replace_float_literals; +use nalgebra::DVector; +use colored::Colorize; + +use alg_tools::types::*; +use alg_tools::loc::Loc; +use alg_tools::mapping::{Mapping, RealMapping}; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::norms::Linfinity; +use alg_tools::iterate::{ + AlgIteratorIteration, + AlgIterator, +}; +use alg_tools::bisection_tree::{ + BTFN, + PreBTFN, + Bounds, + BTSearch, + SupportGenerator, + LocalAnalysis, + BothGenerators, +}; +use crate::measures::{ + RNDM, + DeltaMeasure, + Radon, +}; +use crate::measures::merging::{ + SpikeMerging, +}; +use crate::seminorms::DiscreteMeasureOp; +use crate::types::{ + IterInfo, +}; +use crate::measures::merging::SpikeMergingMethod; +use crate::regularisation::RegTerm; +use super::{ProxPenalty, FBGenericConfig}; + +#[replace_float_literals(F::cast_from(literal))] +impl +ProxPenalty, Reg, N> for 𝒟 +where + F : Float + ToNalgebraRealField, + GA : SupportGenerator + Clone, + BTA : BTSearch>, + S: RealMapping + LocalAnalysis, N>, + G𝒟 : SupportGenerator + Clone, + 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>, + 𝒟::Codomain : RealMapping, + K : RealMapping + LocalAnalysis, N>, + Reg : RegTerm, + RNDM : SpikeMerging, +{ + type ReturnMapping = BTFN, BTA, N>; + + fn insert_and_reweigh( + &self, + μ : &mut RNDM, + τv : &mut BTFN, + μ_base : &RNDM, + ν_delta: Option<&RNDM>, + τ : F, + ε : F, + config : &FBGenericConfig, + reg : &Reg, + state : &AlgIteratorIteration, + stats : &mut IterInfo, + ) -> (Option, BTA, N>>, bool) + where + I : AlgIterator + { + + // TODO: is this inefficient to do in every iteration? + let op𝒟norm = self.opnorm_bound(Radon, Linfinity); + + // Maximum insertion count and measure difference calculation depend on insertion style. + let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { + (i, Some((l, k))) if i <= l => (k, false), + _ => (config.max_insertions, !state.is_quiet()), + }; + + let ω0 = match ν_delta { + None => self.apply(μ_base), + Some(ν) => self.apply(μ_base + ν), + }; + + // Add points to support until within error tolerance or maximum insertion count reached. + let mut count = 0; + let (within_tolerances, d) = 'insertion: loop { + if μ.len() > 0 { + // Form finite-dimensional subproblem. The subproblem references to the original μ^k + // from the beginning of the iteration are all contained in the immutable c and g. + // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional + // problems have not yet been updated to sign change. + let à = self.findim_matrix(μ.iter_locations()); + let g̃ = DVector::from_iterator(μ.len(), + μ.iter_locations() + .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) + .map(F::to_nalgebra_mixed)); + let mut x = μ.masses_dvector(); + + // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. + // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ + // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ + // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 + // = n |𝒟| |x|_2, where n is the number of points. Therefore + let Ã_normest = op𝒟norm * F::cast_from(μ.len()); + + // Solve finite-dimensional subproblem. + stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); + + // Update masses of μ based on solution of finite-dimensional subproblem. + μ.set_masses_dvector(&x); + } + + // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality + // conditions in the predual space, and finding new points for insertion, if necessary. + let mut d = &*τv + match ν_delta { + None => self.preapply(μ.sub_matching(μ_base)), + Some(ν) => self.preapply(μ.sub_matching(μ_base) - ν) + }; + + // If no merging heuristic is used, let's be more conservative about spike insertion, + // and skip it after first round. If merging is done, being more greedy about spike + // insertion also seems to improve performance. + let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { + false + } else { + count > 0 + }; + + // Find a spike to insert, if needed + let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( + &mut d, τ, ε, skip_by_rough_check, config + ) { + None => break 'insertion (true, d), + Some(res) => res, + }; + + // Break if maximum insertion count reached + if count >= max_insertions { + break 'insertion (in_bounds, d) + } + + // No point in optimising the weight here; the finite-dimensional algorithm is fast. + *μ += DeltaMeasure { x : ξ, α : 0.0 }; + count += 1; + stats.inserted += 1; + }; + + if !within_tolerances && warn_insertions { + // Complain (but continue) if we failed to get within tolerances + // by inserting more points. + let err = format!("Maximum insertions reached without achieving \ + subproblem solution tolerance"); + println!("{}", err.red()); + } + + (Some(d), within_tolerances) + } + + fn merge_spikes( + &self, + μ : &mut RNDM, + τv : &mut BTFN, + μ_base : &RNDM, + τ : F, + ε : F, + config : &FBGenericConfig, + reg : &Reg, + ) -> usize + { + μ.merge_spikes(config.merging, |μ_candidate| { + let mut d = &*τv + self.preapply(μ_candidate.sub_matching(μ_base)); + reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config) + }) + } +} diff -r fb911f72e698 -r c5d8bd1a7728 src/radon_fb.rs --- a/src/radon_fb.rs Mon Jan 06 11:32:57 2025 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,404 +0,0 @@ -/*! -Solver for the point source localisation problem using a simplified forward-backward splitting method. - -Instead of the $𝒟$-norm of `fb.rs`, this uses a standard Radon norm for the proximal map. -*/ - -use numeric_literals::replace_float_literals; -use serde::{Serialize, Deserialize}; -use colored::Colorize; -use nalgebra::DVector; - -use alg_tools::iterate::{ - AlgIteratorFactory, - AlgIteratorIteration, - AlgIterator -}; -use alg_tools::euclidean::Euclidean; -use alg_tools::linops::Mapping; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::bisection_tree::{ - BTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, -}; -use alg_tools::mapping::RealMapping; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::L2; - -use crate::types::*; -use crate::measures::{ - RNDM, - DiscreteMeasure, - DeltaMeasure, - Radon, -}; -use crate::measures::merging::{ - SpikeMergingMethod, - SpikeMerging, -}; -use crate::forward_model::ForwardModel; -use crate::plot::{ - SeqPlotter, - Plotting, - PlotLookup -}; -use crate::regularisation::RegTerm; -use crate::dataterm::{ - calculate_residual, - L2Squared, - DataTerm, -}; - -use crate::fb::{ - FBGenericConfig, - postprocess, - prune_with_stats -}; - -/// Settings for [`pointsource_radon_fb_reg`]. -#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] -#[serde(default)] -pub struct RadonFBConfig { - /// Step length scaling - pub τ0 : F, - /// Generic parameters - pub insertion : FBGenericConfig, -} - -#[replace_float_literals(F::cast_from(literal))] -impl Default for RadonFBConfig { - fn default() -> Self { - RadonFBConfig { - τ0 : 0.99, - insertion : Default::default() - } - } -} - -#[replace_float_literals(F::cast_from(literal))] -pub(crate) fn insert_and_reweigh< - 'a, F, GA, BTA, S, Reg, I, const N : usize ->( - μ : &mut RNDM, - τv : &mut BTFN, - μ_base : &mut RNDM, - //_ν_delta: Option<&RNDM>, - τ : F, - ε : F, - config : &FBGenericConfig, - reg : &Reg, - _state : &AlgIteratorIteration, - stats : &mut IterInfo, -) -where F : Float + ToNalgebraRealField, - GA : SupportGenerator + Clone, - BTA : BTSearch>, - S: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - RNDM : SpikeMerging, - Reg : RegTerm, - I : AlgIterator { - - 'i_and_w: for i in 0..=1 { - // Optimise weights - if μ.len() > 0 { - // Form finite-dimensional subproblem. The subproblem references to the original μ^k - // from the beginning of the iteration are all contained in the immutable c and g. - // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional - // problems have not yet been updated to sign change. - let g̃ = DVector::from_iterator(μ.len(), - μ.iter_locations() - .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ)))); - let mut x = μ.masses_dvector(); - let y = μ_base.masses_dvector(); - - // Solve finite-dimensional subproblem. - stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config); - - // Update masses of μ based on solution of finite-dimensional subproblem. - μ.set_masses_dvector(&x); - } - - if i>0 { - // Simple debugging test to see if more inserts would be needed. Doesn't seem so. - //let n = μ.dist_matching(μ_base); - //println!("{:?}", reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n)); - break 'i_and_w - } - - // Calculate ‖μ - μ_base‖_ℳ - let n = μ.dist_matching(μ_base); - - // Find a spike to insert, if needed. - // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ, - // which are supposed to have been guaranteed by the finite-dimensional weight optimisation. - match reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n) { - None => { break 'i_and_w }, - Some((ξ, _v_ξ, _in_bounds)) => { - // Weight is found out by running the finite-dimensional optimisation algorithm - // above - *μ += DeltaMeasure { x : ξ, α : 0.0 }; - *μ_base += DeltaMeasure { x : ξ, α : 0.0 }; - stats.inserted += 1; - } - }; - } -} - - -/// Iteratively solve the pointsource localisation problem using simplified forward-backward splitting. -/// -/// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the -/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. -/// Finally, the `iterator` is an outer loop verbosity and iteration count control -/// as documented in [`alg_tools::iterate`]. -/// -/// For details on the mathematical formulation, see the [module level](self) documentation. -/// -/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of -/// sums of simple functions usign bisection trees, and the related -/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions -/// active at a specific points, and to maximise their sums. Through the implementation of the -/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features -/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. -/// -/// Returns the final iterate. -#[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_radon_fb_reg< - 'a, F, I, A, GA, BTA, S, Reg, const N : usize ->( - opA : &'a A, - b : &A::Observable, - reg : Reg, - fbconfig : &RadonFBConfig, - iterator : I, - mut _plotter : SeqPlotter, -) -> RNDM -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg, - GA : SupportGenerator + Clone, - A : ForwardModel, F, PreadjointCodomain = BTFN>, - BTA : BTSearch>, - S: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - Cube: P2Minimise, F>, - RNDM : SpikeMerging, - Reg : RegTerm { - - // Set up parameters - let config = &fbconfig.insertion; - // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ - // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such - // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. - let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2); - // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled - // by τ compared to the conditional gradient approach. - let tolerance = config.tolerance * τ * reg.tolerance_scaling(); - let mut ε = tolerance.initial(); - - // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut residual = -b; - - // Statistics - let full_stats = |residual : &A::Observable, - μ : &RNDM, - ε, stats| IterInfo { - value : residual.norm2_squared_div2() + reg.apply(μ), - n_spikes : μ.len(), - ε, - // postprocessing: config.postprocessing.then(|| μ.clone()), - .. stats - }; - let mut stats = IterInfo::new(); - - // Run the algorithm - for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { - // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(residual * τ); - - // Save current base point - let mut μ_base = μ.clone(); - - // Insert and reweigh - insert_and_reweigh( - &mut μ, &mut τv, &mut μ_base, //None, - τ, ε, - config, ®, &state, &mut stats - ); - - // Prune and possibly merge spikes - assert!(μ_base.len() <= μ.len()); - if config.merge_now(&state) { - stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { - // Important: μ_candidate's new points are afterwards, - // and do not conflict with μ_base. - // TODO: could simplify to requiring μ_base instead of μ_radon. - // but may complicate with sliding base's exgtra points that need to be - // after μ_candidate's extra points. - // TODO: doesn't seem to work, maybe need to merge μ_base as well? - // Although that doesn't seem to make sense. - let μ_radon = μ_candidate.sub_matching(&μ_base); - reg.verify_merge_candidate_radonsq(&mut τv, μ_candidate, τ, ε, &config, &μ_radon) - //let n = μ_candidate.dist_matching(μ_base); - //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none() - }); - } - stats.pruned += prune_with_stats(&mut μ); - - // Update residual - residual = calculate_residual(&μ, opA, b); - - let iter = state.iteration(); - stats.this_iters += 1; - - // Give statistics if needed - state.if_verbose(|| { - full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) - }); - - // Update main tolerance for next iteration - ε = tolerance.update(ε, iter); - } - - postprocess(μ, config, L2Squared, opA, b) -} - -/// Iteratively solve the pointsource localisation problem using simplified inertial forward-backward splitting. -/// -/// The settings in `config` have their [respective documentation][RadonFBConfig]. `opA` is the -/// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. -/// Finally, the `iterator` is an outer loop verbosity and iteration count control -/// as documented in [`alg_tools::iterate`]. -/// -/// For details on the mathematical formulation, see the [module level](self) documentation. -/// -/// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of -/// sums of simple functions usign bisection trees, and the related -/// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions -/// active at a specific points, and to maximise their sums. Through the implementation of the -/// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features -/// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. -/// -/// Returns the final iterate. -#[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_radon_fista_reg< - 'a, F, I, A, GA, BTA, S, Reg, const N : usize ->( - opA : &'a A, - b : &A::Observable, - reg : Reg, - fbconfig : &RadonFBConfig, - iterator : I, - mut plotter : SeqPlotter, -) -> RNDM -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg, - GA : SupportGenerator + Clone, - A : ForwardModel, F, PreadjointCodomain = BTFN>, - BTA : BTSearch>, - S: RealMapping + LocalAnalysis, N>, - BTNodeLookup: BTNode, N>, - Cube: P2Minimise, F>, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : RegTerm { - - // Set up parameters - let config = &fbconfig.insertion; - // We need L such that the descent inequality F(ν) - F(μ) - ⟨F'(μ),ν-μ⟩ ≤ (L/2)‖ν-μ‖²_ℳ ∀ ν,μ - // holds. Since the left hand side expands as (1/2)‖A(ν-μ)‖₂², this is to say, we need L such - // that ‖Aμ‖₂² ≤ L ‖μ‖²_ℳ ∀ μ. Thus `opnorm_bound` gives the square root of L. - let τ = fbconfig.τ0/opA.opnorm_bound(Radon, L2).powi(2); - let mut λ = 1.0; - // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled - // by τ compared to the conditional gradient approach. - let tolerance = config.tolerance * τ * reg.tolerance_scaling(); - let mut ε = tolerance.initial(); - - // Initialise iterates - let mut μ = DiscreteMeasure::new(); - let mut μ_prev = DiscreteMeasure::new(); - let mut residual = -b; - let mut warned_merging = false; - - // Statistics - let full_stats = |ν : &RNDM, ε, stats| IterInfo { - value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), - n_spikes : ν.len(), - ε, - // postprocessing: config.postprocessing.then(|| ν.clone()), - .. stats - }; - let mut stats = IterInfo::new(); - - // Run the algorithm - for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { - // Calculate smooth part of surrogate model. - let mut τv = opA.preadjoint().apply(residual * τ); - - // Save current base point - let mut μ_base = μ.clone(); - - // Insert new spikes and reweigh - insert_and_reweigh( - &mut μ, &mut τv, &mut μ_base, //None, - τ, ε, - config, ®, &state, &mut stats - ); - - // (Do not) merge spikes. - if config.merge_now(&state) { - match config.merging { - SpikeMergingMethod::None => { }, - _ => if !warned_merging { - let err = format!("Merging not supported for μFISTA"); - println!("{}", err.red()); - warned_merging = true; - } - } - } - - // Update inertial prameters - let λ_prev = λ; - λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); - let θ = λ / λ_prev - λ; - - // Perform inertial update on μ. - // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ - // and μ_prev have zero weight. Since both have weights from the finite-dimensional - // subproblem with a proximal projection step, this is likely to happen when the - // spike is not needed. A copy of the pruned μ without artithmetic performed is - // stored in μ_prev. - let n_before_prune = μ.len(); - μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); - debug_assert!(μ.len() <= n_before_prune); - stats.pruned += n_before_prune - μ.len(); - - // Update residual - residual = calculate_residual(&μ, opA, b); - - let iter = state.iteration(); - stats.this_iters += 1; - - // Give statistics if needed - state.if_verbose(|| { - plotter.plot_spikes(iter, Option::<&S>::None, Some(&τv), &μ_prev); - full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new())) - }); - - // Update main tolerance for next iteration - ε = tolerance.update(ε, iter); - } - - postprocess(μ_prev, config, L2Squared, opA, b) -} diff -r fb911f72e698 -r c5d8bd1a7728 src/run.rs --- a/src/run.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/run.rs Thu Jan 23 23:35:28 2025 +0100 @@ -72,11 +72,6 @@ pointsource_fb_reg, pointsource_fista_reg, }; -use crate::radon_fb::{ - RadonFBConfig, - pointsource_radon_fb_reg, - pointsource_radon_fista_reg, -}; use crate::sliding_fb::{ SlidingFBConfig, TransportConfig, @@ -114,22 +109,33 @@ L1, L2Squared, }; +use crate::prox_penalty::{ + RadonSquared, + //ProxPenalty, +}; use alg_tools::norms::{L2, NormExponent}; use alg_tools::operator_arithmetic::Weighted; use anyhow::anyhow; +/// Available proximal terms +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +pub enum ProxTerm { + /// Partial-to-wave operator 𝒟. + Wave, + /// Radon-norm squared + RadonSquared +} + /// Available algorithms and their configurations #[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub enum AlgorithmConfig { - FB(FBConfig), - FISTA(FBConfig), + FB(FBConfig, ProxTerm), + FISTA(FBConfig, ProxTerm), FW(FWConfig), - PDPS(PDPSConfig), - RadonFB(RadonFBConfig), - RadonFISTA(RadonFBConfig), - SlidingFB(SlidingFBConfig), - ForwardPDPS(ForwardPDPSConfig), - SlidingPDPS(SlidingPDPSConfig), + PDPS(PDPSConfig, ProxTerm), + SlidingFB(SlidingFBConfig, ProxTerm), + ForwardPDPS(ForwardPDPSConfig, ProxTerm), + SlidingPDPS(SlidingPDPSConfig, ProxTerm), } fn unpack_tolerance(v : &Vec) -> Tolerance { @@ -165,45 +171,35 @@ use AlgorithmConfig::*; match self { - FB(fb) => FB(FBConfig { + FB(fb, prox) => FB(FBConfig { τ0 : cli.tau0.unwrap_or(fb.τ0), generic : override_fb_generic(fb.generic), .. fb - }), - FISTA(fb) => FISTA(FBConfig { + }, prox), + FISTA(fb, prox) => FISTA(FBConfig { τ0 : cli.tau0.unwrap_or(fb.τ0), generic : override_fb_generic(fb.generic), .. fb - }), - PDPS(pdps) => PDPS(PDPSConfig { + }, prox), + PDPS(pdps, prox) => PDPS(PDPSConfig { τ0 : cli.tau0.unwrap_or(pdps.τ0), σ0 : cli.sigma0.unwrap_or(pdps.σ0), acceleration : cli.acceleration.unwrap_or(pdps.acceleration), generic : override_fb_generic(pdps.generic), .. pdps - }), + }, prox), FW(fw) => FW(FWConfig { merging : cli.merging.clone().unwrap_or(fw.merging), tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), .. fw }), - RadonFB(fb) => RadonFB(RadonFBConfig { - τ0 : cli.tau0.unwrap_or(fb.τ0), - insertion : override_fb_generic(fb.insertion), - .. fb - }), - RadonFISTA(fb) => RadonFISTA(RadonFBConfig { - τ0 : cli.tau0.unwrap_or(fb.τ0), - insertion : override_fb_generic(fb.insertion), - .. fb - }), - SlidingFB(sfb) => SlidingFB(SlidingFBConfig { + SlidingFB(sfb, prox) => SlidingFB(SlidingFBConfig { τ0 : cli.tau0.unwrap_or(sfb.τ0), transport : override_transport(sfb.transport), insertion : override_fb_generic(sfb.insertion), .. sfb - }), - SlidingPDPS(spdps) => SlidingPDPS(SlidingPDPSConfig { + }, prox), + SlidingPDPS(spdps, prox) => SlidingPDPS(SlidingPDPSConfig { τ0 : cli.tau0.unwrap_or(spdps.τ0), σp0 : cli.sigmap0.unwrap_or(spdps.σp0), σd0 : cli.sigma0.unwrap_or(spdps.σd0), @@ -211,15 +207,15 @@ transport : override_transport(spdps.transport), insertion : override_fb_generic(spdps.insertion), .. spdps - }), - ForwardPDPS(fpdps) => ForwardPDPS(ForwardPDPSConfig { + }, prox), + ForwardPDPS(fpdps, prox) => ForwardPDPS(ForwardPDPSConfig { τ0 : cli.tau0.unwrap_or(fpdps.τ0), σp0 : cli.sigmap0.unwrap_or(fpdps.σp0), σd0 : cli.sigma0.unwrap_or(fpdps.σd0), //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), insertion : override_fb_generic(fpdps.insertion), .. fpdps - }), + }, prox), } } } @@ -250,12 +246,6 @@ /// The μPDPS primal-dual proximal splitting method #[clap(name = "pdps")] PDPS, - /// The RadonFB forward-backward method - #[clap(name = "radon_fb")] - RadonFB, - /// The RadonFISTA inertial forward-backward method - #[clap(name = "radon_fista")] - RadonFISTA, /// The sliding FB method #[clap(name = "sliding_fb", alias = "sfb")] SlidingFB, @@ -265,6 +255,27 @@ /// The PDPS method with a forward step for the smooth function #[clap(name = "forward_pdps", alias = "fpdps")] ForwardPDPS, + + // Radon variants + + /// The μFB forward-backward method with radon-norm squared proximal term + #[clap(name = "radon_fb")] + RadonFB, + /// The μFISTA inertial forward-backward method with radon-norm squared proximal term + #[clap(name = "radon_fista")] + RadonFISTA, + /// The μPDPS primal-dual proximal splitting method with radon-norm squared proximal term + #[clap(name = "radon_pdps")] + RadonPDPS, + /// The sliding FB method with radon-norm squared proximal term + #[clap(name = "radon_sliding_fb", alias = "radon_sfb")] + RadonSlidingFB, + /// The sliding PDPS method with radon-norm squared proximal term + #[clap(name = "radon_sliding_pdps", alias = "radon_spdps")] + RadonSlidingPDPS, + /// The PDPS method with a forward step for the smooth function with radon-norm squared proximal term + #[clap(name = "radon_forward_pdps", alias = "radon_fpdps")] + RadonForwardPDPS, } impl DefaultAlgorithm { @@ -272,19 +283,26 @@ pub fn default_config(&self) -> AlgorithmConfig { use DefaultAlgorithm::*; match *self { - FB => AlgorithmConfig::FB(Default::default()), - FISTA => AlgorithmConfig::FISTA(Default::default()), + FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave), + FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave), FW => AlgorithmConfig::FW(Default::default()), FWRelax => AlgorithmConfig::FW(FWConfig{ variant : FWVariant::Relaxed, .. Default::default() }), - PDPS => AlgorithmConfig::PDPS(Default::default()), - RadonFB => AlgorithmConfig::RadonFB(Default::default()), - RadonFISTA => AlgorithmConfig::RadonFISTA(Default::default()), - SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), - SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default()), - ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default()), + PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave), + SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave), + SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave), + ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave), + + // Radon variants + + RadonFB => AlgorithmConfig::FB(Default::default(), ProxTerm::RadonSquared), + RadonFISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::RadonSquared), + RadonPDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::RadonSquared), + RadonSlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::RadonSquared), + RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::RadonSquared), + RadonForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::RadonSquared), } } @@ -602,7 +620,7 @@ } #[replace_float_literals(F::cast_from(literal))] -impl RunnableExperiment for +impl RunnableExperiment for Named> where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField, @@ -628,7 +646,13 @@ DefaultBT : SensorGridBT + BTSearch, BTNodeLookup: BTNode, N>, RNDM : SpikeMerging, - NoiseDistr : Distribution + Serialize + std::fmt::Debug + NoiseDistr : Distribution + Serialize + std::fmt::Debug, + // DefaultSG : ForwardModel, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector>, + // PreadjointCodomain : Space + Bounded + DifferentiableRealMapping, + // DefaultSeminormOp : ProxPenalty, N>, + // DefaultSeminormOp : ProxPenalty, N>, + // RadonSquared : ProxPenalty, N>, + // RadonSquared : ProxPenalty, N>, { fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option> { @@ -663,7 +687,7 @@ let mut rng = StdRng::seed_from_u64(noise_seed); // Generate the data and calculate SSNR statistic - let b_hat : DVector<_> = opA.apply(μ_hat); + let b_hat = opA.apply(μ_hat); let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); let b = &b_hat + &noise; // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField @@ -683,128 +707,157 @@ |alg, iterator, plotter, running| { let μ = match alg { - AlgorithmConfig::FB(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + AlgorithmConfig::FB(ref algconfig, prox) => { + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fb_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fb_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_fb_reg( + &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter + ) + }), + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_fb_reg( + &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter + ) + }), _ => Err(NotImplemented) } }, - AlgorithmConfig::FISTA(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + AlgorithmConfig::FISTA(ref algconfig, prox) => { + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fista_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fista_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), - _ => Err(NotImplemented), - } - }, - AlgorithmConfig::RadonFB(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); - pointsource_radon_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), algconfig, + pointsource_fista_reg( + &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); - pointsource_radon_fb_reg( - &opA, &b, RadonRegTerm(α), algconfig, + pointsource_fista_reg( + &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter ) }), _ => Err(NotImplemented), } }, - AlgorithmConfig::RadonFISTA(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ - print!("{running}"); - pointsource_radon_fista_reg( - &opA, &b, NonnegRadonRegTerm(α), algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ - print!("{running}"); - pointsource_radon_fista_reg( - &opA, &b, RadonRegTerm(α), algconfig, - iterator, plotter - ) - }), - _ => Err(NotImplemented), - } - }, - AlgorithmConfig::SlidingFB(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + AlgorithmConfig::SlidingFB(ref algconfig, prox) => { + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_fb_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_fb_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_sliding_fb_reg( + &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter + ) + }), + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_sliding_fb_reg( + &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter + ) + }), _ => Err(NotImplemented), } }, - AlgorithmConfig::PDPS(ref algconfig) => { + AlgorithmConfig::PDPS(ref algconfig, prox) => { print!("{running}"); - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L2Squared ) }), - (Regularisation::Radon(α),DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L2Squared ) }), - (Regularisation::NonnegRadon(α), DataTerm::L1) => Ok({ + (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L1 ) }), - (Regularisation::Radon(α), DataTerm::L1) => Ok({ + (Regularisation::Radon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L1 ) }), + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + pointsource_pdps_reg( + &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, L2Squared + ) + }), + (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + pointsource_pdps_reg( + &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, L2Squared + ) + }), + (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ + pointsource_pdps_reg( + &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, L1 + ) + }), + (Regularisation::Radon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ + pointsource_pdps_reg( + &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, L1 + ) + }), + // _ => Err(NotImplemented), } }, AlgorithmConfig::FW(ref algconfig) => { @@ -831,7 +884,7 @@ #[replace_float_literals(F::cast_from(literal))] -impl RunnableExperiment for +impl RunnableExperiment for Named> where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField, @@ -859,6 +912,12 @@ RNDM : SpikeMerging, NoiseDistr : Distribution + Serialize + std::fmt::Debug, B : Mapping, Codomain = F> + Serialize + std::fmt::Debug, + // DefaultSG : ForwardModel, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector>, + // PreadjointCodomain : Bounded + DifferentiableRealMapping, + // DefaultSeminormOp : ProxPenalty, N>, + // DefaultSeminormOp : ProxPenalty, N>, + // RadonSquared : ProxPenalty, N>, + // RadonSquared : ProxPenalty, N>, { fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option> { @@ -937,9 +996,9 @@ |alg, iterator, plotter, running| { let Pair(μ, z) = match alg { - AlgorithmConfig::ForwardPDPS(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + AlgorithmConfig::ForwardPDPS(ref algconfig, prox) => { + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_forward_pdps_pair( &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, @@ -947,7 +1006,7 @@ /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_forward_pdps_pair( &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, @@ -955,12 +1014,28 @@ /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_forward_pdps_pair( + &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, + /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), + ) + }), + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_forward_pdps_pair( + &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, + /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), + ) + }), _ => Err(NotImplemented) } }, - AlgorithmConfig::SlidingPDPS(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ + AlgorithmConfig::SlidingPDPS(ref algconfig, prox) => { + match (regularisation, dataterm, prox) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_pdps_pair( &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, @@ -968,7 +1043,7 @@ /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_pdps_pair( &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, @@ -976,6 +1051,22 @@ /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), + (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_sliding_pdps_pair( + &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, + /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), + ) + }), + (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ + print!("{running}"); + pointsource_sliding_pdps_pair( + &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, + iterator, plotter, + /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), + ) + }), _ => Err(NotImplemented) } }, diff -r fb911f72e698 -r c5d8bd1a7728 src/sliding_fb.rs --- a/src/sliding_fb.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/sliding_fb.rs Thu Jan 23 23:35:28 2025 +0100 @@ -12,38 +12,19 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; -use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance}; +use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; use alg_tools::norms::Norm; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - //Bounded, -}; -use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::norms::{L2, Linfinity}; +use alg_tools::norms::L2; use crate::types::*; use crate::measures::{DiscreteMeasure, Radon, RNDM}; -use crate::measures::merging::{ - SpikeMergingMethod, - SpikeMerging, -}; +use crate::measures::merging::SpikeMerging; use crate::forward_model::{ ForwardModel, AdjointProductBoundedBy, LipschitzValues, }; -use crate::seminorms::DiscreteMeasureOp; //use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, @@ -151,7 +132,7 @@ Observable : Euclidean, for<'a> &'a Observable : Instance, //for<'b> A::Preadjoint<'b> : LipschitzValues, - D : DifferentiableMapping, DerivativeDomain=Loc>, + D : DifferentiableRealMapping, { use TransportStepLength::*; @@ -353,40 +334,29 @@ /// The parametrisation is as for [`pointsource_fb_reg`]. /// Inertia is currently not supported. #[replace_float_literals(F::cast_from(literal))] -pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>( - opA : &'a A, +pub fn pointsource_sliding_fb_reg( + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, config : &SlidingFBConfig, iterator : I, mut plotter : SeqPlotter, ) -> RNDM -where F : Float + ToNalgebraRealField, - I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - for<'b> A::Preadjoint<'b> : LipschitzValues, - A::PreadjointCodomain : DifferentiableMapping< - Loc, DerivativeDomain=Loc, Codomain=F - >, - GA : SupportGenerator + Clone, - A : ForwardModel, F, PreadjointCodomain = BTFN> - + AdjointProductBoundedBy, 𝒟, FloatType=F>, - //+ TransportLipschitz, - BTA : BTSearch>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN, - Codomain = BTFN>, - BT𝒟 : BTSearch>, - S: RealMapping + LocalAnalysis, N> - + DifferentiableMapping, DerivativeDomain=Loc>, - K: RealMapping + LocalAnalysis, N>, - //+ Differentiable, Derivative=Loc>, - BTNodeLookup: BTNode, N>, - Cube: P2Minimise, F>, - PlotLookup : Plotting, - RNDM : SpikeMerging, - Reg : SlidingRegTerm { +where + F : Float + ToNalgebraRealField, + I : AlgIteratorFactory>, + A : ForwardModel, F> + + AdjointProductBoundedBy, P, FloatType=F>, + //+ TransportLipschitz, + for<'b> &'b A::Observable : std::ops::Neg + Instance, + for<'b> A::Preadjoint<'b> : LipschitzValues, + A::PreadjointCodomain : DifferentiableRealMapping, + RNDM : SpikeMerging, + Reg : SlidingRegTerm, + P : ProxPenalty, + PlotLookup : Plotting, +{ // Check parameters assert!(config.τ0 > 0.0, "Invalid step length parameter"); @@ -398,13 +368,12 @@ let mut residual = -b; // Has to equal $Aμ-b$. // Set up parameters - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); let opAnorm = opA.opnorm_bound(Radon, L2); //let max_transport = config.max_transport.scale // * reg.radon_norm_bound(b.norm2_squared() / 2.0); //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; let ℓ = 0.0; - let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap(); + let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v @@ -446,15 +415,14 @@ // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. - let (d, _within_tolerances, τv̆) = 'adapt_transport: loop { + let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); - let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); + let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. - let (d, within_tolerances) = insert_and_reweigh( - &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), - op𝒟, op𝒟norm, + let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, &config.insertion, ®, &state, &mut stats, ); @@ -464,7 +432,7 @@ &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, ε, &config.transport ) { - break 'adapt_transport (d, within_tolerances, τv̆) + break 'adapt_transport (maybe_d, within_tolerances, τv̆) } }; @@ -480,20 +448,20 @@ (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) }); - // Merge spikes. - // This expects the prune below to prune γ. - // TODO: This may not work correctly in all cases. - let ins = &config.insertion; - if ins.merge_now(&state) { - if let SpikeMergingMethod::None = ins.merging { - } else { - stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { - let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; - let mut d = &τv̆ + op𝒟.preapply(ν); - reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) - }); - } - } + // // Merge spikes. + // // This expects the prune below to prune γ. + // // TODO: This may not work correctly in all cases. + // let ins = &config.insertion; + // if ins.merge_now(&state) { + // if let SpikeMergingMethod::None = ins.merging { + // } else { + // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { + // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; + // let mut d = &τv̆ + op𝒟.preapply(ν); + // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) + // }); + // } + // } // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the // latter needs to be pruned when μ is. @@ -514,7 +482,7 @@ // Give statistics if requested state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) }); diff -r fb911f72e698 -r c5d8bd1a7728 src/sliding_pdps.rs --- a/src/sliding_pdps.rs Mon Jan 06 11:32:57 2025 -0500 +++ b/src/sliding_pdps.rs Thu Jan 23 23:35:28 2025 +0100 @@ -11,30 +11,15 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::euclidean::Euclidean; -use alg_tools::sets::Cube; -use alg_tools::loc::Loc; use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; use alg_tools::norms::Norm; use alg_tools::direct_product::Pair; -use alg_tools::bisection_tree::{ - BTFN, - PreBTFN, - Bounds, - BTNodeLookup, - BTNode, - BTSearch, - P2Minimise, - SupportGenerator, - LocalAnalysis, - //Bounded, -}; -use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::linops::{ BoundedLinear, AXPY, GEMV, Adjointable, IdOp, }; use alg_tools::convex::{Conjugable, Prox}; -use alg_tools::norms::{L2, Linfinity, PairNorm}; +use alg_tools::norms::{L2, PairNorm}; use crate::types::*; use crate::measures::{DiscreteMeasure, Radon, RNDM}; @@ -45,7 +30,6 @@ LipschitzValues, }; // use crate::transport::TransportLipschitz; -use crate::seminorms::DiscreteMeasureOp; //use crate::tolerance::Tolerance; use crate::plot::{ SeqPlotter, @@ -101,12 +85,12 @@ /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. #[replace_float_literals(F::cast_from(literal))] pub fn pointsource_sliding_pdps_pair< - 'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize + F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize >( - opA : &'a A, + opA : &A, b : &A::Observable, reg : Reg, - op𝒟 : &'a 𝒟, + prox_penalty : &P, config : &SlidingPDPSConfig, iterator : I, mut plotter : SeqPlotter, @@ -120,36 +104,25 @@ where F : Float + ToNalgebraRealField, I : AlgIteratorFactory>, - for<'b> &'b A::Observable : std::ops::Neg + Instance, - for<'b> A::Preadjoint<'b> : LipschitzValues, - BTFN : DifferentiableRealMapping, - GA : SupportGenerator + Clone, A : ForwardModel< MeasureZ, F, PairNorm, - PreadjointCodomain = Pair, Z>, + PreadjointCodomain = Pair, > - + AdjointProductPairBoundedBy, 𝒟, IdOp, FloatType=F>, - BTA : BTSearch>, - G𝒟 : SupportGenerator + Clone, - 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN, - Codomain = BTFN>, - BT𝒟 : BTSearch>, - S: RealMapping + LocalAnalysis, N> - + DifferentiableRealMapping, - K: RealMapping + LocalAnalysis, N>, - //+ Differentiable, Derivative=Loc>, - BTNodeLookup: BTNode, N>, - Cube: P2Minimise, F>, + + AdjointProductPairBoundedBy, P, IdOp, FloatType=F>, + S : DifferentiableRealMapping, + for<'b> &'b A::Observable : std::ops::Neg + Instance, + for<'b> A::Preadjoint<'b> : LipschitzValues, PlotLookup : Plotting, RNDM : SpikeMerging, Reg : SlidingRegTerm, + P : ProxPenalty, // KOpM : Linear, Codomain=Y> // + GEMV> // + Preadjointable< // RNDM, Y, - // PreadjointCodomain = BTFN, + // PreadjointCodomain = S, // > // + TransportLipschitz // + AdjointProductBoundedBy, 𝒟, FloatType=F>, @@ -185,7 +158,6 @@ let zero_z = z.similar_origin(); // Set up parameters - let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); // TODO: maybe this PairNorm doesn't make sense here? let opAnorm = opA.opnorm_bound(PairNorm(Radon, L2, L2), L2); let bigθ = 0.0; //opKμ.transport_lipschitz_factor(L2Squared); @@ -193,7 +165,7 @@ let nKz = opKz.opnorm_bound(L2, L2); let ℓ = 0.0; let opIdZ = IdOp::new(); - let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap(); + let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); // We need to satisfy // // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 @@ -278,18 +250,17 @@ // Solve finite-dimensional subproblem several times until the dual variable for the // regularisation term conforms to the assumptions made for the transport above. - let (d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop { + let (maybe_d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop { // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) let residual_μ̆ = calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); - let Pair(τv̆, τz) = opA.preadjoint().apply(residual_μ̆ * τ); + let mut τv̆z = opA.preadjoint().apply(residual_μ̆ * τ); // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. - let (d, within_tolerances) = insert_and_reweigh( - &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), - op𝒟, op𝒟norm, + let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( + &mut μ, &mut τv̆z.0, &γ1, Some(&μ_base_minus_γ0), τ, ε, &config.insertion, ®, &state, &mut stats, ); @@ -300,7 +271,7 @@ &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, ε, &config.transport ) { - break 'adapt_transport (d, within_tolerances, Pair(τv̆, τz)) + break 'adapt_transport (maybe_d, within_tolerances, τv̆z) } }; @@ -364,7 +335,7 @@ stats.this_iters += 1; state.if_verbose(|| { - plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); + plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) });