--- a/src/fb.rs Tue Aug 01 10:32:12 2023 +0300 +++ b/src/fb.rs Thu Aug 29 00:00:00 2024 -0500 @@ -136,17 +136,6 @@ DataTerm, }; -/// Method for constructing $μ$ on each iteration -#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] -#[allow(dead_code)] -pub enum InsertionStyle { - /// Resuse previous $μ$ from previous iteration, optimising weights - /// before inserting new spikes. - Reuse, - /// Start each iteration with $μ=0$. - Zero, -} - /// Settings for [`pointsource_fb_reg`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] @@ -154,7 +143,7 @@ /// Step length scaling pub τ0 : F, /// Generic parameters - pub insertion : FBGenericConfig<F>, + pub generic : FBGenericConfig<F>, } /// Settings for the solution of the stepwise optimality condition in algorithms based on @@ -162,29 +151,43 @@ #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] pub struct FBGenericConfig<F : Float> { - /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`]. - pub insertion_style : InsertionStyle, /// Tolerance for point insertion. pub tolerance : Tolerance<F>, + /// Stop looking for predual maximum (where to isert a new point) below /// `tolerance` multiplied by this factor. + /// + /// Not used by [`super::radon_fb`]. pub insertion_cutoff_factor : F, + /// Settings for branch and bound refinement when looking for predual maxima pub refinement : RefinementSettings<F>, + /// Maximum insertions within each outer iteration + /// + /// Not used by [`super::radon_fb`]. pub max_insertions : usize, + /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. + /// + /// Not used by [`super::radon_fb`]. pub bootstrap_insertions : Option<(usize, usize)>, + /// Inner method settings pub inner : InnerSettings<F>, + /// Spike merging method pub merging : SpikeMergingMethod<F>, + /// Tolerance multiplier for merges pub merge_tolerance_mult : F, + /// Spike merging method after the last step pub final_merging : SpikeMergingMethod<F>, + /// Iterations between merging heuristic tries pub merge_every : usize, + /// Save $μ$ for postprocessing optimisation pub postprocessing : bool } @@ -194,7 +197,7 @@ fn default() -> Self { FBConfig { τ0 : 0.99, - insertion : Default::default() + generic : Default::default(), } } } @@ -203,7 +206,6 @@ impl<F : Float> Default for FBGenericConfig<F> { fn default() -> Self { FBGenericConfig { - insertion_style : InsertionStyle::Reuse, tolerance : Default::default(), insertion_cutoff_factor : 1.0, refinement : Default::default(), @@ -211,7 +213,7 @@ //bootstrap_insertions : None, bootstrap_insertions : Some((10, 1)), inner : InnerSettings { - method : InnerMethod::SSN, + method : InnerMethod::Default, .. Default::default() }, merging : SpikeMergingMethod::None, @@ -224,35 +226,9 @@ } } -#[replace_float_literals(F::cast_from(literal))] -pub(crate) fn μ_diff<F : Float, const N : usize>( - μ_new : &DiscreteMeasure<Loc<F, N>, F>, - μ_base : &DiscreteMeasure<Loc<F, N>, F>, - ν_delta : Option<&DiscreteMeasure<Loc<F, N>, F>>, - config : &FBGenericConfig<F> -) -> DiscreteMeasure<Loc<F, N>, F> { - let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style { - InsertionStyle::Reuse => { - μ_new.iter_spikes() - .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0))) - .map(|(δ, α_base)| (δ.x, α_base - δ.α)) - .collect() - }, - InsertionStyle::Zero => { - μ_new.iter_spikes() - .map(|δ| -δ) - .chain(μ_base.iter_spikes().copied()) - .collect() - } - }; - ν.prune(); // Potential small performance improvement - // Add ν_delta if given - match ν_delta { - None => ν, - Some(ν_d) => ν + ν_d, - } -} - +/// TODO: document. +/// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike +/// locations, while `ν_delta` may have different locations. #[replace_float_literals(F::cast_from(literal))] pub(crate) fn insert_and_reweigh< 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize @@ -284,24 +260,15 @@ State : AlgIteratorState { // Maximum insertion count and measure difference calculation depend on insertion style. - let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { + let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { (i, Some((l, k))) if i <= l => (k, false), _ => (config.max_insertions, !state.is_quiet()), }; - let max_insertions = match config.insertion_style { - InsertionStyle::Zero => { - todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled."); - // let n = μ.len(); - // μ = DiscreteMeasure::new(); - // n + m - }, - InsertionStyle::Reuse => m, - }; - // TODO: should avoid a second copy of μ here; μ_base already stores a copy. + // TODO: should avoid a copy of μ_base here. let ω0 = op𝒟.apply(match ν_delta { - None => μ.clone(), - Some(ν_d) => &*μ + ν_d, + None => μ_base.clone(), + Some(ν_d) => &*μ_base + ν_d, }); // Add points to support until within error tolerance or maximum insertion count reached. @@ -333,7 +300,10 @@ // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality // conditions in the predual space, and finding new points for insertion, if necessary. - let mut d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_base, ν_delta, config)); + let mut d = minus_τv + op𝒟.preapply(match ν_delta { + None => μ_base.sub_matching(μ), + Some(ν) => μ_base.sub_matching(μ) + ν + }); // If no merging heuristic is used, let's be more conservative about spike insertion, // and skip it after first round. If merging is done, being more greedy about spike @@ -404,16 +374,10 @@ Reg : RegTerm<F, N>, State : AlgIteratorState { if state.iteration() % config.merge_every == 0 { - let n_before_merge = μ.len(); - μ.merge_spikes(config.merging, |μ_candidate| { - let μd = μ_diff(&μ_candidate, &μ_base, None, config); - let mut d = minus_τv + op𝒟.preapply(μd); - + stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { + let mut d = minus_τv + op𝒟.preapply(μ_base.sub_matching(&μ_candidate)); reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) - .then_some(()) }); - debug_assert!(μ.len() >= n_before_merge); - stats.merged += μ.len() - n_before_merge; } let n_before_prune = μ.len(); @@ -495,7 +459,7 @@ Reg : RegTerm<F, N> { // Set up parameters - let config = &fbconfig.insertion; + let config = &fbconfig.generic; let op𝒟norm = op𝒟.opnorm_bound(); let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled @@ -621,7 +585,7 @@ Reg : RegTerm<F, N> { // Set up parameters - let config = &fbconfig.insertion; + let config = &fbconfig.generic; let op𝒟norm = op𝒟.opnorm_bound(); let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); let mut λ = 1.0;