10 use itertools::izip; |
10 use itertools::izip; |
11 use std::iter::Iterator; |
11 use std::iter::Iterator; |
12 |
12 |
13 use alg_tools::iterate::AlgIteratorFactory; |
13 use alg_tools::iterate::AlgIteratorFactory; |
14 use alg_tools::euclidean::Euclidean; |
14 use alg_tools::euclidean::Euclidean; |
15 use alg_tools::sets::Cube; |
15 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; |
16 use alg_tools::loc::Loc; |
|
17 use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance}; |
|
18 use alg_tools::norms::Norm; |
16 use alg_tools::norms::Norm; |
19 use alg_tools::bisection_tree::{ |
|
20 BTFN, |
|
21 PreBTFN, |
|
22 Bounds, |
|
23 BTNodeLookup, |
|
24 BTNode, |
|
25 BTSearch, |
|
26 P2Minimise, |
|
27 SupportGenerator, |
|
28 LocalAnalysis, |
|
29 //Bounded, |
|
30 }; |
|
31 use alg_tools::mapping::RealMapping; |
|
32 use alg_tools::nalgebra_support::ToNalgebraRealField; |
17 use alg_tools::nalgebra_support::ToNalgebraRealField; |
33 use alg_tools::norms::{L2, Linfinity}; |
18 use alg_tools::norms::L2; |
34 |
19 |
35 use crate::types::*; |
20 use crate::types::*; |
36 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
21 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
37 use crate::measures::merging::{ |
22 use crate::measures::merging::SpikeMerging; |
38 SpikeMergingMethod, |
|
39 SpikeMerging, |
|
40 }; |
|
41 use crate::forward_model::{ |
23 use crate::forward_model::{ |
42 ForwardModel, |
24 ForwardModel, |
43 AdjointProductBoundedBy, |
25 AdjointProductBoundedBy, |
44 LipschitzValues, |
26 LipschitzValues, |
45 }; |
27 }; |
46 use crate::seminorms::DiscreteMeasureOp; |
|
47 //use crate::tolerance::Tolerance; |
28 //use crate::tolerance::Tolerance; |
48 use crate::plot::{ |
29 use crate::plot::{ |
49 SeqPlotter, |
30 SeqPlotter, |
50 Plotting, |
31 Plotting, |
51 PlotLookup |
32 PlotLookup |
351 /// splitting |
332 /// splitting |
352 /// |
333 /// |
353 /// The parametrisation is as for [`pointsource_fb_reg`]. |
334 /// The parametrisation is as for [`pointsource_fb_reg`]. |
354 /// Inertia is currently not supported. |
335 /// Inertia is currently not supported. |
355 #[replace_float_literals(F::cast_from(literal))] |
336 #[replace_float_literals(F::cast_from(literal))] |
356 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>( |
337 pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N : usize>( |
357 opA : &'a A, |
338 opA : &A, |
358 b : &A::Observable, |
339 b : &A::Observable, |
359 reg : Reg, |
340 reg : Reg, |
360 op𝒟 : &'a 𝒟, |
341 prox_penalty : &P, |
361 config : &SlidingFBConfig<F>, |
342 config : &SlidingFBConfig<F>, |
362 iterator : I, |
343 iterator : I, |
363 mut plotter : SeqPlotter<F, N>, |
344 mut plotter : SeqPlotter<F, N>, |
364 ) -> RNDM<F, N> |
345 ) -> RNDM<F, N> |
365 where F : Float + ToNalgebraRealField, |
346 where |
366 I : AlgIteratorFactory<IterInfo<F, N>>, |
347 F : Float + ToNalgebraRealField, |
367 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
348 I : AlgIteratorFactory<IterInfo<F, N>>, |
368 for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, |
349 A : ForwardModel<RNDM<F, N>, F> |
369 A::PreadjointCodomain : DifferentiableMapping< |
350 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
370 Loc<F, N>, DerivativeDomain=Loc<F, N>, Codomain=F |
351 //+ TransportLipschitz<L2Squared, FloatType=F>, |
371 >, |
352 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
372 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
353 for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, |
373 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
354 A::PreadjointCodomain : DifferentiableRealMapping<F, N>, |
374 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
355 RNDM<F, N> : SpikeMerging<F>, |
375 //+ TransportLipschitz<L2Squared, FloatType=F>, |
356 Reg : SlidingRegTerm<F, N>, |
376 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
357 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
377 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
358 PlotLookup : Plotting<N>, |
378 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, |
359 { |
379 Codomain = BTFN<F, G𝒟, BT𝒟, N>>, |
|
380 BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
381 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> |
|
382 + DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F,N>>, |
|
383 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
384 //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>, |
|
385 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
386 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
|
387 PlotLookup : Plotting<N>, |
|
388 RNDM<F, N> : SpikeMerging<F>, |
|
389 Reg : SlidingRegTerm<F, N> { |
|
390 |
360 |
391 // Check parameters |
361 // Check parameters |
392 assert!(config.τ0 > 0.0, "Invalid step length parameter"); |
362 assert!(config.τ0 > 0.0, "Invalid step length parameter"); |
393 config.transport.check(); |
363 config.transport.check(); |
394 |
364 |
396 let mut μ = DiscreteMeasure::new(); |
366 let mut μ = DiscreteMeasure::new(); |
397 let mut γ1 = DiscreteMeasure::new(); |
367 let mut γ1 = DiscreteMeasure::new(); |
398 let mut residual = -b; // Has to equal $Aμ-b$. |
368 let mut residual = -b; // Has to equal $Aμ-b$. |
399 |
369 |
400 // Set up parameters |
370 // Set up parameters |
401 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
|
402 let opAnorm = opA.opnorm_bound(Radon, L2); |
371 let opAnorm = opA.opnorm_bound(Radon, L2); |
403 //let max_transport = config.max_transport.scale |
372 //let max_transport = config.max_transport.scale |
404 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
373 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
405 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
374 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
406 let ℓ = 0.0; |
375 let ℓ = 0.0; |
407 let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap(); |
376 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
408 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); |
377 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); |
409 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { |
378 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { |
410 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v |
379 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v |
411 // (the uniform Lipschitz factor of ∇v). |
380 // (the uniform Lipschitz factor of ∇v). |
412 // We assume that the residual is decreasing. |
381 // We assume that the residual is decreasing. |
444 v, &config.transport, |
413 v, &config.transport, |
445 ); |
414 ); |
446 |
415 |
447 // Solve finite-dimensional subproblem several times until the dual variable for the |
416 // Solve finite-dimensional subproblem several times until the dual variable for the |
448 // regularisation term conforms to the assumptions made for the transport above. |
417 // regularisation term conforms to the assumptions made for the transport above. |
449 let (d, _within_tolerances, τv̆) = 'adapt_transport: loop { |
418 let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop { |
450 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
419 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
451 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); |
420 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); |
452 let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); |
421 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); |
453 |
422 |
454 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
423 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
455 let (d, within_tolerances) = insert_and_reweigh( |
424 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
456 &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), |
425 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), |
457 op𝒟, op𝒟norm, |
|
458 τ, ε, &config.insertion, |
426 τ, ε, &config.insertion, |
459 ®, &state, &mut stats, |
427 ®, &state, &mut stats, |
460 ); |
428 ); |
461 |
429 |
462 // A posteriori transport adaptation. |
430 // A posteriori transport adaptation. |
463 if aposteriori_transport( |
431 if aposteriori_transport( |
464 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, |
432 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, |
465 ε, &config.transport |
433 ε, &config.transport |
466 ) { |
434 ) { |
467 break 'adapt_transport (d, within_tolerances, τv̆) |
435 break 'adapt_transport (maybe_d, within_tolerances, τv̆) |
468 } |
436 } |
469 }; |
437 }; |
470 |
438 |
471 stats.untransported_fraction = Some({ |
439 stats.untransported_fraction = Some({ |
472 assert_eq!(μ_base_masses.len(), γ1.len()); |
440 assert_eq!(μ_base_masses.len(), γ1.len()); |
478 assert_eq!(μ_base_masses.len(), γ1.len()); |
446 assert_eq!(μ_base_masses.len(), γ1.len()); |
479 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); |
447 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); |
480 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) |
448 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) |
481 }); |
449 }); |
482 |
450 |
483 // Merge spikes. |
451 // // Merge spikes. |
484 // This expects the prune below to prune γ. |
452 // // This expects the prune below to prune γ. |
485 // TODO: This may not work correctly in all cases. |
453 // // TODO: This may not work correctly in all cases. |
486 let ins = &config.insertion; |
454 // let ins = &config.insertion; |
487 if ins.merge_now(&state) { |
455 // if ins.merge_now(&state) { |
488 if let SpikeMergingMethod::None = ins.merging { |
456 // if let SpikeMergingMethod::None = ins.merging { |
489 } else { |
457 // } else { |
490 stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { |
458 // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { |
491 let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; |
459 // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; |
492 let mut d = &τv̆ + op𝒟.preapply(ν); |
460 // let mut d = &τv̆ + op𝒟.preapply(ν); |
493 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) |
461 // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) |
494 }); |
462 // }); |
495 } |
463 // } |
496 } |
464 // } |
497 |
465 |
498 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the |
466 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the |
499 // latter needs to be pruned when μ is. |
467 // latter needs to be pruned when μ is. |
500 // TODO: This could do with a two-vector Vec::retain to avoid copies. |
468 // TODO: This could do with a two-vector Vec::retain to avoid copies. |
501 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); |
469 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); |
512 let iter = state.iteration(); |
480 let iter = state.iteration(); |
513 stats.this_iters += 1; |
481 stats.this_iters += 1; |
514 |
482 |
515 // Give statistics if requested |
483 // Give statistics if requested |
516 state.if_verbose(|| { |
484 state.if_verbose(|| { |
517 plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); |
485 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); |
518 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
486 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
519 }); |
487 }); |
520 |
488 |
521 // Update main tolerance for next iteration |
489 // Update main tolerance for next iteration |
522 ε = tolerance.update(ε, iter); |
490 ε = tolerance.update(ε, iter); |