238 postprocessing : false, |
222 postprocessing : false, |
239 } |
223 } |
240 } |
224 } |
241 } |
225 } |
242 |
226 |
243 /// Trait for specialisation of [`generic_pointsource_fb_reg`] to basic FB, FISTA. |
227 #[replace_float_literals(F::cast_from(literal))] |
|
228 pub(crate) fn μ_diff<F : Float, const N : usize>( |
|
229 μ_new : &DiscreteMeasure<Loc<F, N>, F>, |
|
230 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
|
231 ν_delta : Option<&DiscreteMeasure<Loc<F, N>, F>>, |
|
232 config : &FBGenericConfig<F> |
|
233 ) -> DiscreteMeasure<Loc<F, N>, F> { |
|
234 let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style { |
|
235 InsertionStyle::Reuse => { |
|
236 μ_new.iter_spikes() |
|
237 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0))) |
|
238 .map(|(δ, α_base)| (δ.x, α_base - δ.α)) |
|
239 .collect() |
|
240 }, |
|
241 InsertionStyle::Zero => { |
|
242 μ_new.iter_spikes() |
|
243 .map(|δ| -δ) |
|
244 .chain(μ_base.iter_spikes().copied()) |
|
245 .collect() |
|
246 } |
|
247 }; |
|
248 ν.prune(); // Potential small performance improvement |
|
249 // Add ν_delta if given |
|
250 match ν_delta { |
|
251 None => ν, |
|
252 Some(ν_d) => ν + ν_d, |
|
253 } |
|
254 } |
|
255 |
|
256 #[replace_float_literals(F::cast_from(literal))] |
|
257 pub(crate) fn insert_and_reweigh< |
|
258 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize |
|
259 >( |
|
260 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
261 minus_τv : &BTFN<F, GA, BTA, N>, |
|
262 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
|
263 ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>, |
|
264 op𝒟 : &'a 𝒟, |
|
265 op𝒟norm : F, |
|
266 τ : F, |
|
267 ε : F, |
|
268 config : &FBGenericConfig<F>, |
|
269 reg : &Reg, |
|
270 state : &State, |
|
271 stats : &mut IterInfo<F, N>, |
|
272 ) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool) |
|
273 where F : Float + ToNalgebraRealField, |
|
274 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
275 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
276 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
|
277 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
|
278 𝒟::Codomain : RealMapping<F, N>, |
|
279 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
280 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
281 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
282 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
283 Reg : RegTerm<F, N>, |
|
284 State : AlgIteratorState { |
|
285 |
|
286 // Maximum insertion count and measure difference calculation depend on insertion style. |
|
287 let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
|
288 (i, Some((l, k))) if i <= l => (k, false), |
|
289 _ => (config.max_insertions, !state.is_quiet()), |
|
290 }; |
|
291 let max_insertions = match config.insertion_style { |
|
292 InsertionStyle::Zero => { |
|
293 todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled."); |
|
294 // let n = μ.len(); |
|
295 // μ = DiscreteMeasure::new(); |
|
296 // n + m |
|
297 }, |
|
298 InsertionStyle::Reuse => m, |
|
299 }; |
|
300 |
|
301 // TODO: should avoid a second copy of μ here; μ_base already stores a copy. |
|
302 let ω0 = op𝒟.apply(match ν_delta { |
|
303 None => μ.clone(), |
|
304 Some(ν_d) => &*μ + ν_d, |
|
305 }); |
|
306 |
|
307 // Add points to support until within error tolerance or maximum insertion count reached. |
|
308 let mut count = 0; |
|
309 let (within_tolerances, d) = 'insertion: loop { |
|
310 if μ.len() > 0 { |
|
311 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
|
312 // from the beginning of the iteration are all contained in the immutable c and g. |
|
313 let à = op𝒟.findim_matrix(μ.iter_locations()); |
|
314 let g̃ = DVector::from_iterator(μ.len(), |
|
315 μ.iter_locations() |
|
316 .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ)) |
|
317 .map(F::to_nalgebra_mixed)); |
|
318 let mut x = μ.masses_dvector(); |
|
319 |
|
320 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. |
|
321 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ |
|
322 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ |
|
323 // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 |
|
324 // = n |𝒟| |x|_2, where n is the number of points. Therefore |
|
325 let Ã_normest = op𝒟norm * F::cast_from(μ.len()); |
|
326 |
|
327 // Solve finite-dimensional subproblem. |
|
328 stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); |
|
329 |
|
330 // Update masses of μ based on solution of finite-dimensional subproblem. |
|
331 μ.set_masses_dvector(&x); |
|
332 } |
|
333 |
|
334 // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality |
|
335 // conditions in the predual space, and finding new points for insertion, if necessary. |
|
336 let mut d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_base, ν_delta, config)); |
|
337 |
|
338 // If no merging heuristic is used, let's be more conservative about spike insertion, |
|
339 // and skip it after first round. If merging is done, being more greedy about spike |
|
340 // insertion also seems to improve performance. |
|
341 let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { |
|
342 false |
|
343 } else { |
|
344 count > 0 |
|
345 }; |
|
346 |
|
347 // Find a spike to insert, if needed |
|
348 let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( |
|
349 &mut d, τ, ε, skip_by_rough_check, config |
|
350 ) { |
|
351 None => break 'insertion (true, d), |
|
352 Some(res) => res, |
|
353 }; |
|
354 |
|
355 // Break if maximum insertion count reached |
|
356 if count >= max_insertions { |
|
357 break 'insertion (in_bounds, d) |
|
358 } |
|
359 |
|
360 // No point in optimising the weight here; the finite-dimensional algorithm is fast. |
|
361 *μ += DeltaMeasure { x : ξ, α : 0.0 }; |
|
362 count += 1; |
|
363 }; |
|
364 |
|
365 // TODO: should redo everything if some transports cause a problem. |
|
366 // Maybe implementation should call above loop as a closure. |
|
367 |
|
368 if !within_tolerances && warn_insertions { |
|
369 // Complain (but continue) if we failed to get within tolerances |
|
370 // by inserting more points. |
|
371 let err = format!("Maximum insertions reached without achieving \ |
|
372 subproblem solution tolerance"); |
|
373 println!("{}", err.red()); |
|
374 } |
|
375 |
|
376 (d, within_tolerances) |
|
377 } |
|
378 |
|
379 #[replace_float_literals(F::cast_from(literal))] |
|
380 pub(crate) fn prune_and_maybe_simple_merge< |
|
381 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize |
|
382 >( |
|
383 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
384 minus_τv : &BTFN<F, GA, BTA, N>, |
|
385 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
|
386 op𝒟 : &'a 𝒟, |
|
387 τ : F, |
|
388 ε : F, |
|
389 config : &FBGenericConfig<F>, |
|
390 reg : &Reg, |
|
391 state : &State, |
|
392 stats : &mut IterInfo<F, N>, |
|
393 ) |
|
394 where F : Float + ToNalgebraRealField, |
|
395 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
396 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
397 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
|
398 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
|
399 𝒟::Codomain : RealMapping<F, N>, |
|
400 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
401 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
402 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
403 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
404 Reg : RegTerm<F, N>, |
|
405 State : AlgIteratorState { |
|
406 if state.iteration() % config.merge_every == 0 { |
|
407 let n_before_merge = μ.len(); |
|
408 μ.merge_spikes(config.merging, |μ_candidate| { |
|
409 let μd = μ_diff(&μ_candidate, &μ_base, None, config); |
|
410 let mut d = minus_τv + op𝒟.preapply(μd); |
|
411 |
|
412 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
|
413 .then_some(()) |
|
414 }); |
|
415 debug_assert!(μ.len() >= n_before_merge); |
|
416 stats.merged += μ.len() - n_before_merge; |
|
417 } |
|
418 |
|
419 let n_before_prune = μ.len(); |
|
420 μ.prune(); |
|
421 debug_assert!(μ.len() <= n_before_prune); |
|
422 stats.pruned += n_before_prune - μ.len(); |
|
423 } |
|
424 |
|
425 #[replace_float_literals(F::cast_from(literal))] |
|
426 pub(crate) fn postprocess< |
|
427 F : Float, |
|
428 V : Euclidean<F> + Clone, |
|
429 A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>, |
|
430 D : DataTerm<F, V, N>, |
|
431 const N : usize |
|
432 > ( |
|
433 mut μ : DiscreteMeasure<Loc<F, N>, F>, |
|
434 config : &FBGenericConfig<F>, |
|
435 dataterm : D, |
|
436 opA : &A, |
|
437 b : &V, |
|
438 ) -> DiscreteMeasure<Loc<F, N>, F> |
|
439 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { |
|
440 μ.merge_spikes_fitness(config.merging, |
|
441 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
|
442 |&v| v); |
|
443 μ.prune(); |
|
444 μ |
|
445 } |
|
446 |
|
447 /// Iteratively solve the pointsource localisation problem using forward-backward splitting. |
244 /// |
448 /// |
245 /// The idea is that the residual $Aμ - b$ in the forward step can be replaced by an arbitrary |
449 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
246 /// value. For example, to implement [primal-dual proximal splitting][crate::pdps] we replace it |
|
247 /// with the dual variable $y$. We can then also implement alternative data terms, as the |
|
248 /// (pre)differential of $F(μ)=F\_0(Aμ-b)$ is $F\'(μ) = A\_*F\_0\'(Aμ-b)$. In the case of the |
|
249 /// quadratic fidelity $F_0(y)=\frac{1}{2}\\|y\\|_2^2$ in a Hilbert space, of course, |
|
250 /// $F\_0\'(Aμ-b)=Aμ-b$ is the residual. |
|
251 pub trait FBSpecialisation<F : Float, Observable : Euclidean<F>, const N : usize> : Sized { |
|
252 /// Updates the residual and does any necessary pruning of `μ`. |
|
253 /// |
|
254 /// Returns the new residual and possibly a new step length. |
|
255 /// |
|
256 /// The measure `μ` may also be modified to apply, e.g., inertia to it. |
|
257 /// The updated residual should correspond to the residual at `μ`. |
|
258 /// See the [trait documentation][FBSpecialisation] for the use and meaning of the residual. |
|
259 /// |
|
260 /// The parameter `μ_base` is the base point of the iteration, typically the previous iterate, |
|
261 /// but for, e.g., FISTA has inertia applied to it. |
|
262 fn update( |
|
263 &mut self, |
|
264 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
265 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
|
266 ) -> (Observable, Option<F>); |
|
267 |
|
268 /// Calculates the data term value corresponding to iterate `μ` and available residual. |
|
269 /// |
|
270 /// Inertia and other modifications, as deemed, necessary, should be applied to `μ`. |
|
271 /// |
|
272 /// The blanket implementation correspondsn to the 2-norm-squared data fidelity |
|
273 /// $\\|\text{residual}\\|\_2^2/2$. |
|
274 fn calculate_fit( |
|
275 &self, |
|
276 _μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
277 residual : &Observable |
|
278 ) -> F { |
|
279 residual.norm2_squared_div2() |
|
280 } |
|
281 |
|
282 /// Calculates the data term value at $μ$. |
|
283 /// |
|
284 /// Unlike [`Self::calculate_fit`], no inertia, etc., should be applied to `μ`. |
|
285 fn calculate_fit_simple( |
|
286 &self, |
|
287 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
288 ) -> F; |
|
289 |
|
290 /// Returns the final iterate after any necessary postprocess pruning, merging, etc. |
|
291 fn postprocess(self, mut μ : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>) |
|
292 -> DiscreteMeasure<Loc<F, N>, F> |
|
293 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { |
|
294 μ.merge_spikes_fitness(merging, |
|
295 |μ̃| self.calculate_fit_simple(μ̃), |
|
296 |&v| v); |
|
297 μ.prune(); |
|
298 μ |
|
299 } |
|
300 |
|
301 /// Returns measure to be used for value calculations, which may differ from μ. |
|
302 fn value_μ<'c, 'b : 'c>(&'b self, μ : &'c DiscreteMeasure<Loc<F, N>, F>) |
|
303 -> &'c DiscreteMeasure<Loc<F, N>, F> { |
|
304 μ |
|
305 } |
|
306 } |
|
307 |
|
308 /// Specialisation of [`generic_pointsource_fb_reg`] to basic μFB. |
|
309 struct BasicFB< |
|
310 'a, |
|
311 F : Float + ToNalgebraRealField, |
|
312 A : ForwardModel<Loc<F, N>, F>, |
|
313 const N : usize |
|
314 > { |
|
315 /// The data |
|
316 b : &'a A::Observable, |
|
317 /// The forward operator |
|
318 opA : &'a A, |
|
319 } |
|
320 |
|
321 /// Implementation of [`FBSpecialisation`] for basic μFB forward-backward splitting. |
|
322 #[replace_float_literals(F::cast_from(literal))] |
|
323 impl<'a, F : Float + ToNalgebraRealField , A : ForwardModel<Loc<F, N>, F>, const N : usize> |
|
324 FBSpecialisation<F, A::Observable, N> for BasicFB<'a, F, A, N> { |
|
325 fn update( |
|
326 &mut self, |
|
327 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
328 _μ_base : &DiscreteMeasure<Loc<F, N>, F> |
|
329 ) -> (A::Observable, Option<F>) { |
|
330 μ.prune(); |
|
331 //*residual = self.opA.apply(μ) - self.b; |
|
332 let mut residual = self.b.clone(); |
|
333 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
|
334 (residual, None) |
|
335 } |
|
336 |
|
337 fn calculate_fit_simple( |
|
338 &self, |
|
339 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
340 ) -> F { |
|
341 let mut residual = self.b.clone(); |
|
342 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
|
343 residual.norm2_squared_div2() |
|
344 } |
|
345 } |
|
346 |
|
347 /// Specialisation of [`generic_pointsource_fb_reg`] to FISTA. |
|
348 struct FISTA< |
|
349 'a, |
|
350 F : Float + ToNalgebraRealField, |
|
351 A : ForwardModel<Loc<F, N>, F>, |
|
352 const N : usize |
|
353 > { |
|
354 /// The data |
|
355 b : &'a A::Observable, |
|
356 /// The forward operator |
|
357 opA : &'a A, |
|
358 /// Current inertial parameter |
|
359 λ : F, |
|
360 /// Previous iterate without inertia applied. |
|
361 /// We need to store this here because `μ_base` passed to [`FBSpecialisation::update`] will |
|
362 /// have inertia applied to it, so is not useful to use. |
|
363 μ_prev : DiscreteMeasure<Loc<F, N>, F>, |
|
364 } |
|
365 |
|
366 /// Implementation of [`FBSpecialisation`] for μFISTA inertial forward-backward splitting. |
|
367 #[replace_float_literals(F::cast_from(literal))] |
|
368 impl<'a, F : Float + ToNalgebraRealField, A : ForwardModel<Loc<F, N>, F>, const N : usize> |
|
369 FBSpecialisation<F, A::Observable, N> for FISTA<'a, F, A, N> { |
|
370 fn update( |
|
371 &mut self, |
|
372 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
373 _μ_base : &DiscreteMeasure<Loc<F, N>, F> |
|
374 ) -> (A::Observable, Option<F>) { |
|
375 // Update inertial parameters |
|
376 let λ_prev = self.λ; |
|
377 self.λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); |
|
378 let θ = self.λ / λ_prev - self.λ; |
|
379 // Perform inertial update on μ. |
|
380 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ |
|
381 // and μ_prev have zero weight. Since both have weights from the finite-dimensional |
|
382 // subproblem with a proximal projection step, this is likely to happen when the |
|
383 // spike is not needed. A copy of the pruned μ without artithmetic performed is |
|
384 // stored in μ_prev. |
|
385 μ.pruning_sub(1.0 + θ, θ, &mut self.μ_prev); |
|
386 |
|
387 //*residual = self.opA.apply(μ) - self.b; |
|
388 let mut residual = self.b.clone(); |
|
389 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
|
390 (residual, None) |
|
391 } |
|
392 |
|
393 fn calculate_fit_simple( |
|
394 &self, |
|
395 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
396 ) -> F { |
|
397 let mut residual = self.b.clone(); |
|
398 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
|
399 residual.norm2_squared_div2() |
|
400 } |
|
401 |
|
402 fn calculate_fit( |
|
403 &self, |
|
404 _μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
405 _residual : &A::Observable |
|
406 ) -> F { |
|
407 self.calculate_fit_simple(&self.μ_prev) |
|
408 } |
|
409 |
|
410 // For FISTA we need to do a final pruning as well, due to the limited |
|
411 // pruning that can be done on each step. |
|
412 fn postprocess(mut self, μ_base : DiscreteMeasure<Loc<F, N>, F>, merging : SpikeMergingMethod<F>) |
|
413 -> DiscreteMeasure<Loc<F, N>, F> |
|
414 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { |
|
415 let mut μ = self.μ_prev; |
|
416 self.μ_prev = μ_base; |
|
417 μ.merge_spikes_fitness(merging, |
|
418 |μ̃| self.calculate_fit_simple(μ̃), |
|
419 |&v| v); |
|
420 μ.prune(); |
|
421 μ |
|
422 } |
|
423 |
|
424 fn value_μ<'c, 'b : 'c>(&'c self, _μ : &'c DiscreteMeasure<Loc<F, N>, F>) |
|
425 -> &'c DiscreteMeasure<Loc<F, N>, F> { |
|
426 &self.μ_prev |
|
427 } |
|
428 } |
|
429 |
|
430 |
|
431 /// Abstraction of regularisation terms for [`generic_pointsource_fb_reg`]. |
|
432 pub trait RegTerm<F : Float + ToNalgebraRealField, const N : usize> |
|
433 : for<'a> Apply<&'a DiscreteMeasure<Loc<F, N>, F>, Output = F> { |
|
434 /// Approximately solve the problem |
|
435 /// <div>$$ |
|
436 /// \min_{x ∈ ℝ^n} \frac{1}{2} x^⊤Ax - g^⊤ x + τ G(x) |
|
437 /// $$</div> |
|
438 /// for $G$ depending on the trait implementation. |
|
439 /// |
|
440 /// The parameter `mA` is $A$. An estimate for its opeator norm should be provided in |
|
441 /// `mA_normest`. The initial iterate and output is `x`. The current main tolerance is `ε`. |
|
442 /// |
|
443 /// Returns the number of iterations taken. |
|
444 fn solve_findim( |
|
445 &self, |
|
446 mA : &DMatrix<F::MixedType>, |
|
447 g : &DVector<F::MixedType>, |
|
448 τ : F, |
|
449 x : &mut DVector<F::MixedType>, |
|
450 mA_normest : F, |
|
451 ε : F, |
|
452 config : &FBGenericConfig<F> |
|
453 ) -> usize; |
|
454 |
|
455 /// Find a point where `d` may violate the tolerance `ε`. |
|
456 /// |
|
457 /// If `skip_by_rough_check` is set, do not find the point if a rough check indicates that we |
|
458 /// are in bounds. `ε` is the current main tolerance and `τ` a scaling factor for the |
|
459 /// regulariser. |
|
460 /// |
|
461 /// Returns `None` if `d` is in bounds either based on the rough check, or a more precise check |
|
462 /// terminating early. Otherwise returns a possibly violating point, the value of `d` there, |
|
463 /// and a boolean indicating whether the found point is in bounds. |
|
464 fn find_tolerance_violation<G, BT>( |
|
465 &self, |
|
466 d : &mut BTFN<F, G, BT, N>, |
|
467 τ : F, |
|
468 ε : F, |
|
469 skip_by_rough_check : bool, |
|
470 config : &FBGenericConfig<F>, |
|
471 ) -> Option<(Loc<F, N>, F, bool)> |
|
472 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
473 G : SupportGenerator<F, N, Id=BT::Data>, |
|
474 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
475 + LocalAnalysis<F, Bounds<F>, N>; |
|
476 |
|
477 /// Verify that `d` is in bounds `ε` for a merge candidate `μ` |
|
478 /// |
|
479 /// `ε` is the current main tolerance and `τ` a scaling factor for the regulariser. |
|
480 fn verify_merge_candidate<G, BT>( |
|
481 &self, |
|
482 d : &mut BTFN<F, G, BT, N>, |
|
483 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
484 τ : F, |
|
485 ε : F, |
|
486 config : &FBGenericConfig<F>, |
|
487 ) -> bool |
|
488 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
489 G : SupportGenerator<F, N, Id=BT::Data>, |
|
490 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
491 + LocalAnalysis<F, Bounds<F>, N>; |
|
492 |
|
493 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>>; |
|
494 |
|
495 /// Returns a scaling factor for the tolerance sequence. |
|
496 /// |
|
497 /// Typically this is the regularisation parameter. |
|
498 fn tolerance_scaling(&self) -> F; |
|
499 } |
|
500 |
|
501 #[replace_float_literals(F::cast_from(literal))] |
|
502 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for NonnegRadonRegTerm<F> |
|
503 where Cube<F, N> : P2Minimise<Loc<F, N>, F> { |
|
504 fn solve_findim( |
|
505 &self, |
|
506 mA : &DMatrix<F::MixedType>, |
|
507 g : &DVector<F::MixedType>, |
|
508 τ : F, |
|
509 x : &mut DVector<F::MixedType>, |
|
510 mA_normest : F, |
|
511 ε : F, |
|
512 config : &FBGenericConfig<F> |
|
513 ) -> usize { |
|
514 let inner_tolerance = ε * config.inner.tolerance_mult; |
|
515 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); |
|
516 let inner_τ = config.inner.τ0 / mA_normest; |
|
517 quadratic_nonneg(config.inner.method, mA, g, τ * self.α(), x, |
|
518 inner_τ, inner_it) |
|
519 } |
|
520 |
|
521 #[inline] |
|
522 fn find_tolerance_violation<G, BT>( |
|
523 &self, |
|
524 d : &mut BTFN<F, G, BT, N>, |
|
525 τ : F, |
|
526 ε : F, |
|
527 skip_by_rough_check : bool, |
|
528 config : &FBGenericConfig<F>, |
|
529 ) -> Option<(Loc<F, N>, F, bool)> |
|
530 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
531 G : SupportGenerator<F, N, Id=BT::Data>, |
|
532 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
533 + LocalAnalysis<F, Bounds<F>, N> { |
|
534 let τα = τ * self.α(); |
|
535 let keep_below = τα + ε; |
|
536 let maximise_above = τα + ε * config.insertion_cutoff_factor; |
|
537 let refinement_tolerance = ε * config.refinement.tolerance_mult; |
|
538 |
|
539 // If preliminary check indicates that we are in bonds, and if it otherwise matches |
|
540 // the insertion strategy, skip insertion. |
|
541 if skip_by_rough_check && d.bounds().upper() <= keep_below { |
|
542 None |
|
543 } else { |
|
544 // If the rough check didn't indicate no insertion needed, find maximising point. |
|
545 d.maximise_above(maximise_above, refinement_tolerance, config.refinement.max_steps) |
|
546 .map(|(ξ, v_ξ)| (ξ, v_ξ, v_ξ <= keep_below)) |
|
547 } |
|
548 } |
|
549 |
|
550 fn verify_merge_candidate<G, BT>( |
|
551 &self, |
|
552 d : &mut BTFN<F, G, BT, N>, |
|
553 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
554 τ : F, |
|
555 ε : F, |
|
556 config : &FBGenericConfig<F>, |
|
557 ) -> bool |
|
558 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
559 G : SupportGenerator<F, N, Id=BT::Data>, |
|
560 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
561 + LocalAnalysis<F, Bounds<F>, N> { |
|
562 let τα = τ * self.α(); |
|
563 let refinement_tolerance = ε * config.refinement.tolerance_mult; |
|
564 let merge_tolerance = config.merge_tolerance_mult * ε; |
|
565 let keep_below = τα + merge_tolerance; |
|
566 let keep_supp_above = τα - merge_tolerance; |
|
567 let bnd = d.bounds(); |
|
568 |
|
569 return ( |
|
570 bnd.lower() >= keep_supp_above |
|
571 || |
|
572 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| { |
|
573 (β == 0.0) || d.apply(x) >= keep_supp_above |
|
574 }).all(std::convert::identity) |
|
575 ) && ( |
|
576 bnd.upper() <= keep_below |
|
577 || |
|
578 d.has_upper_bound(keep_below, refinement_tolerance, config.refinement.max_steps) |
|
579 ) |
|
580 } |
|
581 |
|
582 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> { |
|
583 let τα = τ * self.α(); |
|
584 Some(Bounds(τα - ε, τα + ε)) |
|
585 } |
|
586 |
|
587 fn tolerance_scaling(&self) -> F { |
|
588 self.α() |
|
589 } |
|
590 } |
|
591 |
|
592 #[replace_float_literals(F::cast_from(literal))] |
|
593 impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for RadonRegTerm<F> |
|
594 where Cube<F, N> : P2Minimise<Loc<F, N>, F> { |
|
595 fn solve_findim( |
|
596 &self, |
|
597 mA : &DMatrix<F::MixedType>, |
|
598 g : &DVector<F::MixedType>, |
|
599 τ : F, |
|
600 x : &mut DVector<F::MixedType>, |
|
601 mA_normest: F, |
|
602 ε : F, |
|
603 config : &FBGenericConfig<F> |
|
604 ) -> usize { |
|
605 let inner_tolerance = ε * config.inner.tolerance_mult; |
|
606 let inner_it = config.inner.iterator_options.stop_target(inner_tolerance); |
|
607 let inner_τ = config.inner.τ0 / mA_normest; |
|
608 quadratic_unconstrained(config.inner.method, mA, g, τ * self.α(), x, |
|
609 inner_τ, inner_it) |
|
610 } |
|
611 |
|
612 fn find_tolerance_violation<G, BT>( |
|
613 &self, |
|
614 d : &mut BTFN<F, G, BT, N>, |
|
615 τ : F, |
|
616 ε : F, |
|
617 skip_by_rough_check : bool, |
|
618 config : &FBGenericConfig<F>, |
|
619 ) -> Option<(Loc<F, N>, F, bool)> |
|
620 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
621 G : SupportGenerator<F, N, Id=BT::Data>, |
|
622 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
623 + LocalAnalysis<F, Bounds<F>, N> { |
|
624 let τα = τ * self.α(); |
|
625 let keep_below = τα + ε; |
|
626 let keep_above = -τα - ε; |
|
627 let maximise_above = τα + ε * config.insertion_cutoff_factor; |
|
628 let minimise_below = -τα - ε * config.insertion_cutoff_factor; |
|
629 let refinement_tolerance = ε * config.refinement.tolerance_mult; |
|
630 |
|
631 // If preliminary check indicates that we are in bonds, and if it otherwise matches |
|
632 // the insertion strategy, skip insertion. |
|
633 if skip_by_rough_check && Bounds(keep_above, keep_below).superset(&d.bounds()) { |
|
634 None |
|
635 } else { |
|
636 // If the rough check didn't indicate no insertion needed, find maximising point. |
|
637 let mx = d.maximise_above(maximise_above, refinement_tolerance, |
|
638 config.refinement.max_steps); |
|
639 let mi = d.minimise_below(minimise_below, refinement_tolerance, |
|
640 config.refinement.max_steps); |
|
641 |
|
642 match (mx, mi) { |
|
643 (None, None) => None, |
|
644 (Some((ξ, v_ξ)), None) => Some((ξ, v_ξ, keep_below >= v_ξ)), |
|
645 (None, Some((ζ, v_ζ))) => Some((ζ, v_ζ, keep_above <= v_ζ)), |
|
646 (Some((ξ, v_ξ)), Some((ζ, v_ζ))) => { |
|
647 if v_ξ - τα > τα - v_ζ { |
|
648 Some((ξ, v_ξ, keep_below >= v_ξ)) |
|
649 } else { |
|
650 Some((ζ, v_ζ, keep_above <= v_ζ)) |
|
651 } |
|
652 } |
|
653 } |
|
654 } |
|
655 } |
|
656 |
|
657 fn verify_merge_candidate<G, BT>( |
|
658 &self, |
|
659 d : &mut BTFN<F, G, BT, N>, |
|
660 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
661 τ : F, |
|
662 ε : F, |
|
663 config : &FBGenericConfig<F>, |
|
664 ) -> bool |
|
665 where BT : BTSearch<F, N, Agg=Bounds<F>>, |
|
666 G : SupportGenerator<F, N, Id=BT::Data>, |
|
667 G::SupportType : Mapping<Loc<F, N>,Codomain=F> |
|
668 + LocalAnalysis<F, Bounds<F>, N> { |
|
669 let τα = τ * self.α(); |
|
670 let refinement_tolerance = ε * config.refinement.tolerance_mult; |
|
671 let merge_tolerance = config.merge_tolerance_mult * ε; |
|
672 let keep_below = τα + merge_tolerance; |
|
673 let keep_above = -τα - merge_tolerance; |
|
674 let keep_supp_pos_above = τα - merge_tolerance; |
|
675 let keep_supp_neg_below = -τα + merge_tolerance; |
|
676 let bnd = d.bounds(); |
|
677 |
|
678 return ( |
|
679 (bnd.lower() >= keep_supp_pos_above && bnd.upper() <= keep_supp_neg_below) |
|
680 || |
|
681 μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| { |
|
682 use std::cmp::Ordering::*; |
|
683 match β.partial_cmp(&0.0) { |
|
684 Some(Greater) => d.apply(x) >= keep_supp_pos_above, |
|
685 Some(Less) => d.apply(x) <= keep_supp_neg_below, |
|
686 _ => true, |
|
687 } |
|
688 }).all(std::convert::identity) |
|
689 ) && ( |
|
690 bnd.upper() <= keep_below |
|
691 || |
|
692 d.has_upper_bound(keep_below, refinement_tolerance, |
|
693 config.refinement.max_steps) |
|
694 ) && ( |
|
695 bnd.lower() >= keep_above |
|
696 || |
|
697 d.has_lower_bound(keep_above, refinement_tolerance, |
|
698 config.refinement.max_steps) |
|
699 ) |
|
700 } |
|
701 |
|
702 fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> { |
|
703 let τα = τ * self.α(); |
|
704 Some(Bounds(-τα - ε, τα + ε)) |
|
705 } |
|
706 |
|
707 fn tolerance_scaling(&self) -> F { |
|
708 self.α() |
|
709 } |
|
710 } |
|
711 |
|
712 |
|
713 /// Generic implementation of [`pointsource_fb_reg`]. |
|
714 /// |
|
715 /// The method can be specialised to even primal-dual proximal splitting through the |
|
716 /// [`FBSpecialisation`] parameter `specialisation`. |
|
717 /// The settings in `config` have their [respective documentation](FBGenericConfig). `opA` is the |
|
718 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
450 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
719 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution |
451 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution |
720 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
452 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
721 /// as documented in [`alg_tools::iterate`]. |
453 /// as documented in [`alg_tools::iterate`]. |
|
454 /// |
|
455 /// For details on the mathematical formulation, see the [module level](self) documentation. |
722 /// |
456 /// |
723 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
457 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
724 /// sums of simple functions usign bisection trees, and the related |
458 /// sums of simple functions usign bisection trees, and the related |
725 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
459 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
726 /// active at a specific points, and to maximise their sums. Through the implementation of the |
460 /// active at a specific points, and to maximise their sums. Through the implementation of the |
727 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
461 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
728 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
462 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
729 /// |
463 /// |
730 /// Returns the final iterate. |
464 /// Returns the final iterate. |
731 #[replace_float_literals(F::cast_from(literal))] |
465 #[replace_float_literals(F::cast_from(literal))] |
732 pub fn generic_pointsource_fb_reg< |
466 pub fn pointsource_fb_reg< |
733 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Spec, Reg, const N : usize |
467 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize |
734 >( |
468 >( |
735 opA : &'a A, |
|
736 reg : Reg, |
|
737 op𝒟 : &'a 𝒟, |
|
738 mut τ : F, |
|
739 config : &FBGenericConfig<F>, |
|
740 iterator : I, |
|
741 mut plotter : SeqPlotter<F, N>, |
|
742 mut residual : A::Observable, |
|
743 mut specialisation : Spec |
|
744 ) -> DiscreteMeasure<Loc<F, N>, F> |
|
745 where F : Float + ToNalgebraRealField, |
|
746 I : AlgIteratorFactory<IterInfo<F, N>>, |
|
747 Spec : FBSpecialisation<F, A::Observable, N>, |
|
748 A::Observable : std::ops::MulAssign<F>, |
|
749 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
750 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
|
751 + Lipschitz<𝒟, FloatType=F>, |
|
752 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
753 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
|
754 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
|
755 𝒟::Codomain : RealMapping<F, N>, |
|
756 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
757 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
758 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
759 PlotLookup : Plotting<N>, |
|
760 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
761 Reg : RegTerm<F, N> { |
|
762 |
|
763 // Set up parameters |
|
764 let quiet = iterator.is_quiet(); |
|
765 let op𝒟norm = op𝒟.opnorm_bound(); |
|
766 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
|
767 // by τ compared to the conditional gradient approach. |
|
768 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
|
769 let mut ε = tolerance.initial(); |
|
770 |
|
771 // Initialise operators |
|
772 let preadjA = opA.preadjoint(); |
|
773 |
|
774 // Initialise iterates |
|
775 let mut μ = DiscreteMeasure::new(); |
|
776 |
|
777 let mut inner_iters = 0; |
|
778 let mut this_iters = 0; |
|
779 let mut pruned = 0; |
|
780 let mut merged = 0; |
|
781 |
|
782 let μ_diff = |μ_new : &DiscreteMeasure<Loc<F, N>, F>, |
|
783 μ_base : &DiscreteMeasure<Loc<F, N>, F>| { |
|
784 let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style { |
|
785 InsertionStyle::Reuse => { |
|
786 μ_new.iter_spikes() |
|
787 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0))) |
|
788 .map(|(δ, α_base)| (δ.x, α_base - δ.α)) |
|
789 .collect() |
|
790 }, |
|
791 InsertionStyle::Zero => { |
|
792 μ_new.iter_spikes() |
|
793 .map(|δ| -δ) |
|
794 .chain(μ_base.iter_spikes().copied()) |
|
795 .collect() |
|
796 } |
|
797 }; |
|
798 ν.prune(); // Potential small performance improvement |
|
799 ν |
|
800 }; |
|
801 |
|
802 // Run the algorithm |
|
803 iterator.iterate(|state| { |
|
804 // Maximum insertion count and measure difference calculation depend on insertion style. |
|
805 let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
|
806 (i, Some((l, k))) if i <= l => (k, false), |
|
807 _ => (config.max_insertions, !quiet), |
|
808 }; |
|
809 let max_insertions = match config.insertion_style { |
|
810 InsertionStyle::Zero => { |
|
811 todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled."); |
|
812 // let n = μ.len(); |
|
813 // μ = DiscreteMeasure::new(); |
|
814 // n + m |
|
815 }, |
|
816 InsertionStyle::Reuse => m, |
|
817 }; |
|
818 |
|
819 // Calculate smooth part of surrogate model. |
|
820 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
|
821 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
822 // the residual and replacing it below before the end of this closure. |
|
823 residual *= -τ; |
|
824 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
825 let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b) |
|
826 // TODO: should avoid a second copy of μ here; μ_base already stores a copy. |
|
827 let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k |
|
828 //let g = &minus_τv + ω0; // Linear term of surrogate model |
|
829 |
|
830 // Save current base point |
|
831 let μ_base = μ.clone(); |
|
832 |
|
833 // Add points to support until within error tolerance or maximum insertion count reached. |
|
834 let mut count = 0; |
|
835 let (within_tolerances, d) = 'insertion: loop { |
|
836 if μ.len() > 0 { |
|
837 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
|
838 // from the beginning of the iteration are all contained in the immutable c and g. |
|
839 let à = op𝒟.findim_matrix(μ.iter_locations()); |
|
840 let g̃ = DVector::from_iterator(μ.len(), |
|
841 μ.iter_locations() |
|
842 .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ)) |
|
843 .map(F::to_nalgebra_mixed)); |
|
844 let mut x = μ.masses_dvector(); |
|
845 |
|
846 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. |
|
847 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ |
|
848 // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ |
|
849 // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2 |
|
850 // = n |𝒟| |x|_2, where n is the number of points. Therefore |
|
851 let Ã_normest = op𝒟norm * F::cast_from(μ.len()); |
|
852 |
|
853 // Solve finite-dimensional subproblem. |
|
854 inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config); |
|
855 |
|
856 // Update masses of μ based on solution of finite-dimensional subproblem. |
|
857 μ.set_masses_dvector(&x); |
|
858 } |
|
859 |
|
860 // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality |
|
861 // conditions in the predual space, and finding new points for insertion, if necessary. |
|
862 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ, &μ_base)); |
|
863 |
|
864 // If no merging heuristic is used, let's be more conservative about spike insertion, |
|
865 // and skip it after first round. If merging is done, being more greedy about spike |
|
866 // insertion also seems to improve performance. |
|
867 let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { |
|
868 false |
|
869 } else { |
|
870 count > 0 |
|
871 }; |
|
872 |
|
873 // Find a spike to insert, if needed |
|
874 let (ξ, _v_ξ, in_bounds) = match reg.find_tolerance_violation( |
|
875 &mut d, τ, ε, skip_by_rough_check, config |
|
876 ) { |
|
877 None => break 'insertion (true, d), |
|
878 Some(res) => res, |
|
879 }; |
|
880 |
|
881 // Break if maximum insertion count reached |
|
882 if count >= max_insertions { |
|
883 break 'insertion (in_bounds, d) |
|
884 } |
|
885 |
|
886 // No point in optimising the weight here; the finite-dimensional algorithm is fast. |
|
887 μ += DeltaMeasure { x : ξ, α : 0.0 }; |
|
888 count += 1; |
|
889 }; |
|
890 |
|
891 if !within_tolerances && warn_insertions { |
|
892 // Complain (but continue) if we failed to get within tolerances |
|
893 // by inserting more points. |
|
894 let err = format!("Maximum insertions reached without achieving \ |
|
895 subproblem solution tolerance"); |
|
896 println!("{}", err.red()); |
|
897 } |
|
898 |
|
899 // Merge spikes |
|
900 if state.iteration() % config.merge_every == 0 { |
|
901 let n_before_merge = μ.len(); |
|
902 μ.merge_spikes(config.merging, |μ_candidate| { |
|
903 let mut d = &minus_τv + op𝒟.preapply(μ_diff(&μ_candidate, &μ_base)); |
|
904 |
|
905 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
|
906 .then_some(()) |
|
907 }); |
|
908 debug_assert!(μ.len() >= n_before_merge); |
|
909 merged += μ.len() - n_before_merge; |
|
910 } |
|
911 |
|
912 let n_before_prune = μ.len(); |
|
913 (residual, τ) = match specialisation.update(&mut μ, &μ_base) { |
|
914 (r, None) => (r, τ), |
|
915 (r, Some(new_τ)) => (r, new_τ) |
|
916 }; |
|
917 debug_assert!(μ.len() <= n_before_prune); |
|
918 pruned += n_before_prune - μ.len(); |
|
919 |
|
920 this_iters += 1; |
|
921 |
|
922 // Update main tolerance for next iteration |
|
923 let ε_prev = ε; |
|
924 ε = tolerance.update(ε, state.iteration()); |
|
925 |
|
926 // Give function value if needed |
|
927 state.if_verbose(|| { |
|
928 let value_μ = specialisation.value_μ(&μ); |
|
929 // Plot if so requested |
|
930 plotter.plot_spikes( |
|
931 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, |
|
932 "start".to_string(), Some(&minus_τv), |
|
933 reg.target_bounds(τ, ε_prev), value_μ, |
|
934 ); |
|
935 // Calculate mean inner iterations and reset relevant counters. |
|
936 // Return the statistics |
|
937 let res = IterInfo { |
|
938 value : specialisation.calculate_fit(&μ, &residual) + reg.apply(value_μ), |
|
939 n_spikes : value_μ.len(), |
|
940 inner_iters, |
|
941 this_iters, |
|
942 merged, |
|
943 pruned, |
|
944 ε : ε_prev, |
|
945 postprocessing: config.postprocessing.then(|| value_μ.clone()), |
|
946 }; |
|
947 inner_iters = 0; |
|
948 this_iters = 0; |
|
949 merged = 0; |
|
950 pruned = 0; |
|
951 res |
|
952 }) |
|
953 }); |
|
954 |
|
955 specialisation.postprocess(μ, config.final_merging) |
|
956 } |
|
957 |
|
958 /// Iteratively solve the pointsource localisation problem using forward-backward splitting |
|
959 /// |
|
960 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
|
961 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
|
962 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution |
|
963 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
|
964 /// as documented in [`alg_tools::iterate`]. |
|
965 /// |
|
966 /// For details on the mathematical formulation, see the [module level](self) documentation. |
|
967 /// |
|
968 /// Returns the final iterate. |
|
969 #[replace_float_literals(F::cast_from(literal))] |
|
970 pub fn pointsource_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>( |
|
971 opA : &'a A, |
469 opA : &'a A, |
972 b : &A::Observable, |
470 b : &A::Observable, |
973 reg : Reg, |
471 reg : Reg, |
974 op𝒟 : &'a 𝒟, |
472 op𝒟 : &'a 𝒟, |
975 config : &FBConfig<F>, |
473 fbconfig : &FBConfig<F>, |
976 iterator : I, |
474 iterator : I, |
977 plotter : SeqPlotter<F, N>, |
475 mut plotter : SeqPlotter<F, N>, |
978 ) -> DiscreteMeasure<Loc<F, N>, F> |
476 ) -> DiscreteMeasure<Loc<F, N>, F> |
979 where F : Float + ToNalgebraRealField, |
477 where F : Float + ToNalgebraRealField, |
980 I : AlgIteratorFactory<IterInfo<F, N>>, |
478 I : AlgIteratorFactory<IterInfo<F, N>>, |
981 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
479 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
982 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
480 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
983 A::Observable : std::ops::MulAssign<F>, |
481 A::Observable : std::ops::MulAssign<F>, |
984 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
482 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
985 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
483 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
986 + Lipschitz<𝒟, FloatType=F>, |
484 + Lipschitz<&'a 𝒟, FloatType=F>, |
987 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
485 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
988 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
486 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
989 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
487 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
990 𝒟::Codomain : RealMapping<F, N>, |
488 𝒟::Codomain : RealMapping<F, N>, |
991 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
489 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
994 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
492 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
995 PlotLookup : Plotting<N>, |
493 PlotLookup : Plotting<N>, |
996 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
494 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
997 Reg : RegTerm<F, N> { |
495 Reg : RegTerm<F, N> { |
998 |
496 |
999 let initial_residual = -b; |
497 // Set up parameters |
1000 let τ = config.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); |
498 let config = &fbconfig.insertion; |
1001 |
499 let op𝒟norm = op𝒟.opnorm_bound(); |
1002 match config.meta { |
500 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); |
1003 FBMetaAlgorithm::None => generic_pointsource_fb_reg( |
501 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
1004 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, |
502 // by τ compared to the conditional gradient approach. |
1005 BasicFB{ b, opA }, |
503 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
1006 ), |
504 let mut ε = tolerance.initial(); |
1007 FBMetaAlgorithm::InertiaFISTA => generic_pointsource_fb_reg( |
505 |
1008 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, initial_residual, |
506 // Initialise iterates |
1009 FISTA{ b, opA, λ : 1.0, μ_prev : DiscreteMeasure::new() }, |
507 let mut μ = DiscreteMeasure::new(); |
1010 ), |
508 let mut residual = -b; |
1011 } |
509 let mut stats = IterInfo::new(); |
1012 } |
510 |
|
511 // Run the algorithm |
|
512 iterator.iterate(|state| { |
|
513 // Calculate smooth part of surrogate model. |
|
514 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
|
515 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
516 // the residual and replacing it below before the end of this closure. |
|
517 residual *= -τ; |
|
518 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
519 let minus_τv = opA.preadjoint().apply(r); |
|
520 |
|
521 // Save current base point |
|
522 let μ_base = μ.clone(); |
|
523 |
|
524 // Insert and reweigh |
|
525 let (d, within_tolerances) = insert_and_reweigh( |
|
526 &mut μ, &minus_τv, &μ_base, None, |
|
527 op𝒟, op𝒟norm, |
|
528 τ, ε, |
|
529 config, ®, state, &mut stats |
|
530 ); |
|
531 |
|
532 // Prune and possibly merge spikes |
|
533 prune_and_maybe_simple_merge( |
|
534 &mut μ, &minus_τv, &μ_base, |
|
535 op𝒟, |
|
536 τ, ε, |
|
537 config, ®, state, &mut stats |
|
538 ); |
|
539 |
|
540 // Update residual |
|
541 residual = calculate_residual(&μ, opA, b); |
|
542 |
|
543 // Update main tolerance for next iteration |
|
544 let ε_prev = ε; |
|
545 ε = tolerance.update(ε, state.iteration()); |
|
546 stats.this_iters += 1; |
|
547 |
|
548 // Give function value if needed |
|
549 state.if_verbose(|| { |
|
550 // Plot if so requested |
|
551 plotter.plot_spikes( |
|
552 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, |
|
553 "start".to_string(), Some(&minus_τv), |
|
554 reg.target_bounds(τ, ε_prev), &μ, |
|
555 ); |
|
556 // Calculate mean inner iterations and reset relevant counters. |
|
557 // Return the statistics |
|
558 let res = IterInfo { |
|
559 value : residual.norm2_squared_div2() + reg.apply(&μ), |
|
560 n_spikes : μ.len(), |
|
561 ε : ε_prev, |
|
562 postprocessing: config.postprocessing.then(|| μ.clone()), |
|
563 .. stats |
|
564 }; |
|
565 stats = IterInfo::new(); |
|
566 res |
|
567 }) |
|
568 }); |
|
569 |
|
570 postprocess(μ, config, L2Squared, opA, b) |
|
571 } |
|
572 |
|
573 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. |
|
574 /// |
|
575 /// The settings in `config` have their [respective documentation](FBConfig). `opA` is the |
|
576 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
|
577 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution |
|
578 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
|
579 /// as documented in [`alg_tools::iterate`]. |
|
580 /// |
|
581 /// For details on the mathematical formulation, see the [module level](self) documentation. |
|
582 /// |
|
583 /// The implementation relies on [`alg_tools::bisection_tree::BTFN`] presentations of |
|
584 /// sums of simple functions usign bisection trees, and the related |
|
585 /// [`alg_tools::bisection_tree::Aggregator`]s, to efficiently search for component functions |
|
586 /// active at a specific points, and to maximise their sums. Through the implementation of the |
|
587 /// [`alg_tools::bisection_tree::BT`] bisection trees, it also relies on the copy-on-write features |
|
588 /// of [`std::sync::Arc`] to only update relevant parts of the bisection tree when adding functions. |
|
589 /// |
|
590 /// Returns the final iterate. |
|
591 #[replace_float_literals(F::cast_from(literal))] |
|
592 pub fn pointsource_fista_reg< |
|
593 'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize |
|
594 >( |
|
595 opA : &'a A, |
|
596 b : &A::Observable, |
|
597 reg : Reg, |
|
598 op𝒟 : &'a 𝒟, |
|
599 fbconfig : &FBConfig<F>, |
|
600 iterator : I, |
|
601 mut plotter : SeqPlotter<F, N>, |
|
602 ) -> DiscreteMeasure<Loc<F, N>, F> |
|
603 where F : Float + ToNalgebraRealField, |
|
604 I : AlgIteratorFactory<IterInfo<F, N>>, |
|
605 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
|
606 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
|
607 A::Observable : std::ops::MulAssign<F>, |
|
608 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
609 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
|
610 + Lipschitz<&'a 𝒟, FloatType=F>, |
|
611 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
612 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
|
613 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
|
614 𝒟::Codomain : RealMapping<F, N>, |
|
615 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
616 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
617 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
618 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
|
619 PlotLookup : Plotting<N>, |
|
620 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
621 Reg : RegTerm<F, N> { |
|
622 |
|
623 // Set up parameters |
|
624 let config = &fbconfig.insertion; |
|
625 let op𝒟norm = op𝒟.opnorm_bound(); |
|
626 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); |
|
627 let mut λ = 1.0; |
|
628 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
|
629 // by τ compared to the conditional gradient approach. |
|
630 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
|
631 let mut ε = tolerance.initial(); |
|
632 |
|
633 // Initialise iterates |
|
634 let mut μ = DiscreteMeasure::new(); |
|
635 let mut μ_prev = DiscreteMeasure::new(); |
|
636 let mut residual = -b; |
|
637 let mut stats = IterInfo::new(); |
|
638 let mut warned_merging = false; |
|
639 |
|
640 // Run the algorithm |
|
641 iterator.iterate(|state| { |
|
642 // Calculate smooth part of surrogate model. |
|
643 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
|
644 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
645 // the residual and replacing it below before the end of this closure. |
|
646 residual *= -τ; |
|
647 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
648 let minus_τv = opA.preadjoint().apply(r); |
|
649 |
|
650 // Save current base point |
|
651 let μ_base = μ.clone(); |
|
652 |
|
653 // Insert new spikes and reweigh |
|
654 let (d, within_tolerances) = insert_and_reweigh( |
|
655 &mut μ, &minus_τv, &μ_base, None, |
|
656 op𝒟, op𝒟norm, |
|
657 τ, ε, |
|
658 config, ®, state, &mut stats |
|
659 ); |
|
660 |
|
661 // (Do not) merge spikes. |
|
662 if state.iteration() % config.merge_every == 0 { |
|
663 match config.merging { |
|
664 SpikeMergingMethod::None => { }, |
|
665 _ => if !warned_merging { |
|
666 let err = format!("Merging not supported for μFISTA"); |
|
667 println!("{}", err.red()); |
|
668 warned_merging = true; |
|
669 } |
|
670 } |
|
671 } |
|
672 |
|
673 // Update inertial prameters |
|
674 let λ_prev = λ; |
|
675 λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); |
|
676 let θ = λ / λ_prev - λ; |
|
677 |
|
678 // Perform inertial update on μ. |
|
679 // This computes μ ← (1 + θ) * μ - θ * μ_prev, pruning spikes where both μ |
|
680 // and μ_prev have zero weight. Since both have weights from the finite-dimensional |
|
681 // subproblem with a proximal projection step, this is likely to happen when the |
|
682 // spike is not needed. A copy of the pruned μ without artithmetic performed is |
|
683 // stored in μ_prev. |
|
684 let n_before_prune = μ.len(); |
|
685 μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); |
|
686 debug_assert!(μ.len() <= n_before_prune); |
|
687 stats.pruned += n_before_prune - μ.len(); |
|
688 |
|
689 // Update residual |
|
690 residual = calculate_residual(&μ, opA, b); |
|
691 |
|
692 // Update main tolerance for next iteration |
|
693 let ε_prev = ε; |
|
694 ε = tolerance.update(ε, state.iteration()); |
|
695 stats.this_iters += 1; |
|
696 |
|
697 // Give function value if needed |
|
698 state.if_verbose(|| { |
|
699 // Plot if so requested |
|
700 plotter.plot_spikes( |
|
701 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, |
|
702 "start".to_string(), Some(&minus_τv), |
|
703 reg.target_bounds(τ, ε_prev), &μ_prev, |
|
704 ); |
|
705 // Calculate mean inner iterations and reset relevant counters. |
|
706 // Return the statistics |
|
707 let res = IterInfo { |
|
708 value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev), |
|
709 n_spikes : μ_prev.len(), |
|
710 ε : ε_prev, |
|
711 postprocessing: config.postprocessing.then(|| μ_prev.clone()), |
|
712 .. stats |
|
713 }; |
|
714 stats = IterInfo::new(); |
|
715 res |
|
716 }) |
|
717 }); |
|
718 |
|
719 postprocess(μ_prev, config, L2Squared, opA, b) |
|
720 } |