diff -r efa60bc4f743 -r b087e3eab191 src/fb.rs
--- a/src/fb.rs Thu Aug 29 00:00:00 2024 -0500
+++ b/src/fb.rs Tue Dec 31 09:25:45 2024 -0500
@@ -6,10 +6,7 @@
* Valkonen T. - _Proximal methods for point source localisation_,
[arXiv:2212.02991](https://arxiv.org/abs/2212.02991).
-The main routine is [`pointsource_fb_reg`]. It is based on [`generic_pointsource_fb_reg`], which is
-also used by our [primal-dual proximal splitting][crate::pdps] implementation.
-
-FISTA-type inertia can also be enabled through [`FBConfig::meta`].
+The main routine is [`pointsource_fb_reg`].
## Problem
@@ -76,7 +73,7 @@
$$
-We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by
+We solve this with either SSN or FB as determined by
[`InnerSettings`] in [`FBGenericConfig::inner`].
*/
@@ -87,10 +84,11 @@
use alg_tools::iterate::{
AlgIteratorFactory,
- AlgIteratorState,
+ AlgIteratorIteration,
+ AlgIterator,
};
use alg_tools::euclidean::Euclidean;
-use alg_tools::linops::{Apply, GEMV};
+use alg_tools::linops::{Mapping, GEMV};
use alg_tools::sets::Cube;
use alg_tools::loc::Loc;
use alg_tools::bisection_tree::{
@@ -107,17 +105,24 @@
};
use alg_tools::mapping::RealMapping;
use alg_tools::nalgebra_support::ToNalgebraRealField;
+use alg_tools::instance::Instance;
+use alg_tools::norms::Linfinity;
use crate::types::*;
use crate::measures::{
DiscreteMeasure,
+ RNDM,
DeltaMeasure,
+ Radon,
};
use crate::measures::merging::{
SpikeMergingMethod,
SpikeMerging,
};
-use crate::forward_model::ForwardModel;
+use crate::forward_model::{
+ ForwardModel,
+ AdjointProductBoundedBy
+};
use crate::seminorms::DiscreteMeasureOp;
use crate::subproblem::{
InnerSettings,
@@ -146,8 +151,7 @@
pub generic : FBGenericConfig,
}
-/// Settings for the solution of the stepwise optimality condition in algorithms based on
-/// [`generic_pointsource_fb_reg`].
+/// Settings for the solution of the stepwise optimality condition.
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct FBGenericConfig {
@@ -188,8 +192,8 @@
/// Iterations between merging heuristic tries
pub merge_every : usize,
- /// Save $μ$ for postprocessing optimisation
- pub postprocessing : bool
+ // /// Save $μ$ for postprocessing optimisation
+ // pub postprocessing : bool
}
#[replace_float_literals(F::cast_from(literal))]
@@ -221,29 +225,36 @@
final_merging : Default::default(),
merge_every : 10,
merge_tolerance_mult : 2.0,
- postprocessing : false,
+ // postprocessing : false,
}
}
}
+impl FBGenericConfig {
+ /// Check if merging should be attempted this iteration
+ pub fn merge_now(&self, state : &AlgIteratorIteration) -> bool {
+ state.iteration() % self.merge_every == 0
+ }
+}
+
/// TODO: document.
/// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike
/// locations, while `ν_delta` may have different locations.
#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn insert_and_reweigh<
- 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
+ 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, I, const N : usize
>(
- μ : &mut DiscreteMeasure, F>,
- minus_τv : &BTFN,
- μ_base : &DiscreteMeasure, F>,
- ν_delta: Option<&DiscreteMeasure, F>>,
+ μ : &mut RNDM,
+ τv : &BTFN,
+ μ_base : &RNDM,
+ ν_delta: Option<&RNDM>,
op𝒟 : &'a 𝒟,
op𝒟norm : F,
τ : F,
ε : F,
config : &FBGenericConfig,
reg : &Reg,
- state : &State,
+ state : &AlgIteratorIteration,
stats : &mut IterInfo,
) -> (BTFN, BTA, N>, bool)
where F : Float + ToNalgebraRealField,
@@ -255,9 +266,8 @@
S: RealMapping + LocalAnalysis, N>,
K: RealMapping + LocalAnalysis, N>,
BTNodeLookup: BTNode, N>,
- DiscreteMeasure, F> : SpikeMerging,
Reg : RegTerm,
- State : AlgIteratorState {
+ I : AlgIterator {
// Maximum insertion count and measure difference calculation depend on insertion style.
let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
@@ -265,11 +275,10 @@
_ => (config.max_insertions, !state.is_quiet()),
};
- // TODO: should avoid a copy of μ_base here.
- let ω0 = op𝒟.apply(match ν_delta {
- None => μ_base.clone(),
- Some(ν_d) => &*μ_base + ν_d,
- });
+ let ω0 = match ν_delta {
+ None => op𝒟.apply(μ_base),
+ Some(ν) => op𝒟.apply(μ_base + ν),
+ };
// Add points to support until within error tolerance or maximum insertion count reached.
let mut count = 0;
@@ -277,10 +286,12 @@
if μ.len() > 0 {
// Form finite-dimensional subproblem. The subproblem references to the original μ^k
// from the beginning of the iteration are all contained in the immutable c and g.
+ // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
+ // problems have not yet been updated to sign change.
let à = op𝒟.findim_matrix(μ.iter_locations());
let g̃ = DVector::from_iterator(μ.len(),
μ.iter_locations()
- .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ))
+ .map(|ζ| ω0.apply(ζ) - τv.apply(ζ))
.map(F::to_nalgebra_mixed));
let mut x = μ.masses_dvector();
@@ -298,12 +309,12 @@
μ.set_masses_dvector(&x);
}
- // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality
+ // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality
// conditions in the predual space, and finding new points for insertion, if necessary.
- let mut d = minus_τv + op𝒟.preapply(match ν_delta {
- None => μ_base.sub_matching(μ),
- Some(ν) => μ_base.sub_matching(μ) + ν
- });
+ let mut d = τv + match ν_delta {
+ None => op𝒟.preapply(μ.sub_matching(μ_base)),
+ Some(ν) => op𝒟.preapply(μ.sub_matching(μ_base) - ν)
+ };
// If no merging heuristic is used, let's be more conservative about spike insertion,
// and skip it after first round. If merging is done, being more greedy about spike
@@ -330,11 +341,9 @@
// No point in optimising the weight here; the finite-dimensional algorithm is fast.
*μ += DeltaMeasure { x : ξ, α : 0.0 };
count += 1;
+ stats.inserted += 1;
};
- // TODO: should redo everything if some transports cause a problem.
- // Maybe implementation should call above loop as a closure.
-
if !within_tolerances && warn_insertions {
// Complain (but continue) if we failed to get within tolerances
// by inserting more points.
@@ -346,61 +355,33 @@
(d, within_tolerances)
}
-#[replace_float_literals(F::cast_from(literal))]
-pub(crate) fn prune_and_maybe_simple_merge<
- 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
->(
- μ : &mut DiscreteMeasure, F>,
- minus_τv : &BTFN,
- μ_base : &DiscreteMeasure, F>,
- op𝒟 : &'a 𝒟,
- τ : F,
- ε : F,
- config : &FBGenericConfig,
- reg : &Reg,
- state : &State,
- stats : &mut IterInfo,
-)
-where F : Float + ToNalgebraRealField,
- GA : SupportGenerator + Clone,
- BTA : BTSearch>,
- G𝒟 : SupportGenerator + Clone,
- 𝒟 : DiscreteMeasureOp, F, PreCodomain = PreBTFN>,
- 𝒟::Codomain : RealMapping,
- S: RealMapping + LocalAnalysis, N>,
- K: RealMapping + LocalAnalysis, N>,
- BTNodeLookup: BTNode, N>,
- DiscreteMeasure, F> : SpikeMerging,
- Reg : RegTerm,
- State : AlgIteratorState {
- if state.iteration() % config.merge_every == 0 {
- stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
- let mut d = minus_τv + op𝒟.preapply(μ_base.sub_matching(&μ_candidate));
- reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
- });
- }
-
+pub(crate) fn prune_with_stats(
+ μ : &mut RNDM,
+) -> usize {
let n_before_prune = μ.len();
μ.prune();
debug_assert!(μ.len() <= n_before_prune);
- stats.pruned += n_before_prune - μ.len();
+ n_before_prune - μ.len()
}
#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn postprocess<
F : Float,
V : Euclidean + Clone,
- A : GEMV, F>, Codomain = V>,
+ A : GEMV, Codomain = V>,
D : DataTerm,
const N : usize
> (
- mut μ : DiscreteMeasure, F>,
+ mut μ : RNDM,
config : &FBGenericConfig,
dataterm : D,
opA : &A,
b : &V,
-) -> DiscreteMeasure, F>
-where DiscreteMeasure, F> : SpikeMerging {
+) -> RNDM
+where
+ RNDM : SpikeMerging,
+ for<'a> &'a RNDM : Instance>,
+{
μ.merge_spikes_fitness(config.merging,
|μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
|&v| v);
@@ -437,15 +418,13 @@
fbconfig : &FBConfig,
iterator : I,
mut plotter : SeqPlotter,
-) -> DiscreteMeasure, F>
+) -> RNDM
where F : Float + ToNalgebraRealField,
I : AlgIteratorFactory>,
for<'b> &'b A::Observable : std::ops::Neg