src/fb.rs

branch
dev
changeset 32
56c8adc32b09
parent 29
87649ccfa6a8
child 34
efa60bc4f743
equal deleted inserted replaced
30:bd13c2ae3450 32:56c8adc32b09
81 */ 81 */
82 82
83 use numeric_literals::replace_float_literals; 83 use numeric_literals::replace_float_literals;
84 use serde::{Serialize, Deserialize}; 84 use serde::{Serialize, Deserialize};
85 use colored::Colorize; 85 use colored::Colorize;
86 use nalgebra::{DVector, DMatrix}; 86 use nalgebra::DVector;
87 87
88 use alg_tools::iterate::{ 88 use alg_tools::iterate::{
89 AlgIteratorFactory, 89 AlgIteratorFactory,
90 AlgIteratorState, 90 AlgIteratorState,
91 }; 91 };
92 use alg_tools::euclidean::Euclidean; 92 use alg_tools::euclidean::Euclidean;
93 use alg_tools::linops::Apply; 93 use alg_tools::linops::{Apply, GEMV};
94 use alg_tools::sets::Cube; 94 use alg_tools::sets::Cube;
95 use alg_tools::loc::Loc; 95 use alg_tools::loc::Loc;
96 use alg_tools::mapping::Mapping;
97 use alg_tools::bisection_tree::{ 96 use alg_tools::bisection_tree::{
98 BTFN, 97 BTFN,
99 PreBTFN, 98 PreBTFN,
100 Bounds, 99 Bounds,
101 BTNodeLookup, 100 BTNodeLookup,
102 BTNode, 101 BTNode,
103 BTSearch, 102 BTSearch,
104 P2Minimise, 103 P2Minimise,
105 SupportGenerator, 104 SupportGenerator,
106 LocalAnalysis, 105 LocalAnalysis,
107 Bounded, 106 BothGenerators,
108 }; 107 };
109 use alg_tools::mapping::RealMapping; 108 use alg_tools::mapping::RealMapping;
110 use alg_tools::nalgebra_support::ToNalgebraRealField; 109 use alg_tools::nalgebra_support::ToNalgebraRealField;
111 110
112 use crate::types::*; 111 use crate::types::*;
117 use crate::measures::merging::{ 116 use crate::measures::merging::{
118 SpikeMergingMethod, 117 SpikeMergingMethod,
119 SpikeMerging, 118 SpikeMerging,
120 }; 119 };
121 use crate::forward_model::ForwardModel; 120 use crate::forward_model::ForwardModel;
122 use crate::seminorms::{ 121 use crate::seminorms::DiscreteMeasureOp;
123 DiscreteMeasureOp, Lipschitz
124 };
125 use crate::subproblem::{ 122 use crate::subproblem::{
126 nonneg::quadratic_nonneg,
127 unconstrained::quadratic_unconstrained,
128 InnerSettings, 123 InnerSettings,
129 InnerMethod, 124 InnerMethod,
130 }; 125 };
131 use crate::tolerance::Tolerance; 126 use crate::tolerance::Tolerance;
132 use crate::plot::{ 127 use crate::plot::{
133 SeqPlotter, 128 SeqPlotter,
134 Plotting, 129 Plotting,
135 PlotLookup 130 PlotLookup
136 }; 131 };
137 use crate::regularisation::{ 132 use crate::regularisation::RegTerm;
138 NonnegRadonRegTerm, 133 use crate::dataterm::{
139 RadonRegTerm, 134 calculate_residual,
135 L2Squared,
136 DataTerm,
140 }; 137 };
141 138
142 /// Method for constructing $μ$ on each iteration 139 /// Method for constructing $μ$ on each iteration
143 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 140 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
144 #[allow(dead_code)] 141 #[allow(dead_code)]
148 Reuse, 145 Reuse,
149 /// Start each iteration with $μ=0$. 146 /// Start each iteration with $μ=0$.
150 Zero, 147 Zero,
151 } 148 }
152 149
153 /// Meta-algorithm type
154 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
155 #[allow(dead_code)]
156 pub enum FBMetaAlgorithm {
157 /// No meta-algorithm
158 None,
159 /// FISTA-style inertia
160 InertiaFISTA,
161 }
162
163 /// Settings for [`pointsource_fb_reg`]. 150 /// Settings for [`pointsource_fb_reg`].
164 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 151 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
165 #[serde(default)] 152 #[serde(default)]
166 pub struct FBConfig<F : Float> { 153 pub struct FBConfig<F : Float> {
167 /// Step length scaling 154 /// Step length scaling
168 pub τ0 : F, 155 pub τ0 : F,
169 /// Meta-algorithm to apply
170 pub meta : FBMetaAlgorithm,
171 /// Generic parameters 156 /// Generic parameters
172 pub insertion : FBGenericConfig<F>, 157 pub insertion : FBGenericConfig<F>,
173 } 158 }
174 159
175 /// Settings for the solution of the stepwise optimality condition in algorithms based on 160 /// Settings for the solution of the stepwise optimality condition in algorithms based on
207 #[replace_float_literals(F::cast_from(literal))] 192 #[replace_float_literals(F::cast_from(literal))]
208 impl<F : Float> Default for FBConfig<F> { 193 impl<F : Float> Default for FBConfig<F> {
209 fn default() -> Self { 194 fn default() -> Self {
210 FBConfig { 195 FBConfig {
211 τ0 : 0.99, 196 τ0 : 0.99,
212 meta : FBMetaAlgorithm::None,
213 insertion : Default::default() 197 insertion : Default::default()
214 } 198 }
215 } 199 }
216 } 200 }
217 201
238 postprocessing : false, 222 postprocessing : false,
239 } 223 }
240 } 224 }
241 } 225 }
242 226
243 /// Trait for specialisation of [`generic_pointsource_fb_reg`] to basic FB, FISTA. 227 #[replace_float_literals(F::cast_from(literal))]
228 pub(crate) fn μ_diff<F : Float, const N : usize>(
229 μ_new : &DiscreteMeasure<Loc<F, N>, F>,
230 μ_base : &DiscreteMeasure<Loc<F, N>, F>,
231 ν_delta : Option<&DiscreteMeasure<Loc<F, N>, F>>,
232 config : &FBGenericConfig<F>
233 ) -> DiscreteMeasure<Loc<F, N>, F> {
234 let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style {
235 InsertionStyle::Reuse => {
236 μ_new.iter_spikes()
237 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0)))
238 .map(|(δ, α_base)| (δ.x, α_base - δ.α))
239 .collect()
240 },
241 InsertionStyle::Zero => {
242 μ_new.iter_spikes()
243 .map(|δ| -δ)
244 .chain(μ_base.iter_spikes().copied())
245 .collect()
246 }
247 };
248 ν.prune(); // Potential small performance improvement
249 // Add ν_delta if given
250 match ν_delta {
251 None => ν,
252 Some(ν_d) => ν + ν_d,
253 }
254 }
255
256 #[replace_float_literals(F::cast_from(literal))]
257 pub(crate) fn insert_and_reweigh<
258 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
259 >(
260 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
261 minus_τv : &BTFN<F, GA, BTA, N>,
262 μ_base : &DiscreteMeasure<Loc<F, N>, F>,
263 ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>,
264 op𝒟 : &'a 𝒟,
265 op𝒟norm : F,
266 τ : F,
267 ε : F,
268 config : &FBGenericConfig<F>,
269 reg : &Reg,
270 state : &State,
271 stats : &mut IterInfo<F, N>,
272 ) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool)
273 where F : Float + ToNalgebraRealField,
274 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
275 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
276 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
277 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
278 𝒟::Codomain : RealMapping<F, N>,
279 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
280 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
281 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
282 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
283 Reg : RegTerm<F, N>,
284 State : AlgIteratorState {
285
286 // Maximum insertion count and measure difference calculation depend on insertion style.
287 let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
288 (i, Some((l, k))) if i <= l => (k, false),
289 _ => (config.max_insertions, !state.is_quiet()),
290 };
291 let max_insertions = match config.insertion_style {
292 InsertionStyle::Zero => {
293 todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled.");
294 // let n = μ.len();
295 // μ = DiscreteMeasure::new();
296 // n + m
297 },
298 InsertionStyle::Reuse => m,
299 };
300
301 // TODO: should avoid a second copy of μ here; μ_base already stores a copy.
302 let ω0 = op𝒟.apply(match ν_delta {
303 None => μ.clone(),
304 Some(ν_d) => &*μ + ν_d,
305 });
306
307 // Add points to support until within error tolerance or maximum insertion count reached.
308 let mut count = 0;
309 let (within_tolerances, d) = 'insertion: loop {
310 if μ.len() > 0 {
311 // Form finite-dimensional subproblem. The subproblem references to the original μ^k
312 // from the beginning of the iteration are all contained in the immutable c and g.
313 let à = op𝒟.findim_matrix(μ.iter_locations());
314 let g̃ = DVector::from_iterator(μ.len(),
315 μ.iter_locations()
316 .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ))
317 .map(F::to_nalgebra_mixed));
318 let mut x = μ.masses_dvector();
319
320 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
321 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
322 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
323 // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2
324 // = n |𝒟| |x|_2, where n is the number of points. Therefore
325 let Ã_normest = op𝒟norm * F::cast_from(μ.len());
326
327 // Solve finite-dimensional subproblem.
328 stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config);
329
330 // Update masses of μ based on solution of finite-dimensional subproblem.
331 μ.set_masses_dvector(&x);
332 }
333
334 // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality
335 // conditions in the predual space, and finding new points for insertion, if necessary.
336 let mut d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_base, ν_delta, config));
337
338 // If no merging heuristic is used, let's be more conservative about spike insertion,
339 // and skip it after first round. If merging is done, being more greedy about spike
340 // insertion also seems to improve performance.
341 let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging {
342 false
343 } else {
344 count > 0
345 };
346
347 // Find a spike to insert, if needed
348 let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation(
349 &mut d, τ, ε, skip_by_rough_check, config
350 ) {
351 None => break 'insertion (true, d),
352 Some(res) => res,
353 };
354
355 // Break if maximum insertion count reached
356 if count >= max_insertions {
357 break 'insertion (in_bounds, d)
358 }
359
360 // No point in optimising the weight here; the finite-dimensional algorithm is fast.
361 *μ += DeltaMeasure { x : ξ, α : 0.0 };
362 count += 1;
363 };
364
365 // TODO: should redo everything if some transports cause a problem.
366 // Maybe implementation should call above loop as a closure.
367
368 if !within_tolerances && warn_insertions {
369 // Complain (but continue) if we failed to get within tolerances
370 // by inserting more points.
371 let err = format!("Maximum insertions reached without achieving \
372 subproblem solution tolerance");
373 println!("{}", err.red());
374 }
375
376 (d, within_tolerances)
377 }
378
379 #[replace_float_literals(F::cast_from(literal))]
380 pub(crate) fn prune_and_maybe_simple_merge<
381 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
382 >(
383 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
384 minus_τv : &BTFN<F, GA, BTA, N>,
385 μ_base : &DiscreteMeasure<Loc<F, N>, F>,
386 op𝒟 : &'a 𝒟,
387 τ : F,
388 ε : F,
389 config : &FBGenericConfig<F>,
390 reg : &Reg,
391 state : &State,
392 stats : &mut IterInfo<F, N>,
393 )
394 where F : Float + ToNalgebraRealField,
395 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
396 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
397 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
398 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
399 𝒟::Codomain : RealMapping<F, N>,
400 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
401 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
402 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
403 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
404 Reg : RegTerm<F, N>,
405 State : AlgIteratorState {
406 if state.iteration() % config.merge_every == 0 {
407 let n_before_merge = μ.len();
408 μ.merge_spikes(config.merging, |μ_candidate| {
409 let μd = μ_diff(&μ_candidate, &μ_base, None, config);
410 let mut d = minus_τv + op𝒟.preapply(μd);
411
412 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
413 .then_some(())
414 });
415 debug_assert!(μ.len() >= n_before_merge);
416 stats.merged += μ.len() - n_before_merge;
417 }
418
419 let n_before_prune = μ.len();
420 μ.prune();
421 debug_assert!(μ.len() <= n_before_prune);
422 stats.pruned += n_before_prune - μ.len();
423 }
424
425 #[replace_float_literals(F::cast_from(literal))]
426 pub(crate) fn postprocess<
427 F : Float,
428 V : Euclidean<F> + Clone,
429 A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>,
430 D : DataTerm<F, V, N>,
431 const N : usize
432 > (
433 mut μ : DiscreteMeasure<Loc<F, N>, F>,
434 config : &FBGenericConfig<F>,
435 dataterm : D,
436 opA : &A,
437 b : &V,
438 ) -> DiscreteMeasure<Loc<F, N>, F>
439 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
440 μ.merge_spikes_fitness(config.merging,
441 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
442 |&v| v);
443 μ.prune();
444 μ
445 }
446
447 /// Iteratively solve the pointsource localisation problem using forward-backward splitting.
244 /// 448 ///
245 /// The idea is that the residual $Aμ - b$ in the forward step can be replaced by an arbitrary 449 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
246 /// value. For example, to implement [primal-dual proximal splitting][crate::pdps] we replace it
247 /// with the dual variable $y$. We can then also implement alternative data terms, as the
248 /// (pre)differential of $F(μ)=F\_0(Aμ-b)$ is $F\'(μ) = A\_*F\_0\'(Aμ-b)$. In the case of the
249 /// quadratic fidelity $F_0(y)=\frac{1}{2}\\|y\\|_2^2$ in a Hilbert space, of course,
250 /// $F\_0\'(Aμ-b)=Aμ-b$ is the residual.
251 pub trait FBSpecialisation<F : Float, Observable : Euclidean<F>, const N : usize> : Sized {
252 /// Updates the residual and does any necessary pruning of `μ`.
253 ///
254 /// Returns the new residual and possibly a new step length.
255 ///
256 /// The measure `μ` may also be modified to apply, e.g., inertia to it.
257 /// The updated residual should correspond to the residual at `μ`.
258 /// See the [trait documentation][FBSpecialisation] for the use and meaning of the residual.
259 ///
260 /// The parameter `μ_base` is the base point of the iteration, typically the previous iterate,
261 /// but for, e.g., FISTA has inertia applied to it.
262 fn update(
263 &mut self,
264 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
265 μ_base : &DiscreteMeasure<Loc<F, N>, F>,
266 ) -> (Observable, Option<F>);
267
268 /// Calculates the data term value corresponding to iterate `μ` and available residual.
269 ///
270 /// Inertia and other modifications, as deemed, necessary, should be applied to `μ`.
271 ///
272 /// The blanket implementation correspondsn to the 2-norm-squared data fidelity
273 /// $\\|\text{residual}\\|\_2^2/2$.
274 fn calculate_fit(
275 &self,
276 _μ : &DiscreteMeasure<Loc<F, N>, F>,
277 residual : &Observable
278 ) -> F {
279 residual.norm2_squared_div2()
280 }
281
282 /// Calculates the data term value at $μ$.
283 ///
284 /// Unlike [`Self::calculate_fit`], no inertia, etc., should be applied to `μ`.
285 fn calculate_fit_simple(
286 &self,
287 μ : &DiscreteMeasure<Loc<F, N>, F>,
288 ) -> F;
289
290 /// Returns the final iterate after any necessary postprocess pruning, merging, etc.
291 fn postprocess(self, mut μ : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>)
292 -> DiscreteMeasure<Loc<F, N>, F>
293 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
294 μ.merge_spikes_fitness(merging,
295 |μ̃| self.calculate_fit_simple(μ̃),
296 |&v| v);
297 μ.prune();
298 μ
299 }
300
301 /// Returns measure to be used for value calculations, which may differ from μ.
302 fn value_μ<'c, 'b : 'c>(&'b self, μ : &'c DiscreteMeasure<Loc<F, N>, F>)
303 -> &'c DiscreteMeasure<Loc<F, N>, F> {
304 μ
305 }
306 }
307
308 /// Specialisation of [`generic_pointsource_fb_reg`] to basic μFB.
309 struct BasicFB<
310 'a,
311 F : Float + ToNalgebraRealField,
312 A : ForwardModel<Loc<F, N>, F>,
313 const N : usize
314 > {
315 /// The data
316 b : &'a A::Observable,
317 /// The forward operator
318 opA : &'a A,
319 }
320
321 /// Implementation of [`FBSpecialisation`] for basic μFB forward-backward splitting.
322 #[replace_float_literals(F::cast_from(literal))]
323 impl<'a, F : Float + ToNalgebraRealField , A : ForwardModel<Loc<F, N>, F>, const N : usize>
324 FBSpecialisation<F, A::Observable, N> for BasicFB<'a, F, A, N> {
325 fn update(
326 &mut self,
327 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
328 _μ_base : &DiscreteMeasure<Loc<F, N>, F>
329 ) -> (A::Observable, Option<F>) {
330 μ.prune();
331 //*residual = self.opA.apply(μ) - self.b;
332 let mut residual = self.b.clone();
333 self.opA.gemv(&mut residual, 1.0, μ, -1.0);
334 (residual, None)
335 }
336
337 fn calculate_fit_simple(
338 &self,
339 μ : &DiscreteMeasure<Loc<F, N>, F>,
340 ) -> F {
341 let mut residual = self.b.clone();
342 self.opA.gemv(&mut residual, 1.0, μ, -1.0);
343 residual.norm2_squared_div2()
344 }
345 }
346
347 /// Specialisation of [`generic_pointsource_fb_reg`] to FISTA.
348 struct FISTA<
349 'a,
350 F : Float + ToNalgebraRealField,
351 A : ForwardModel<Loc<F, N>, F>,
352 const N : usize
353 > {
354 /// The data
355 b : &'a A::Observable,
356 /// The forward operator
357 opA : &'a A,
358 /// Current inertial parameter
359 λ : F,
360 /// Previous iterate without inertia applied.
361 /// We need to store this here because `μ_base` passed to [`FBSpecialisation::update`] will
362 /// have inertia applied to it, so is not useful to use.
363 μ_prev : DiscreteMeasure<Loc<F, N>, F>,
364 }
365
366 /// Implementation of [`FBSpecialisation`] for μFISTA inertial forward-backward splitting.
367 #[replace_float_literals(F::cast_from(literal))]
368 impl<'a, F : Float + ToNalgebraRealField, A : ForwardModel<Loc<F, N>, F>, const N : usize>
369 FBSpecialisation<F, A::Observable, N> for FISTA<'a, F, A, N> {
370 fn update(
371 &mut self,
372 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
373 _μ_base : &DiscreteMeasure<Loc<F, N>, F>
374 ) -> (A::Observable, Option<F>) {
375 // Update inertial parameters
376 let λ_prev = self.λ;
377 self.λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() );
378 let θ = self.λ / λ_prev - self.λ;
379 // Perform inertial update on μ.
380 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ
381 // and μ_prev have zero weight. Since both have weights from the finite-dimensional
382 // subproblem with a proximal projection step, this is likely to happen when the
383 // spike is not needed. A copy of the pruned μ without artithmetic performed is
384 // stored in μ_prev.
385 μ.pruning_sub(1.0 + θ, θ, &mut self.μ_prev);
386
387 //*residual = self.opA.apply(μ) - self.b;
388 let mut residual = self.b.clone();
389 self.opA.gemv(&mut residual, 1.0, μ, -1.0);
390 (residual, None)
391 }
392
393 fn calculate_fit_simple(
394 &self,
395 μ : &DiscreteMeasure<Loc<F, N>, F>,
396 ) -> F {
397 let mut residual = self.b.clone();
398 self.opA.gemv(&mut residual, 1.0, μ, -1.0);
399 residual.norm2_squared_div2()
400 }
401
402 fn calculate_fit(
403 &self,
404 _μ : &DiscreteMeasure<Loc<F, N>, F>,
405 _residual : &A::Observable
406 ) -> F {
407 self.calculate_fit_simple(&self.μ_prev)
408 }
409
410 // For FISTA we need to do a final pruning as well, due to the limited
411 // pruning that can be done on each step.
412 fn postprocess(mut self, μ_base : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>)
413 -> DiscreteMeasure<Loc<F, N>, F>
414 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> {
415 let mut μ = self.μ_prev;
416 self.μ_prev = μ_base;
417 μ.merge_spikes_fitness(merging,
418 |μ̃| self.calculate_fit_simple(μ̃),
419 |&v| v);
420 μ.prune();
421 μ
422 }
423
424 fn value_μ<'c, 'b : 'c>(&'c self, _μ : &'c DiscreteMeasure<Loc<F, N>, F>)
425 -> &'c DiscreteMeasure<Loc<F, N>, F> {
426 &self.μ_prev
427 }
428 }
429
430
431 /// Abstraction of regularisation terms for [`generic_pointsource_fb_reg`].
432 pub trait RegTerm<F : Float + ToNalgebraRealField, const N : usize>
433 : for<'a> Apply<&'a DiscreteMeasure<Loc<F, N>, F>, Output = F> {
434 /// Approximately solve the problem
435 /// <div>$$
436 /// \min_{x ∈ ℝ^n} \frac{1}{2} x^⊤Ax - g^⊤ x + τ G(x)
437 /// $$</div>
438 /// for $G$ depending on the trait implementation.
439 ///
440 /// The parameter `mA` is $A$. An estimate for its opeator norm should be provided in
441 /// `mA_normest`. The initial iterate and output is `x`. The current main tolerance is `ε`.
442 ///
443 /// Returns the number of iterations taken.
444 fn solve_findim(
445 &self,
446 mA : &DMatrix<F::MixedType>,
447 g : &DVector<F::MixedType>,
448 τ : F,
449 x : &mut DVector<F::MixedType>,
450 mA_normest : F,
451 ε : F,
452 config : &FBGenericConfig<F>
453 ) -> usize;
454
455 /// Find a point where `d` may violate the tolerance `ε`.
456 ///
457 /// If `skip_by_rough_check` is set, do not find the point if a rough check indicates that we
458 /// are in bounds. `ε` is the current main tolerance and `τ` a scaling factor for the
459 /// regulariser.
460 ///
461 /// Returns `None` if `d` is in bounds either based on the rough check, or a more precise check
462 /// terminating early. Otherwise returns a possibly violating point, the value of `d` there,
463 /// and a boolean indicating whether the found point is in bounds.
464 fn find_tolerance_violation<G, BT>(
465 &self,
466 d : &mut BTFN<F, G, BT, N>,
467 τ : F,
468 ε : F,
469 skip_by_rough_check : bool,
470 config : &FBGenericConfig<F>,
471 ) -> Option<(Loc<F, N>, F, bool)>
472 where BT : BTSearch<F, N, Agg=Bounds<F>>,
473 G : SupportGenerator<F, N, Id=BT::Data>,
474 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
475 + LocalAnalysis<F, Bounds<F>, N>;
476
477 /// Verify that `d` is in bounds `ε` for a merge candidate `μ`
478 ///
479 /// `ε` is the current main tolerance and `τ` a scaling factor for the regulariser.
480 fn verify_merge_candidate<G, BT>(
481 &self,
482 d : &mut BTFN<F, G, BT, N>,
483 μ : &DiscreteMeasure<Loc<F, N>, F>,
484 τ : F,
485 ε : F,
486 config : &FBGenericConfig<F>,
487 ) -> bool
488 where BT : BTSearch<F, N, Agg=Bounds<F>>,
489 G : SupportGenerator<F, N, Id=BT::Data>,
490 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
491 + LocalAnalysis<F, Bounds<F>, N>;
492
493 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>>;
494
495 /// Returns a scaling factor for the tolerance sequence.
496 ///
497 /// Typically this is the regularisation parameter.
498 fn tolerance_scaling(&self) -> F;
499 }
500
501 #[replace_float_literals(F::cast_from(literal))]
502 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for NonnegRadonRegTerm<F>
503 where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
504 fn solve_findim(
505 &self,
506 mA : &DMatrix<F::MixedType>,
507 g : &DVector<F::MixedType>,
508 τ : F,
509 x : &mut DVector<F::MixedType>,
510 mA_normest : F,
511 ε : F,
512 config : &FBGenericConfig<F>
513 ) -> usize {
514 let inner_tolerance = ε * config.inner.tolerance_mult;
515 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
516 let inner_τ = config.inner.τ0 / mA_normest;
517 quadratic_nonneg(config.inner.method, mA, g, τ * self.α(), x,
518 inner_τ, inner_it)
519 }
520
521 #[inline]
522 fn find_tolerance_violation<G, BT>(
523 &self,
524 d : &mut BTFN<F, G, BT, N>,
525 τ : F,
526 ε : F,
527 skip_by_rough_check : bool,
528 config : &FBGenericConfig<F>,
529 ) -> Option<(Loc<F, N>, F, bool)>
530 where BT : BTSearch<F, N, Agg=Bounds<F>>,
531 G : SupportGenerator<F, N, Id=BT::Data>,
532 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
533 + LocalAnalysis<F, Bounds<F>, N> {
534 let τα = τ * self.α();
535 let keep_below = τα + ε;
536 let maximise_above = τα + ε * config.insertion_cutoff_factor;
537 let refinement_tolerance = ε * config.refinement.tolerance_mult;
538
539 // If preliminary check indicates that we are in bonds, and if it otherwise matches
540 // the insertion strategy, skip insertion.
541 if skip_by_rough_check && d.bounds().upper() <= keep_below {
542 None
543 } else {
544 // If the rough check didn't indicate no insertion needed, find maximising point.
545 d.maximise_above(maximise_above, refinement_tolerance, config.refinement.max_steps)
546 .map(|(ξ, v_ξ)| (ξ, v_ξ, v_ξ <= keep_below))
547 }
548 }
549
550 fn verify_merge_candidate<G, BT>(
551 &self,
552 d : &mut BTFN<F, G, BT, N>,
553 μ : &DiscreteMeasure<Loc<F, N>, F>,
554 τ : F,
555 ε : F,
556 config : &FBGenericConfig<F>,
557 ) -> bool
558 where BT : BTSearch<F, N, Agg=Bounds<F>>,
559 G : SupportGenerator<F, N, Id=BT::Data>,
560 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
561 + LocalAnalysis<F, Bounds<F>, N> {
562 let τα = τ * self.α();
563 let refinement_tolerance = ε * config.refinement.tolerance_mult;
564 let merge_tolerance = config.merge_tolerance_mult * ε;
565 let keep_below = τα + merge_tolerance;
566 let keep_supp_above = τα - merge_tolerance;
567 let bnd = d.bounds();
568
569 return (
570 bnd.lower() >= keep_supp_above
571 ||
572 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
573 (β == 0.0) || d.apply(x) >= keep_supp_above
574 }).all(std::convert::identity)
575 ) && (
576 bnd.upper() <= keep_below
577 ||
578 d.has_upper_bound(keep_below, refinement_tolerance, config.refinement.max_steps)
579 )
580 }
581
582 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
583 let τα = τ * self.α();
584 Some(Bounds(τα - ε, τα + ε))
585 }
586
587 fn tolerance_scaling(&self) -> F {
588 self.α()
589 }
590 }
591
592 #[replace_float_literals(F::cast_from(literal))]
593 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for RadonRegTerm<F>
594 where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
595 fn solve_findim(
596 &self,
597 mA : &DMatrix<F::MixedType>,
598 g : &DVector<F::MixedType>,
599 τ : F,
600 x : &mut DVector<F::MixedType>,
601 mA_normest: F,
602 ε : F,
603 config : &FBGenericConfig<F>
604 ) -> usize {
605 let inner_tolerance = ε * config.inner.tolerance_mult;
606 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
607 let inner_τ = config.inner.τ0 / mA_normest;
608 quadratic_unconstrained(config.inner.method, mA, g, τ * self.α(), x,
609 inner_τ, inner_it)
610 }
611
612 fn find_tolerance_violation<G, BT>(
613 &self,
614 d : &mut BTFN<F, G, BT, N>,
615 τ : F,
616 ε : F,
617 skip_by_rough_check : bool,
618 config : &FBGenericConfig<F>,
619 ) -> Option<(Loc<F, N>, F, bool)>
620 where BT : BTSearch<F, N, Agg=Bounds<F>>,
621 G : SupportGenerator<F, N, Id=BT::Data>,
622 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
623 + LocalAnalysis<F, Bounds<F>, N> {
624 let τα = τ * self.α();
625 let keep_below = τα + ε;
626 let keep_above = -τα - ε;
627 let maximise_above = τα + ε * config.insertion_cutoff_factor;
628 let minimise_below = -τα - ε * config.insertion_cutoff_factor;
629 let refinement_tolerance = ε * config.refinement.tolerance_mult;
630
631 // If preliminary check indicates that we are in bonds, and if it otherwise matches
632 // the insertion strategy, skip insertion.
633 if skip_by_rough_check && Bounds(keep_above, keep_below).superset(&d.bounds()) {
634 None
635 } else {
636 // If the rough check didn't indicate no insertion needed, find maximising point.
637 let mx = d.maximise_above(maximise_above, refinement_tolerance,
638 config.refinement.max_steps);
639 let mi = d.minimise_below(minimise_below, refinement_tolerance,
640 config.refinement.max_steps);
641
642 match (mx, mi) {
643 (None, None) => None,
644 (Some((ξ, v_ξ)), None) => Some((ξ, v_ξ, keep_below >= v_ξ)),
645 (None, Some((ζ, v_ζ))) => Some((ζ, v_ζ, keep_above <= v_ζ)),
646 (Some((ξ, v_ξ)), Some((ζ, v_ζ))) => {
647 if v_ξ - τα > τα - v_ζ {
648 Some((ξ, v_ξ, keep_below >= v_ξ))
649 } else {
650 Some((ζ, v_ζ, keep_above <= v_ζ))
651 }
652 }
653 }
654 }
655 }
656
657 fn verify_merge_candidate<G, BT>(
658 &self,
659 d : &mut BTFN<F, G, BT, N>,
660 μ : &DiscreteMeasure<Loc<F, N>, F>,
661 τ : F,
662 ε : F,
663 config : &FBGenericConfig<F>,
664 ) -> bool
665 where BT : BTSearch<F, N, Agg=Bounds<F>>,
666 G : SupportGenerator<F, N, Id=BT::Data>,
667 G::SupportType : Mapping<Loc<F, N>,Codomain=F>
668 + LocalAnalysis<F, Bounds<F>, N> {
669 let τα = τ * self.α();
670 let refinement_tolerance = ε * config.refinement.tolerance_mult;
671 let merge_tolerance = config.merge_tolerance_mult * ε;
672 let keep_below = τα + merge_tolerance;
673 let keep_above = -τα - merge_tolerance;
674 let keep_supp_pos_above = τα - merge_tolerance;
675 let keep_supp_neg_below = -τα + merge_tolerance;
676 let bnd = d.bounds();
677
678 return (
679 (bnd.lower() >= keep_supp_pos_above && bnd.upper() <= keep_supp_neg_below)
680 ||
681 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
682 use std::cmp::Ordering::*;
683 match β.partial_cmp(&0.0) {
684 Some(Greater) => d.apply(x) >= keep_supp_pos_above,
685 Some(Less) => d.apply(x) <= keep_supp_neg_below,
686 _ => true,
687 }
688 }).all(std::convert::identity)
689 ) && (
690 bnd.upper() <= keep_below
691 ||
692 d.has_upper_bound(keep_below, refinement_tolerance,
693 config.refinement.max_steps)
694 ) && (
695 bnd.lower() >= keep_above
696 ||
697 d.has_lower_bound(keep_above, refinement_tolerance,
698 config.refinement.max_steps)
699 )
700 }
701
702 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
703 let τα = τ * self.α();
704 Some(Bounds(-τα - ε, τα + ε))
705 }
706
707 fn tolerance_scaling(&self) -> F {
708 self.α()
709 }
710 }
711
712
713 /// Generic implementation of [`pointsource_fb_reg`].
714 ///
715 /// The method can be specialised to even primal-dual proximal splitting through the
716 /// [`FBSpecialisation`] parameter `specialisation`.
717 /// The settings in `config` have their [respective documentation](FBGenericConfig). `opA` is the
718 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. 450 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
719 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution 451 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
720 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control 452 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
721 /// as documented in [`alg_tools::iterate`]. 453 /// as documented in [`alg_tools::iterate`].
454 ///
455 /// For details on the mathematical formulation, see the [module level](self) documentation.
722 /// 456 ///
723 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of 457 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
724 /// sums of simple functions usign bisection trees, and the related 458 /// sums of simple functions usign bisection trees, and the related
725 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions 459 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
726 /// active at a specific points, and to maximise their sums. Through the implementation of the 460 /// active at a specific points, and to maximise their sums. Through the implementation of the
727 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features 461 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
728 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. 462 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
729 /// 463 ///
730 /// Returns the final iterate. 464 /// Returns the final iterate.
731 #[replace_float_literals(F::cast_from(literal))] 465 #[replace_float_literals(F::cast_from(literal))]
732 pub fn generic_pointsource_fb_reg< 466 pub fn pointsource_fb_reg<
733 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, Reg, const N : usize 467 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize
734 >( 468 >(
735 opA : &'a A,
736 reg : Reg,
737 op𝒟 : &'a 𝒟,
738 mut τ : F,
739 config : &FBGenericConfig<F>,
740 iterator : I,
741 mut plotter : SeqPlotter<F, N>,
742 mut residual : A::Observable,
743 mut specialisation : Spec
744 ) -> DiscreteMeasure<Loc<F, N>, F>
745 where F : Float + ToNalgebraRealField,
746 I : AlgIteratorFactory<IterInfo<F, N>>,
747 Spec : FBSpecialisation<F, A::Observable, N>,
748 A::Observable : std::ops::MulAssign<F>,
749 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
750 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
751 + Lipschitz<𝒟, FloatType=F>,
752 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
753 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
754 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
755 𝒟::Codomain : RealMapping<F, N>,
756 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
757 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
758 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
759 PlotLookup : Plotting<N>,
760 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
761 Reg : RegTerm<F, N> {
762
763 // Set up parameters
764 let quiet = iterator.is_quiet();
765 let op𝒟norm = op𝒟.opnorm_bound();
766 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
767 // by τ compared to the conditional gradient approach.
768 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
769 let mut ε = tolerance.initial();
770
771 // Initialise operators
772 let preadjA = opA.preadjoint();
773
774 // Initialise iterates
775 let mut μ = DiscreteMeasure::new();
776
777 let mut inner_iters = 0;
778 let mut this_iters = 0;
779 let mut pruned = 0;
780 let mut merged = 0;
781
782 let μ_diff = |μ_new : &DiscreteMeasure<Loc<F, N>, F>,
783 μ_base : &DiscreteMeasure<Loc<F, N>, F>| {
784 let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style {
785 InsertionStyle::Reuse => {
786 μ_new.iter_spikes()
787 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0)))
788 .map(|(δ, α_base)| (δ.x, α_base - δ.α))
789 .collect()
790 },
791 InsertionStyle::Zero => {
792 μ_new.iter_spikes()
793 .map(|δ| -δ)
794 .chain(μ_base.iter_spikes().copied())
795 .collect()
796 }
797 };
798 ν.prune(); // Potential small performance improvement
799 ν
800 };
801
802 // Run the algorithm
803 iterator.iterate(|state| {
804 // Maximum insertion count and measure difference calculation depend on insertion style.
805 let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
806 (i, Some((l, k))) if i <= l => (k, false),
807 _ => (config.max_insertions, !quiet),
808 };
809 let max_insertions = match config.insertion_style {
810 InsertionStyle::Zero => {
811 todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled.");
812 // let n = μ.len();
813 // μ = DiscreteMeasure::new();
814 // n + m
815 },
816 InsertionStyle::Reuse => m,
817 };
818
819 // Calculate smooth part of surrogate model.
820 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
821 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
822 // the residual and replacing it below before the end of this closure.
823 residual *= -τ;
824 let r = std::mem::replace(&mut residual, opA.empty_observable());
825 let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b)
826 // TODO: should avoid a second copy of μ here; μ_base already stores a copy.
827 let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k
828 //let g = &minus_τv + ω0; // Linear term of surrogate model
829
830 // Save current base point
831 let μ_base = μ.clone();
832
833 // Add points to support until within error tolerance or maximum insertion count reached.
834 let mut count = 0;
835 let (within_tolerances, d) = 'insertion: loop {
836 if μ.len() > 0 {
837 // Form finite-dimensional subproblem. The subproblem references to the original μ^k
838 // from the beginning of the iteration are all contained in the immutable c and g.
839 let à = op𝒟.findim_matrix(μ.iter_locations());
840 let g̃ = DVector::from_iterator(μ.len(),
841 μ.iter_locations()
842 .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ))
843 .map(F::to_nalgebra_mixed));
844 let mut x = μ.masses_dvector();
845
846 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
847 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
848 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
849 // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2
850 // = n |𝒟| |x|_2, where n is the number of points. Therefore
851 let Ã_normest = op𝒟norm * F::cast_from(μ.len());
852
853 // Solve finite-dimensional subproblem.
854 inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config);
855
856 // Update masses of μ based on solution of finite-dimensional subproblem.
857 μ.set_masses_dvector(&x);
858 }
859
860 // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality
861 // conditions in the predual space, and finding new points for insertion, if necessary.
862 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ, &μ_base));
863
864 // If no merging heuristic is used, let's be more conservative about spike insertion,
865 // and skip it after first round. If merging is done, being more greedy about spike
866 // insertion also seems to improve performance.
867 let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging {
868 false
869 } else {
870 count > 0
871 };
872
873 // Find a spike to insert, if needed
874 let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation(
875 &mut d, τ, ε, skip_by_rough_check, config
876 ) {
877 None => break 'insertion (true, d),
878 Some(res) => res,
879 };
880
881 // Break if maximum insertion count reached
882 if count >= max_insertions {
883 break 'insertion (in_bounds, d)
884 }
885
886 // No point in optimising the weight here; the finite-dimensional algorithm is fast.
887 μ += DeltaMeasure { x : ξ, α : 0.0 };
888 count += 1;
889 };
890
891 if !within_tolerances && warn_insertions {
892 // Complain (but continue) if we failed to get within tolerances
893 // by inserting more points.
894 let err = format!("Maximum insertions reached without achieving \
895 subproblem solution tolerance");
896 println!("{}", err.red());
897 }
898
899 // Merge spikes
900 if state.iteration() % config.merge_every == 0 {
901 let n_before_merge = μ.len();
902 μ.merge_spikes(config.merging, |μ_candidate| {
903 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ_candidate, &μ_base));
904
905 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
906 .then_some(())
907 });
908 debug_assert!(μ.len() >= n_before_merge);
909 merged += μ.len() - n_before_merge;
910 }
911
912 let n_before_prune = μ.len();
913 (residual, τ) = match specialisation.update(&mut μ, &μ_base) {
914 (r, None) => (r, τ),
915 (r, Some(new_τ)) => (r, new_τ)
916 };
917 debug_assert!(μ.len() <= n_before_prune);
918 pruned += n_before_prune - μ.len();
919
920 this_iters += 1;
921
922 // Update main tolerance for next iteration
923 let ε_prev = ε;
924 ε = tolerance.update(ε, state.iteration());
925
926 // Give function value if needed
927 state.if_verbose(|| {
928 let value_μ = specialisation.value_μ(&μ);
929 // Plot if so requested
930 plotter.plot_spikes(
931 format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
932 "start".to_string(), Some(&minus_τv),
933 reg.target_bounds(τ, ε_prev), value_μ,
934 );
935 // Calculate mean inner iterations and reset relevant counters.
936 // Return the statistics
937 let res = IterInfo {
938 value : specialisation.calculate_fit(&μ, &residual) + reg.apply(value_μ),
939 n_spikes : value_μ.len(),
940 inner_iters,
941 this_iters,
942 merged,
943 pruned,
944 ε : ε_prev,
945 postprocessing: config.postprocessing.then(|| value_μ.clone()),
946 };
947 inner_iters = 0;
948 this_iters = 0;
949 merged = 0;
950 pruned = 0;
951 res
952 })
953 });
954
955 specialisation.postprocess(μ, config.final_merging)
956 }
957
958 /// Iteratively solve the pointsource localisation problem using forward-backward splitting
959 ///
960 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
961 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
962 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
963 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
964 /// as documented in [`alg_tools::iterate`].
965 ///
966 /// For details on the mathematical formulation, see the [module level](self) documentation.
967 ///
968 /// Returns the final iterate.
969 #[replace_float_literals(F::cast_from(literal))]
970 pub fn pointsource_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>(
971 opA : &'a A, 469 opA : &'a A,
972 b : &A::Observable, 470 b : &A::Observable,
973 reg : Reg, 471 reg : Reg,
974 op𝒟 : &'a 𝒟, 472 op𝒟 : &'a 𝒟,
975 config : &FBConfig<F>, 473 fbconfig : &FBConfig<F>,
976 iterator : I, 474 iterator : I,
977 plotter : SeqPlotter<F, N>, 475 mut plotter : SeqPlotter<F, N>,
978 ) -> DiscreteMeasure<Loc<F, N>, F> 476 ) -> DiscreteMeasure<Loc<F, N>, F>
979 where F : Float + ToNalgebraRealField, 477 where F : Float + ToNalgebraRealField,
980 I : AlgIteratorFactory<IterInfo<F, N>>, 478 I : AlgIteratorFactory<IterInfo<F, N>>,
981 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, 479 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
982 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow 480 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
983 A::Observable : std::ops::MulAssign<F>, 481 A::Observable : std::ops::MulAssign<F>,
984 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 482 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
985 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> 483 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
986 + Lipschitz<𝒟, FloatType=F>, 484 + Lipschitz<&'a 𝒟, FloatType=F>,
987 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 485 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
988 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 486 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
989 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 487 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
990 𝒟::Codomain : RealMapping<F, N>, 488 𝒟::Codomain : RealMapping<F, N>,
991 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 489 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
994 Cube<F, N>: P2Minimise<Loc<F, N>, F>, 492 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
995 PlotLookup : Plotting<N>, 493 PlotLookup : Plotting<N>,
996 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 494 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
997 Reg : RegTerm<F, N> { 495 Reg : RegTerm<F, N> {
998 496
999 let initial_residual = -b; 497 // Set up parameters
1000 let τ = config.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); 498 let config = &fbconfig.insertion;
1001 499 let op𝒟norm = op𝒟.opnorm_bound();
1002 match config.meta { 500 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
1003 FBMetaAlgorithm::None => generic_pointsource_fb_reg( 501 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
1004 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, 502 // by τ compared to the conditional gradient approach.
1005 BasicFB{ b, opA }, 503 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
1006 ), 504 let mut ε = tolerance.initial();
1007 FBMetaAlgorithm::InertiaFISTA => generic_pointsource_fb_reg( 505
1008 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, 506 // Initialise iterates
1009 FISTA{ b, opA, λ : 1.0, μ_prev : DiscreteMeasure::new() }, 507 let mut μ = DiscreteMeasure::new();
1010 ), 508 let mut residual = -b;
1011 } 509 let mut stats = IterInfo::new();
1012 } 510
511 // Run the algorithm
512 iterator.iterate(|state| {
513 // Calculate smooth part of surrogate model.
514 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
515 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
516 // the residual and replacing it below before the end of this closure.
517 residual *= -τ;
518 let r = std::mem::replace(&mut residual, opA.empty_observable());
519 let minus_τv = opA.preadjoint().apply(r);
520
521 // Save current base point
522 let μ_base = μ.clone();
523
524 // Insert and reweigh
525 let (d, within_tolerances) = insert_and_reweigh(
526 &mut μ, &minus_τv, &μ_base, None,
527 op𝒟, op𝒟norm,
528 τ, ε,
529 config, &reg, state, &mut stats
530 );
531
532 // Prune and possibly merge spikes
533 prune_and_maybe_simple_merge(
534 &mut μ, &minus_τv, &μ_base,
535 op𝒟,
536 τ, ε,
537 config, &reg, state, &mut stats
538 );
539
540 // Update residual
541 residual = calculate_residual(&μ, opA, b);
542
543 // Update main tolerance for next iteration
544 let ε_prev = ε;
545 ε = tolerance.update(ε, state.iteration());
546 stats.this_iters += 1;
547
548 // Give function value if needed
549 state.if_verbose(|| {
550 // Plot if so requested
551 plotter.plot_spikes(
552 format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
553 "start".to_string(), Some(&minus_τv),
554 reg.target_bounds(τ, ε_prev), &μ,
555 );
556 // Calculate mean inner iterations and reset relevant counters.
557 // Return the statistics
558 let res = IterInfo {
559 value : residual.norm2_squared_div2() + reg.apply(&μ),
560 n_spikes : μ.len(),
561 ε : ε_prev,
562 postprocessing: config.postprocessing.then(|| μ.clone()),
563 .. stats
564 };
565 stats = IterInfo::new();
566 res
567 })
568 });
569
570 postprocess(μ, config, L2Squared, opA, b)
571 }
572
573 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting.
574 ///
575 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the
576 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
577 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
578 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
579 /// as documented in [`alg_tools::iterate`].
580 ///
581 /// For details on the mathematical formulation, see the [module level](self) documentation.
582 ///
583 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of
584 /// sums of simple functions usign bisection trees, and the related
585 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions
586 /// active at a specific points, and to maximise their sums. Through the implementation of the
587 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features
588 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions.
589 ///
590 /// Returns the final iterate.
591 #[replace_float_literals(F::cast_from(literal))]
592 pub fn pointsource_fista_reg<
593 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize
594 >(
595 opA : &'a A,
596 b : &A::Observable,
597 reg : Reg,
598 op𝒟 : &'a 𝒟,
599 fbconfig : &FBConfig<F>,
600 iterator : I,
601 mut plotter : SeqPlotter<F, N>,
602 ) -> DiscreteMeasure<Loc<F, N>, F>
603 where F : Float + ToNalgebraRealField,
604 I : AlgIteratorFactory<IterInfo<F, N>>,
605 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
606 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
607 A::Observable : std::ops::MulAssign<F>,
608 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
609 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
610 + Lipschitz<&'a 𝒟, FloatType=F>,
611 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
612 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
613 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
614 𝒟::Codomain : RealMapping<F, N>,
615 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
616 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
617 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
618 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
619 PlotLookup : Plotting<N>,
620 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
621 Reg : RegTerm<F, N> {
622
623 // Set up parameters
624 let config = &fbconfig.insertion;
625 let op𝒟norm = op𝒟.opnorm_bound();
626 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
627 let mut λ = 1.0;
628 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
629 // by τ compared to the conditional gradient approach.
630 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
631 let mut ε = tolerance.initial();
632
633 // Initialise iterates
634 let mut μ = DiscreteMeasure::new();
635 let mut μ_prev = DiscreteMeasure::new();
636 let mut residual = -b;
637 let mut stats = IterInfo::new();
638 let mut warned_merging = false;
639
640 // Run the algorithm
641 iterator.iterate(|state| {
642 // Calculate smooth part of surrogate model.
643 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
644 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
645 // the residual and replacing it below before the end of this closure.
646 residual *= -τ;
647 let r = std::mem::replace(&mut residual, opA.empty_observable());
648 let minus_τv = opA.preadjoint().apply(r);
649
650 // Save current base point
651 let μ_base = μ.clone();
652
653 // Insert new spikes and reweigh
654 let (d, within_tolerances) = insert_and_reweigh(
655 &mut μ, &minus_τv, &μ_base, None,
656 op𝒟, op𝒟norm,
657 τ, ε,
658 config, &reg, state, &mut stats
659 );
660
661 // (Do not) merge spikes.
662 if state.iteration() % config.merge_every == 0 {
663 match config.merging {
664 SpikeMergingMethod::None => { },
665 _ => if !warned_merging {
666 let err = format!("Merging not supported for μFISTA");
667 println!("{}", err.red());
668 warned_merging = true;
669 }
670 }
671 }
672
673 // Update inertial prameters
674 let λ_prev = λ;
675 λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() );
676 let θ = λ / λ_prev - λ;
677
678 // Perform inertial update on μ.
679 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ
680 // and μ_prev have zero weight. Since both have weights from the finite-dimensional
681 // subproblem with a proximal projection step, this is likely to happen when the
682 // spike is not needed. A copy of the pruned μ without artithmetic performed is
683 // stored in μ_prev.
684 let n_before_prune = μ.len();
685 μ.pruning_sub(1.0 + θ, θ, &mut μ_prev);
686 debug_assert!(μ.len() <= n_before_prune);
687 stats.pruned += n_before_prune - μ.len();
688
689 // Update residual
690 residual = calculate_residual(&μ, opA, b);
691
692 // Update main tolerance for next iteration
693 let ε_prev = ε;
694 ε = tolerance.update(ε, state.iteration());
695 stats.this_iters += 1;
696
697 // Give function value if needed
698 state.if_verbose(|| {
699 // Plot if so requested
700 plotter.plot_spikes(
701 format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
702 "start".to_string(), Some(&minus_τv),
703 reg.target_bounds(τ, ε_prev), &μ_prev,
704 );
705 // Calculate mean inner iterations and reset relevant counters.
706 // Return the statistics
707 let res = IterInfo {
708 value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev),
709 n_spikes : μ_prev.len(),
710 ε : ε_prev,
711 postprocessing: config.postprocessing.then(|| μ_prev.clone()),
712 .. stats
713 };
714 stats = IterInfo::new();
715 res
716 })
717 });
718
719 postprocess(μ_prev, config, L2Squared, opA, b)
720 }

mercurial