| 219 merging : SpikeMergingMethod::None, |
223 merging : SpikeMergingMethod::None, |
| 220 //merging : Default::default(), |
224 //merging : Default::default(), |
| 221 final_merging : Default::default(), |
225 final_merging : Default::default(), |
| 222 merge_every : 10, |
226 merge_every : 10, |
| 223 merge_tolerance_mult : 2.0, |
227 merge_tolerance_mult : 2.0, |
| 224 postprocessing : false, |
228 // postprocessing : false, |
| 225 } |
229 } |
| |
230 } |
| |
231 } |
| |
232 |
| |
233 impl<F : Float> FBGenericConfig<F> { |
| |
234 /// Check if merging should be attempted this iteration |
| |
235 pub fn merge_now<I : AlgIterator>(&self, state : &AlgIteratorIteration<I>) -> bool { |
| |
236 state.iteration() % self.merge_every == 0 |
| 226 } |
237 } |
| 227 } |
238 } |
| 228 |
239 |
| 229 /// TODO: document. |
240 /// TODO: document. |
| 230 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike |
241 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike |
| 231 /// locations, while `ν_delta` may have different locations. |
242 /// locations, while `ν_delta` may have different locations. |
| 232 #[replace_float_literals(F::cast_from(literal))] |
243 #[replace_float_literals(F::cast_from(literal))] |
| 233 pub(crate) fn insert_and_reweigh< |
244 pub(crate) fn insert_and_reweigh< |
| 234 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize |
245 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, I, const N : usize |
| 235 >( |
246 >( |
| 236 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
247 μ : &mut RNDM<F, N>, |
| 237 minus_τv : &BTFN<F, GA, BTA, N>, |
248 τv : &BTFN<F, GA, BTA, N>, |
| 238 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
249 μ_base : &RNDM<F, N>, |
| 239 ν_delta: Option<&DiscreteMeasure<Loc<F, N>, F>>, |
250 ν_delta: Option<&RNDM<F, N>>, |
| 240 op𝒟 : &'a 𝒟, |
251 op𝒟 : &'a 𝒟, |
| 241 op𝒟norm : F, |
252 op𝒟norm : F, |
| 242 τ : F, |
253 τ : F, |
| 243 ε : F, |
254 ε : F, |
| 244 config : &FBGenericConfig<F>, |
255 config : &FBGenericConfig<F>, |
| 245 reg : &Reg, |
256 reg : &Reg, |
| 246 state : &State, |
257 state : &AlgIteratorIteration<I>, |
| 247 stats : &mut IterInfo<F, N>, |
258 stats : &mut IterInfo<F, N>, |
| 248 ) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool) |
259 ) -> (BTFN<F, BothGenerators<GA, G𝒟>, BTA, N>, bool) |
| 249 where F : Float + ToNalgebraRealField, |
260 where F : Float + ToNalgebraRealField, |
| 250 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
261 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
| 251 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
262 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
| 253 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
264 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
| 254 𝒟::Codomain : RealMapping<F, N>, |
265 𝒟::Codomain : RealMapping<F, N>, |
| 255 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
266 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| 256 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
267 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| 257 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
268 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| 258 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
| 259 Reg : RegTerm<F, N>, |
269 Reg : RegTerm<F, N>, |
| 260 State : AlgIteratorState { |
270 I : AlgIterator { |
| 261 |
271 |
| 262 // Maximum insertion count and measure difference calculation depend on insertion style. |
272 // Maximum insertion count and measure difference calculation depend on insertion style. |
| 263 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
273 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
| 264 (i, Some((l, k))) if i <= l => (k, false), |
274 (i, Some((l, k))) if i <= l => (k, false), |
| 265 _ => (config.max_insertions, !state.is_quiet()), |
275 _ => (config.max_insertions, !state.is_quiet()), |
| 266 }; |
276 }; |
| 267 |
277 |
| 268 // TODO: should avoid a copy of μ_base here. |
278 let ω0 = match ν_delta { |
| 269 let ω0 = op𝒟.apply(match ν_delta { |
279 None => op𝒟.apply(μ_base), |
| 270 None => μ_base.clone(), |
280 Some(ν) => op𝒟.apply(μ_base + ν), |
| 271 Some(ν_d) => &*μ_base + ν_d, |
281 }; |
| 272 }); |
|
| 273 |
282 |
| 274 // Add points to support until within error tolerance or maximum insertion count reached. |
283 // Add points to support until within error tolerance or maximum insertion count reached. |
| 275 let mut count = 0; |
284 let mut count = 0; |
| 276 let (within_tolerances, d) = 'insertion: loop { |
285 let (within_tolerances, d) = 'insertion: loop { |
| 277 if μ.len() > 0 { |
286 if μ.len() > 0 { |
| 278 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
287 // Form finite-dimensional subproblem. The subproblem references to the original μ^k |
| 279 // from the beginning of the iteration are all contained in the immutable c and g. |
288 // from the beginning of the iteration are all contained in the immutable c and g. |
| |
289 // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional |
| |
290 // problems have not yet been updated to sign change. |
| 280 let à = op𝒟.findim_matrix(μ.iter_locations()); |
291 let à = op𝒟.findim_matrix(μ.iter_locations()); |
| 281 let g̃ = DVector::from_iterator(μ.len(), |
292 let g̃ = DVector::from_iterator(μ.len(), |
| 282 μ.iter_locations() |
293 μ.iter_locations() |
| 283 .map(|ζ| minus_τv.apply(ζ) + ω0.apply(ζ)) |
294 .map(|ζ| ω0.apply(ζ) - τv.apply(ζ)) |
| 284 .map(F::to_nalgebra_mixed)); |
295 .map(F::to_nalgebra_mixed)); |
| 285 let mut x = μ.masses_dvector(); |
296 let mut x = μ.masses_dvector(); |
| 286 |
297 |
| 287 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. |
298 // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃. |
| 288 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ |
299 // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩ |
| 344 } |
353 } |
| 345 |
354 |
| 346 (d, within_tolerances) |
355 (d, within_tolerances) |
| 347 } |
356 } |
| 348 |
357 |
| 349 #[replace_float_literals(F::cast_from(literal))] |
358 pub(crate) fn prune_with_stats<F : Float, const N : usize>( |
| 350 pub(crate) fn prune_and_maybe_simple_merge< |
359 μ : &mut RNDM<F, N>, |
| 351 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize |
360 ) -> usize { |
| 352 >( |
|
| 353 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
| 354 minus_τv : &BTFN<F, GA, BTA, N>, |
|
| 355 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 356 op𝒟 : &'a 𝒟, |
|
| 357 τ : F, |
|
| 358 ε : F, |
|
| 359 config : &FBGenericConfig<F>, |
|
| 360 reg : &Reg, |
|
| 361 state : &State, |
|
| 362 stats : &mut IterInfo<F, N>, |
|
| 363 ) |
|
| 364 where F : Float + ToNalgebraRealField, |
|
| 365 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
| 366 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
| 367 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
|
| 368 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
|
| 369 𝒟::Codomain : RealMapping<F, N>, |
|
| 370 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 371 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 372 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 373 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
| 374 Reg : RegTerm<F, N>, |
|
| 375 State : AlgIteratorState { |
|
| 376 if state.iteration() % config.merge_every == 0 { |
|
| 377 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
|
| 378 let mut d = minus_τv + op𝒟.preapply(μ_base.sub_matching(&μ_candidate)); |
|
| 379 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
|
| 380 }); |
|
| 381 } |
|
| 382 |
|
| 383 let n_before_prune = μ.len(); |
361 let n_before_prune = μ.len(); |
| 384 μ.prune(); |
362 μ.prune(); |
| 385 debug_assert!(μ.len() <= n_before_prune); |
363 debug_assert!(μ.len() <= n_before_prune); |
| 386 stats.pruned += n_before_prune - μ.len(); |
364 n_before_prune - μ.len() |
| 387 } |
365 } |
| 388 |
366 |
| 389 #[replace_float_literals(F::cast_from(literal))] |
367 #[replace_float_literals(F::cast_from(literal))] |
| 390 pub(crate) fn postprocess< |
368 pub(crate) fn postprocess< |
| 391 F : Float, |
369 F : Float, |
| 392 V : Euclidean<F> + Clone, |
370 V : Euclidean<F> + Clone, |
| 393 A : GEMV<F, DiscreteMeasure<Loc<F, N>, F>, Codomain = V>, |
371 A : GEMV<F, RNDM<F, N>, Codomain = V>, |
| 394 D : DataTerm<F, V, N>, |
372 D : DataTerm<F, V, N>, |
| 395 const N : usize |
373 const N : usize |
| 396 > ( |
374 > ( |
| 397 mut μ : DiscreteMeasure<Loc<F, N>, F>, |
375 mut μ : RNDM<F, N>, |
| 398 config : &FBGenericConfig<F>, |
376 config : &FBGenericConfig<F>, |
| 399 dataterm : D, |
377 dataterm : D, |
| 400 opA : &A, |
378 opA : &A, |
| 401 b : &V, |
379 b : &V, |
| 402 ) -> DiscreteMeasure<Loc<F, N>, F> |
380 ) -> RNDM<F, N> |
| 403 where DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F> { |
381 where |
| |
382 RNDM<F, N> : SpikeMerging<F>, |
| |
383 for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>, |
| |
384 { |
| 404 μ.merge_spikes_fitness(config.merging, |
385 μ.merge_spikes_fitness(config.merging, |
| 405 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
386 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
| 406 |&v| v); |
387 |&v| v); |
| 407 μ.prune(); |
388 μ.prune(); |
| 408 μ |
389 μ |
| 435 reg : Reg, |
416 reg : Reg, |
| 436 op𝒟 : &'a 𝒟, |
417 op𝒟 : &'a 𝒟, |
| 437 fbconfig : &FBConfig<F>, |
418 fbconfig : &FBConfig<F>, |
| 438 iterator : I, |
419 iterator : I, |
| 439 mut plotter : SeqPlotter<F, N>, |
420 mut plotter : SeqPlotter<F, N>, |
| 440 ) -> DiscreteMeasure<Loc<F, N>, F> |
421 ) -> RNDM<F, N> |
| 441 where F : Float + ToNalgebraRealField, |
422 where F : Float + ToNalgebraRealField, |
| 442 I : AlgIteratorFactory<IterInfo<F, N>>, |
423 I : AlgIteratorFactory<IterInfo<F, N>>, |
| 443 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
424 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
| 444 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
|
| 445 A::Observable : std::ops::MulAssign<F>, |
|
| 446 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
425 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
| 447 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
426 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
| 448 + Lipschitz<&'a 𝒟, FloatType=F>, |
427 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
| 449 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
428 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
| 450 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
429 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
| 451 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
430 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
| 452 𝒟::Codomain : RealMapping<F, N>, |
431 𝒟::Codomain : RealMapping<F, N>, |
| 453 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
432 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| 454 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
433 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| 455 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
434 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| 456 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
435 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
| 457 PlotLookup : Plotting<N>, |
436 PlotLookup : Plotting<N>, |
| 458 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
437 RNDM<F, N> : SpikeMerging<F>, |
| 459 Reg : RegTerm<F, N> { |
438 Reg : RegTerm<F, N> { |
| 460 |
439 |
| 461 // Set up parameters |
440 // Set up parameters |
| 462 let config = &fbconfig.generic; |
441 let config = &fbconfig.generic; |
| 463 let op𝒟norm = op𝒟.opnorm_bound(); |
442 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
| 464 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); |
443 let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap(); |
| 465 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
444 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 466 // by τ compared to the conditional gradient approach. |
445 // by τ compared to the conditional gradient approach. |
| 467 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
446 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 468 let mut ε = tolerance.initial(); |
447 let mut ε = tolerance.initial(); |
| 469 |
448 |
| 470 // Initialise iterates |
449 // Initialise iterates |
| 471 let mut μ = DiscreteMeasure::new(); |
450 let mut μ = DiscreteMeasure::new(); |
| 472 let mut residual = -b; |
451 let mut residual = -b; |
| |
452 |
| |
453 // Statistics |
| |
454 let full_stats = |residual : &A::Observable, |
| |
455 μ : &RNDM<F, N>, |
| |
456 ε, stats| IterInfo { |
| |
457 value : residual.norm2_squared_div2() + reg.apply(μ), |
| |
458 n_spikes : μ.len(), |
| |
459 ε, |
| |
460 //postprocessing: config.postprocessing.then(|| μ.clone()), |
| |
461 .. stats |
| |
462 }; |
| 473 let mut stats = IterInfo::new(); |
463 let mut stats = IterInfo::new(); |
| 474 |
464 |
| 475 // Run the algorithm |
465 // Run the algorithm |
| 476 iterator.iterate(|state| { |
466 for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) { |
| 477 // Calculate smooth part of surrogate model. |
467 // Calculate smooth part of surrogate model. |
| 478 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
468 let τv = opA.preadjoint().apply(residual * τ); |
| 479 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
| 480 // the residual and replacing it below before the end of this closure. |
|
| 481 residual *= -τ; |
|
| 482 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
| 483 let minus_τv = opA.preadjoint().apply(r); |
|
| 484 |
469 |
| 485 // Save current base point |
470 // Save current base point |
| 486 let μ_base = μ.clone(); |
471 let μ_base = μ.clone(); |
| 487 |
472 |
| 488 // Insert and reweigh |
473 // Insert and reweigh |
| 489 let (d, within_tolerances) = insert_and_reweigh( |
474 let (d, _within_tolerances) = insert_and_reweigh( |
| 490 &mut μ, &minus_τv, &μ_base, None, |
475 &mut μ, &τv, &μ_base, None, |
| 491 op𝒟, op𝒟norm, |
476 op𝒟, op𝒟norm, |
| 492 τ, ε, |
477 τ, ε, |
| 493 config, ®, state, &mut stats |
478 config, ®, &state, &mut stats |
| 494 ); |
479 ); |
| 495 |
480 |
| 496 // Prune and possibly merge spikes |
481 // Prune and possibly merge spikes |
| 497 prune_and_maybe_simple_merge( |
482 if config.merge_now(&state) { |
| 498 &mut μ, &minus_τv, &μ_base, |
483 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
| 499 op𝒟, |
484 let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); |
| 500 τ, ε, |
485 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
| 501 config, ®, state, &mut stats |
486 }); |
| 502 ); |
487 } |
| |
488 stats.pruned += prune_with_stats(&mut μ); |
| 503 |
489 |
| 504 // Update residual |
490 // Update residual |
| 505 residual = calculate_residual(&μ, opA, b); |
491 residual = calculate_residual(&μ, opA, b); |
| 506 |
492 |
| |
493 let iter = state.iteration(); |
| |
494 stats.this_iters += 1; |
| |
495 |
| |
496 // Give statistics if needed |
| |
497 state.if_verbose(|| { |
| |
498 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); |
| |
499 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| |
500 }); |
| |
501 |
| 507 // Update main tolerance for next iteration |
502 // Update main tolerance for next iteration |
| 508 let ε_prev = ε; |
503 ε = tolerance.update(ε, iter); |
| 509 ε = tolerance.update(ε, state.iteration()); |
504 } |
| 510 stats.this_iters += 1; |
|
| 511 |
|
| 512 // Give function value if needed |
|
| 513 state.if_verbose(|| { |
|
| 514 // Plot if so requested |
|
| 515 plotter.plot_spikes( |
|
| 516 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, |
|
| 517 "start".to_string(), Some(&minus_τv), |
|
| 518 reg.target_bounds(τ, ε_prev), &μ, |
|
| 519 ); |
|
| 520 // Calculate mean inner iterations and reset relevant counters. |
|
| 521 // Return the statistics |
|
| 522 let res = IterInfo { |
|
| 523 value : residual.norm2_squared_div2() + reg.apply(&μ), |
|
| 524 n_spikes : μ.len(), |
|
| 525 ε : ε_prev, |
|
| 526 postprocessing: config.postprocessing.then(|| μ.clone()), |
|
| 527 .. stats |
|
| 528 }; |
|
| 529 stats = IterInfo::new(); |
|
| 530 res |
|
| 531 }) |
|
| 532 }); |
|
| 533 |
505 |
| 534 postprocess(μ, config, L2Squared, opA, b) |
506 postprocess(μ, config, L2Squared, opA, b) |
| 535 } |
507 } |
| 536 |
508 |
| 537 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. |
509 /// Iteratively solve the pointsource localisation problem using inertial forward-backward splitting. |
| 561 reg : Reg, |
533 reg : Reg, |
| 562 op𝒟 : &'a 𝒟, |
534 op𝒟 : &'a 𝒟, |
| 563 fbconfig : &FBConfig<F>, |
535 fbconfig : &FBConfig<F>, |
| 564 iterator : I, |
536 iterator : I, |
| 565 mut plotter : SeqPlotter<F, N>, |
537 mut plotter : SeqPlotter<F, N>, |
| 566 ) -> DiscreteMeasure<Loc<F, N>, F> |
538 ) -> RNDM<F, N> |
| 567 where F : Float + ToNalgebraRealField, |
539 where F : Float + ToNalgebraRealField, |
| 568 I : AlgIteratorFactory<IterInfo<F, N>>, |
540 I : AlgIteratorFactory<IterInfo<F, N>>, |
| 569 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
541 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, |
| 570 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow |
|
| 571 A::Observable : std::ops::MulAssign<F>, |
|
| 572 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
542 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
| 573 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
543 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
| 574 + Lipschitz<&'a 𝒟, FloatType=F>, |
544 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
| 575 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
545 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
| 576 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
546 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
| 577 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
547 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
| 578 𝒟::Codomain : RealMapping<F, N>, |
548 𝒟::Codomain : RealMapping<F, N>, |
| 579 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
549 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| 580 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
550 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
| 581 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
551 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| 582 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
552 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
| 583 PlotLookup : Plotting<N>, |
553 PlotLookup : Plotting<N>, |
| 584 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
554 RNDM<F, N> : SpikeMerging<F>, |
| 585 Reg : RegTerm<F, N> { |
555 Reg : RegTerm<F, N> { |
| 586 |
556 |
| 587 // Set up parameters |
557 // Set up parameters |
| 588 let config = &fbconfig.generic; |
558 let config = &fbconfig.generic; |
| 589 let op𝒟norm = op𝒟.opnorm_bound(); |
559 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
| 590 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); |
560 let τ = fbconfig.τ0/opA.adjoint_product_bound(&op𝒟).unwrap(); |
| 591 let mut λ = 1.0; |
561 let mut λ = 1.0; |
| 592 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
562 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 593 // by τ compared to the conditional gradient approach. |
563 // by τ compared to the conditional gradient approach. |
| 594 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
564 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 595 let mut ε = tolerance.initial(); |
565 let mut ε = tolerance.initial(); |
| 596 |
566 |
| 597 // Initialise iterates |
567 // Initialise iterates |
| 598 let mut μ = DiscreteMeasure::new(); |
568 let mut μ = DiscreteMeasure::new(); |
| 599 let mut μ_prev = DiscreteMeasure::new(); |
569 let mut μ_prev = DiscreteMeasure::new(); |
| 600 let mut residual = -b; |
570 let mut residual = -b; |
| |
571 let mut warned_merging = false; |
| |
572 |
| |
573 // Statistics |
| |
574 let full_stats = |ν : &RNDM<F, N>, ε, stats| IterInfo { |
| |
575 value : L2Squared.calculate_fit_op(ν, opA, b) + reg.apply(ν), |
| |
576 n_spikes : ν.len(), |
| |
577 ε, |
| |
578 // postprocessing: config.postprocessing.then(|| ν.clone()), |
| |
579 .. stats |
| |
580 }; |
| 601 let mut stats = IterInfo::new(); |
581 let mut stats = IterInfo::new(); |
| 602 let mut warned_merging = false; |
|
| 603 |
582 |
| 604 // Run the algorithm |
583 // Run the algorithm |
| 605 iterator.iterate(|state| { |
584 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 606 // Calculate smooth part of surrogate model. |
585 // Calculate smooth part of surrogate model. |
| 607 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
586 let τv = opA.preadjoint().apply(residual * τ); |
| 608 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
| 609 // the residual and replacing it below before the end of this closure. |
|
| 610 residual *= -τ; |
|
| 611 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
|
| 612 let minus_τv = opA.preadjoint().apply(r); |
|
| 613 |
587 |
| 614 // Save current base point |
588 // Save current base point |
| 615 let μ_base = μ.clone(); |
589 let μ_base = μ.clone(); |
| 616 |
590 |
| 617 // Insert new spikes and reweigh |
591 // Insert new spikes and reweigh |
| 618 let (d, within_tolerances) = insert_and_reweigh( |
592 let (d, _within_tolerances) = insert_and_reweigh( |
| 619 &mut μ, &minus_τv, &μ_base, None, |
593 &mut μ, &τv, &μ_base, None, |
| 620 op𝒟, op𝒟norm, |
594 op𝒟, op𝒟norm, |
| 621 τ, ε, |
595 τ, ε, |
| 622 config, ®, state, &mut stats |
596 config, ®, &state, &mut stats |
| 623 ); |
597 ); |
| 624 |
598 |
| 625 // (Do not) merge spikes. |
599 // (Do not) merge spikes. |
| 626 if state.iteration() % config.merge_every == 0 { |
600 if config.merge_now(&state) { |
| 627 match config.merging { |
601 match config.merging { |
| 628 SpikeMergingMethod::None => { }, |
602 SpikeMergingMethod::None => { }, |
| 629 _ => if !warned_merging { |
603 _ => if !warned_merging { |
| 630 let err = format!("Merging not supported for μFISTA"); |
604 let err = format!("Merging not supported for μFISTA"); |
| 631 println!("{}", err.red()); |
605 println!("{}", err.red()); |
| 651 stats.pruned += n_before_prune - μ.len(); |
625 stats.pruned += n_before_prune - μ.len(); |
| 652 |
626 |
| 653 // Update residual |
627 // Update residual |
| 654 residual = calculate_residual(&μ, opA, b); |
628 residual = calculate_residual(&μ, opA, b); |
| 655 |
629 |
| |
630 let iter = state.iteration(); |
| |
631 stats.this_iters += 1; |
| |
632 |
| |
633 // Give statistics if needed |
| |
634 state.if_verbose(|| { |
| |
635 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ_prev); |
| |
636 full_stats(&μ_prev, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| |
637 }); |
| |
638 |
| 656 // Update main tolerance for next iteration |
639 // Update main tolerance for next iteration |
| 657 let ε_prev = ε; |
640 ε = tolerance.update(ε, iter); |
| 658 ε = tolerance.update(ε, state.iteration()); |
641 } |
| 659 stats.this_iters += 1; |
|
| 660 |
|
| 661 // Give function value if needed |
|
| 662 state.if_verbose(|| { |
|
| 663 // Plot if so requested |
|
| 664 plotter.plot_spikes( |
|
| 665 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, |
|
| 666 "start".to_string(), Some(&minus_τv), |
|
| 667 reg.target_bounds(τ, ε_prev), &μ_prev, |
|
| 668 ); |
|
| 669 // Calculate mean inner iterations and reset relevant counters. |
|
| 670 // Return the statistics |
|
| 671 let res = IterInfo { |
|
| 672 value : L2Squared.calculate_fit_op(&μ_prev, opA, b) + reg.apply(&μ_prev), |
|
| 673 n_spikes : μ_prev.len(), |
|
| 674 ε : ε_prev, |
|
| 675 postprocessing: config.postprocessing.then(|| μ_prev.clone()), |
|
| 676 .. stats |
|
| 677 }; |
|
| 678 stats = IterInfo::new(); |
|
| 679 res |
|
| 680 }) |
|
| 681 }); |
|
| 682 |
642 |
| 683 postprocess(μ_prev, config, L2Squared, opA, b) |
643 postprocess(μ_prev, config, L2Squared, opA, b) |
| 684 } |
644 } |