src/sliding_fb.rs

branch
dev
changeset 37
c5d8bd1a7728
parent 35
b087e3eab191
child 39
6316d68b58af
equal deleted inserted replaced
36:fb911f72e698 37:c5d8bd1a7728
10 use itertools::izip; 10 use itertools::izip;
11 use std::iter::Iterator; 11 use std::iter::Iterator;
12 12
13 use alg_tools::iterate::AlgIteratorFactory; 13 use alg_tools::iterate::AlgIteratorFactory;
14 use alg_tools::euclidean::Euclidean; 14 use alg_tools::euclidean::Euclidean;
15 use alg_tools::sets::Cube; 15 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance};
16 use alg_tools::loc::Loc;
17 use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance};
18 use alg_tools::norms::Norm; 16 use alg_tools::norms::Norm;
19 use alg_tools::bisection_tree::{
20 BTFN,
21 PreBTFN,
22 Bounds,
23 BTNodeLookup,
24 BTNode,
25 BTSearch,
26 P2Minimise,
27 SupportGenerator,
28 LocalAnalysis,
29 //Bounded,
30 };
31 use alg_tools::mapping::RealMapping;
32 use alg_tools::nalgebra_support::ToNalgebraRealField; 17 use alg_tools::nalgebra_support::ToNalgebraRealField;
33 use alg_tools::norms::{L2, Linfinity}; 18 use alg_tools::norms::L2;
34 19
35 use crate::types::*; 20 use crate::types::*;
36 use crate::measures::{DiscreteMeasure, Radon, RNDM}; 21 use crate::measures::{DiscreteMeasure, Radon, RNDM};
37 use crate::measures::merging::{ 22 use crate::measures::merging::SpikeMerging;
38 SpikeMergingMethod,
39 SpikeMerging,
40 };
41 use crate::forward_model::{ 23 use crate::forward_model::{
42 ForwardModel, 24 ForwardModel,
43 AdjointProductBoundedBy, 25 AdjointProductBoundedBy,
44 LipschitzValues, 26 LipschitzValues,
45 }; 27 };
46 use crate::seminorms::DiscreteMeasureOp;
47 //use crate::tolerance::Tolerance; 28 //use crate::tolerance::Tolerance;
48 use crate::plot::{ 29 use crate::plot::{
49 SeqPlotter, 30 SeqPlotter,
50 Plotting, 31 Plotting,
51 PlotLookup 32 PlotLookup
149 F : Float + ToNalgebraRealField, 130 F : Float + ToNalgebraRealField,
150 G : Fn(F, F) -> F, 131 G : Fn(F, F) -> F,
151 Observable : Euclidean<F, Output=Observable>, 132 Observable : Euclidean<F, Output=Observable>,
152 for<'a> &'a Observable : Instance<Observable>, 133 for<'a> &'a Observable : Instance<Observable>,
153 //for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, 134 //for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
154 D : DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F, N>>, 135 D : DifferentiableRealMapping<F, N>,
155 { 136 {
156 137
157 use TransportStepLength::*; 138 use TransportStepLength::*;
158 139
159 // Save current base point and shift μ to new positions. Idea is that 140 // Save current base point and shift μ to new positions. Idea is that
351 /// splitting 332 /// splitting
352 /// 333 ///
353 /// The parametrisation is as for [`pointsource_fb_reg`]. 334 /// The parametrisation is as for [`pointsource_fb_reg`].
354 /// Inertia is currently not supported. 335 /// Inertia is currently not supported.
355 #[replace_float_literals(F::cast_from(literal))] 336 #[replace_float_literals(F::cast_from(literal))]
356 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>( 337 pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N : usize>(
357 opA : &'a A, 338 opA : &A,
358 b : &A::Observable, 339 b : &A::Observable,
359 reg : Reg, 340 reg : Reg,
360 op𝒟 : &'a 𝒟, 341 prox_penalty : &P,
361 config : &SlidingFBConfig<F>, 342 config : &SlidingFBConfig<F>,
362 iterator : I, 343 iterator : I,
363 mut plotter : SeqPlotter<F, N>, 344 mut plotter : SeqPlotter<F, N>,
364 ) -> RNDM<F, N> 345 ) -> RNDM<F, N>
365 where F : Float + ToNalgebraRealField, 346 where
366 I : AlgIteratorFactory<IterInfo<F, N>>, 347 F : Float + ToNalgebraRealField,
367 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, 348 I : AlgIteratorFactory<IterInfo<F, N>>,
368 for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, 349 A : ForwardModel<RNDM<F, N>, F>
369 A::PreadjointCodomain : DifferentiableMapping< 350 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>,
370 Loc<F, N>, DerivativeDomain=Loc<F, N>, Codomain=F 351 //+ TransportLipschitz<L2Squared, FloatType=F>,
371 >, 352 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
372 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 353 for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
373 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> 354 A::PreadjointCodomain : DifferentiableRealMapping<F, N>,
374 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, 355 RNDM<F, N> : SpikeMerging<F>,
375 //+ TransportLipschitz<L2Squared, FloatType=F>, 356 Reg : SlidingRegTerm<F, N>,
376 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 357 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
377 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 358 PlotLookup : Plotting<N>,
378 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, 359 {
379 Codomain = BTFN<F, G𝒟, BT𝒟, N>>,
380 BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
381 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>
382 + DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F,N>>,
383 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
384 //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>,
385 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
386 Cube<F, N>: P2Minimise<Loc<F, N>, F>,
387 PlotLookup : Plotting<N>,
388 RNDM<F, N> : SpikeMerging<F>,
389 Reg : SlidingRegTerm<F, N> {
390 360
391 // Check parameters 361 // Check parameters
392 assert!(config.τ0 > 0.0, "Invalid step length parameter"); 362 assert!(config.τ0 > 0.0, "Invalid step length parameter");
393 config.transport.check(); 363 config.transport.check();
394 364
396 let mut μ = DiscreteMeasure::new(); 366 let mut μ = DiscreteMeasure::new();
397 let mut γ1 = DiscreteMeasure::new(); 367 let mut γ1 = DiscreteMeasure::new();
398 let mut residual = -b; // Has to equal $Aμ-b$. 368 let mut residual = -b; // Has to equal $Aμ-b$.
399 369
400 // Set up parameters 370 // Set up parameters
401 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
402 let opAnorm = opA.opnorm_bound(Radon, L2); 371 let opAnorm = opA.opnorm_bound(Radon, L2);
403 //let max_transport = config.max_transport.scale 372 //let max_transport = config.max_transport.scale
404 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); 373 // * reg.radon_norm_bound(b.norm2_squared() / 2.0);
405 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; 374 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
406 let ℓ = 0.0; 375 let ℓ = 0.0;
407 let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap(); 376 let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap();
408 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); 377 let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v));
409 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { 378 let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() {
410 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v 379 // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v
411 // (the uniform Lipschitz factor of ∇v). 380 // (the uniform Lipschitz factor of ∇v).
412 // We assume that the residual is decreasing. 381 // We assume that the residual is decreasing.
444 v, &config.transport, 413 v, &config.transport,
445 ); 414 );
446 415
447 // Solve finite-dimensional subproblem several times until the dual variable for the 416 // Solve finite-dimensional subproblem several times until the dual variable for the
448 // regularisation term conforms to the assumptions made for the transport above. 417 // regularisation term conforms to the assumptions made for the transport above.
449 let (d, _within_tolerances, τv̆) = 'adapt_transport: loop { 418 let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop {
450 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) 419 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
451 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); 420 let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
452 let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); 421 let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ);
453 422
454 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 423 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
455 let (d, within_tolerances) = insert_and_reweigh( 424 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
456 &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), 425 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0),
457 op𝒟, op𝒟norm,
458 τ, ε, &config.insertion, 426 τ, ε, &config.insertion,
459 &reg, &state, &mut stats, 427 &reg, &state, &mut stats,
460 ); 428 );
461 429
462 // A posteriori transport adaptation. 430 // A posteriori transport adaptation.
463 if aposteriori_transport( 431 if aposteriori_transport(
464 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, 432 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses,
465 ε, &config.transport 433 ε, &config.transport
466 ) { 434 ) {
467 break 'adapt_transport (d, within_tolerances, τv̆) 435 break 'adapt_transport (maybe_d, within_tolerances, τv̆)
468 } 436 }
469 }; 437 };
470 438
471 stats.untransported_fraction = Some({ 439 stats.untransported_fraction = Some({
472 assert_eq!(μ_base_masses.len(), γ1.len()); 440 assert_eq!(μ_base_masses.len(), γ1.len());
478 assert_eq!(μ_base_masses.len(), γ1.len()); 446 assert_eq!(μ_base_masses.len(), γ1.len());
479 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); 447 let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
480 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) 448 (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
481 }); 449 });
482 450
483 // Merge spikes. 451 // // Merge spikes.
484 // This expects the prune below to prune γ. 452 // // This expects the prune below to prune γ.
485 // TODO: This may not work correctly in all cases. 453 // // TODO: This may not work correctly in all cases.
486 let ins = &config.insertion; 454 // let ins = &config.insertion;
487 if ins.merge_now(&state) { 455 // if ins.merge_now(&state) {
488 if let SpikeMergingMethod::None = ins.merging { 456 // if let SpikeMergingMethod::None = ins.merging {
489 } else { 457 // } else {
490 stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { 458 // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| {
491 let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; 459 // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0;
492 let mut d = &τv̆ + op𝒟.preapply(ν); 460 // let mut d = &τv̆ + op𝒟.preapply(ν);
493 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) 461 // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins)
494 }); 462 // });
495 } 463 // }
496 } 464 // }
497 465
498 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the 466 // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
499 // latter needs to be pruned when μ is. 467 // latter needs to be pruned when μ is.
500 // TODO: This could do with a two-vector Vec::retain to avoid copies. 468 // TODO: This could do with a two-vector Vec::retain to avoid copies.
501 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); 469 let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
512 let iter = state.iteration(); 480 let iter = state.iteration();
513 stats.this_iters += 1; 481 stats.this_iters += 1;
514 482
515 // Give statistics if requested 483 // Give statistics if requested
516 state.if_verbose(|| { 484 state.if_verbose(|| {
517 plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); 485 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ);
518 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) 486 full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
519 }); 487 });
520 488
521 // Update main tolerance for next iteration 489 // Update main tolerance for next iteration
522 ε = tolerance.update(ε, iter); 490 ε = tolerance.update(ε, iter);

mercurial