| 10 use itertools::izip; |
10 use itertools::izip; |
| 11 use std::iter::Iterator; |
11 use std::iter::Iterator; |
| 12 |
12 |
| 13 use alg_tools::iterate::AlgIteratorFactory; |
13 use alg_tools::iterate::AlgIteratorFactory; |
| 14 use alg_tools::euclidean::Euclidean; |
14 use alg_tools::euclidean::Euclidean; |
| 15 use alg_tools::sets::Cube; |
15 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; |
| 16 use alg_tools::loc::Loc; |
|
| 17 use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance}; |
|
| 18 use alg_tools::norms::Norm; |
16 use alg_tools::norms::Norm; |
| 19 use alg_tools::bisection_tree::{ |
|
| 20 BTFN, |
|
| 21 PreBTFN, |
|
| 22 Bounds, |
|
| 23 BTNodeLookup, |
|
| 24 BTNode, |
|
| 25 BTSearch, |
|
| 26 P2Minimise, |
|
| 27 SupportGenerator, |
|
| 28 LocalAnalysis, |
|
| 29 //Bounded, |
|
| 30 }; |
|
| 31 use alg_tools::mapping::RealMapping; |
|
| 32 use alg_tools::nalgebra_support::ToNalgebraRealField; |
17 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 33 use alg_tools::norms::{L2, Linfinity}; |
18 use alg_tools::norms::L2; |
| 34 |
19 |
| 35 use crate::types::*; |
20 use crate::types::*; |
| 36 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
21 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
| 37 use crate::measures::merging::{ |
22 use crate::measures::merging::SpikeMerging; |
| 38 SpikeMergingMethod, |
|
| 39 SpikeMerging, |
|
| 40 }; |
|
| 41 use crate::forward_model::{ |
23 use crate::forward_model::{ |
| 42 ForwardModel, |
24 ForwardModel, |
| 43 AdjointProductBoundedBy, |
25 AdjointProductBoundedBy, |
| 44 LipschitzValues, |
26 LipschitzValues, |
| 45 }; |
27 }; |
| 46 use crate::seminorms::DiscreteMeasureOp; |
|
| 47 //use crate::tolerance::Tolerance; |
28 //use crate::tolerance::Tolerance; |
| 48 use crate::plot::{ |
29 use crate::plot::{ |
| 49 SeqPlotter, |
30 SeqPlotter, |
| 50 Plotting, |
31 Plotting, |
| 51 PlotLookup |
32 PlotLookup |
| 351 /// splitting |
332 /// splitting |
| 352 /// |
333 /// |
| 353 /// The parametrisation is as for [`pointsource_fb_reg`]. |
334 /// The parametrisation is as for [`pointsource_fb_reg`]. |
| 354 /// Inertia is currently not supported. |
335 /// Inertia is currently not supported. |
| 355 #[replace_float_literals(F::cast_from(literal))] |
336 #[replace_float_literals(F::cast_from(literal))] |
| 356 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>( |
337 pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N : usize>( |
| 357 opA : &'a A, |
338 opA : &A, |
| 358 b : &A::Observable, |
339 b : &A::Observable, |
| 359 reg : Reg, |
340 reg : Reg, |
| 360 op𝒟 : &'a 𝒟, |
341 prox_penalty : &P, |
| 361 config : &SlidingFBConfig<F>, |
342 config : &SlidingFBConfig<F>, |
| 362 iterator : I, |
343 iterator : I, |
| 363 mut plotter : SeqPlotter<F, N>, |
344 mut plotter : SeqPlotter<F, N>, |
| 364 ) -> RNDM<F, N> |
345 ) -> RNDM<F, N> |
| 365 where F : Float + ToNalgebraRealField, |
346 where |
| 366 I : AlgIteratorFactory<IterInfo<F, N>>, |
347 F : Float + ToNalgebraRealField, |
| 367 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
348 I : AlgIteratorFactory<IterInfo<F, N>>, |
| 368 for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, |
349 A : ForwardModel<RNDM<F, N>, F> |
| 369 A::PreadjointCodomain : DifferentiableMapping< |
350 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
| 370 Loc<F, N>, DerivativeDomain=Loc<F, N>, Codomain=F |
351 //+ TransportLipschitz<L2Squared, FloatType=F>, |
| 371 >, |
352 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
| 372 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
353 for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, |
| 373 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
354 A::PreadjointCodomain : DifferentiableRealMapping<F, N>, |
| 374 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
355 RNDM<F, N> : SpikeMerging<F>, |
| 375 //+ TransportLipschitz<L2Squared, FloatType=F>, |
356 Reg : SlidingRegTerm<F, N>, |
| 376 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
357 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
| 377 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
358 PlotLookup : Plotting<N>, |
| 378 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, |
359 { |
| 379 Codomain = BTFN<F, G𝒟, BT𝒟, N>>, |
|
| 380 BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
| 381 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> |
|
| 382 + DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F,N>>, |
|
| 383 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 384 //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>, |
|
| 385 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 386 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
|
| 387 PlotLookup : Plotting<N>, |
|
| 388 RNDM<F, N> : SpikeMerging<F>, |
|
| 389 Reg : SlidingRegTerm<F, N> { |
|
| 390 |
360 |
| 391 // Check parameters |
361 // Check parameters |
| 392 assert!(config.τ0 > 0.0, "Invalid step length parameter"); |
362 assert!(config.τ0 > 0.0, "Invalid step length parameter"); |
| 393 config.transport.check(); |
363 config.transport.check(); |
| 394 |
364 |
| 396 let mut μ = DiscreteMeasure::new(); |
366 let mut μ = DiscreteMeasure::new(); |
| 397 let mut γ1 = DiscreteMeasure::new(); |
367 let mut γ1 = DiscreteMeasure::new(); |
| 398 let mut residual = -b; // Has to equal $Aμ-b$. |
368 let mut residual = -b; // Has to equal $Aμ-b$. |
| 399 |
369 |
| 400 // Set up parameters |
370 // Set up parameters |
| 401 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
|
| 402 let opAnorm = opA.opnorm_bound(Radon, L2); |
371 let opAnorm = opA.opnorm_bound(Radon, L2); |
| 403 //let max_transport = config.max_transport.scale |
372 //let max_transport = config.max_transport.scale |
| 404 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
373 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
| 405 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
374 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
| 406 let ℓ = 0.0; |
375 let ℓ = 0.0; |
| 407 let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap(); |
376 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); |
| 408 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); |
377 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); |
| 409 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { |
378 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { |
| 410 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v |
379 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v |
| 411 // (the uniform Lipschitz factor of ∇v). |
380 // (the uniform Lipschitz factor of ∇v). |
| 412 // We assume that the residual is decreasing. |
381 // We assume that the residual is decreasing. |
| 444 v, &config.transport, |
413 v, &config.transport, |
| 445 ); |
414 ); |
| 446 |
415 |
| 447 // Solve finite-dimensional subproblem several times until the dual variable for the |
416 // Solve finite-dimensional subproblem several times until the dual variable for the |
| 448 // regularisation term conforms to the assumptions made for the transport above. |
417 // regularisation term conforms to the assumptions made for the transport above. |
| 449 let (d, _within_tolerances, τv̆) = 'adapt_transport: loop { |
418 let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop { |
| 450 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
419 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
| 451 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); |
420 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); |
| 452 let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); |
421 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); |
| 453 |
422 |
| 454 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
423 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
| 455 let (d, within_tolerances) = insert_and_reweigh( |
424 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
| 456 &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), |
425 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), |
| 457 op𝒟, op𝒟norm, |
|
| 458 τ, ε, &config.insertion, |
426 τ, ε, &config.insertion, |
| 459 ®, &state, &mut stats, |
427 ®, &state, &mut stats, |
| 460 ); |
428 ); |
| 461 |
429 |
| 462 // A posteriori transport adaptation. |
430 // A posteriori transport adaptation. |
| 463 if aposteriori_transport( |
431 if aposteriori_transport( |
| 464 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, |
432 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, |
| 465 ε, &config.transport |
433 ε, &config.transport |
| 466 ) { |
434 ) { |
| 467 break 'adapt_transport (d, within_tolerances, τv̆) |
435 break 'adapt_transport (maybe_d, within_tolerances, τv̆) |
| 468 } |
436 } |
| 469 }; |
437 }; |
| 470 |
438 |
| 471 stats.untransported_fraction = Some({ |
439 stats.untransported_fraction = Some({ |
| 472 assert_eq!(μ_base_masses.len(), γ1.len()); |
440 assert_eq!(μ_base_masses.len(), γ1.len()); |
| 478 assert_eq!(μ_base_masses.len(), γ1.len()); |
446 assert_eq!(μ_base_masses.len(), γ1.len()); |
| 479 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); |
447 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); |
| 480 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) |
448 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) |
| 481 }); |
449 }); |
| 482 |
450 |
| 483 // Merge spikes. |
451 // // Merge spikes. |
| 484 // This expects the prune below to prune γ. |
452 // // This expects the prune below to prune γ. |
| 485 // TODO: This may not work correctly in all cases. |
453 // // TODO: This may not work correctly in all cases. |
| 486 let ins = &config.insertion; |
454 // let ins = &config.insertion; |
| 487 if ins.merge_now(&state) { |
455 // if ins.merge_now(&state) { |
| 488 if let SpikeMergingMethod::None = ins.merging { |
456 // if let SpikeMergingMethod::None = ins.merging { |
| 489 } else { |
457 // } else { |
| 490 stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { |
458 // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { |
| 491 let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; |
459 // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; |
| 492 let mut d = &τv̆ + op𝒟.preapply(ν); |
460 // let mut d = &τv̆ + op𝒟.preapply(ν); |
| 493 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) |
461 // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) |
| 494 }); |
462 // }); |
| 495 } |
463 // } |
| 496 } |
464 // } |
| 497 |
465 |
| 498 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the |
466 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the |
| 499 // latter needs to be pruned when μ is. |
467 // latter needs to be pruned when μ is. |
| 500 // TODO: This could do with a two-vector Vec::retain to avoid copies. |
468 // TODO: This could do with a two-vector Vec::retain to avoid copies. |
| 501 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); |
469 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); |
| 512 let iter = state.iteration(); |
480 let iter = state.iteration(); |
| 513 stats.this_iters += 1; |
481 stats.this_iters += 1; |
| 514 |
482 |
| 515 // Give statistics if requested |
483 // Give statistics if requested |
| 516 state.if_verbose(|| { |
484 state.if_verbose(|| { |
| 517 plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); |
485 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); |
| 518 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
486 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| 519 }); |
487 }); |
| 520 |
488 |
| 521 // Update main tolerance for next iteration |
489 // Update main tolerance for next iteration |
| 522 ε = tolerance.update(ε, iter); |
490 ε = tolerance.update(ε, iter); |