| 10 use itertools::izip; | 10 use itertools::izip; | 
| 11 use std::iter::Iterator; | 11 use std::iter::Iterator; | 
| 12 | 12 | 
| 13 use alg_tools::iterate::AlgIteratorFactory; | 13 use alg_tools::iterate::AlgIteratorFactory; | 
| 14 use alg_tools::euclidean::Euclidean; | 14 use alg_tools::euclidean::Euclidean; | 
| 15 use alg_tools::sets::Cube; | 15 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; | 
| 16 use alg_tools::loc::Loc; |  | 
| 17 use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance}; |  | 
| 18 use alg_tools::norms::Norm; | 16 use alg_tools::norms::Norm; | 
| 19 use alg_tools::bisection_tree::{ |  | 
| 20     BTFN, |  | 
| 21     PreBTFN, |  | 
| 22     Bounds, |  | 
| 23     BTNodeLookup, |  | 
| 24     BTNode, |  | 
| 25     BTSearch, |  | 
| 26     P2Minimise, |  | 
| 27     SupportGenerator, |  | 
| 28     LocalAnalysis, |  | 
| 29     //Bounded, |  | 
| 30 }; |  | 
| 31 use alg_tools::mapping::RealMapping; |  | 
| 32 use alg_tools::nalgebra_support::ToNalgebraRealField; | 17 use alg_tools::nalgebra_support::ToNalgebraRealField; | 
| 33 use alg_tools::norms::{L2, Linfinity}; | 18 use alg_tools::norms::L2; | 
| 34 | 19 | 
| 35 use crate::types::*; | 20 use crate::types::*; | 
| 36 use crate::measures::{DiscreteMeasure, Radon, RNDM}; | 21 use crate::measures::{DiscreteMeasure, Radon, RNDM}; | 
| 37 use crate::measures::merging::{ | 22 use crate::measures::merging::SpikeMerging; | 
| 38     SpikeMergingMethod, |  | 
| 39     SpikeMerging, |  | 
| 40 }; |  | 
| 41 use crate::forward_model::{ | 23 use crate::forward_model::{ | 
| 42     ForwardModel, | 24     ForwardModel, | 
| 43     AdjointProductBoundedBy, | 25     AdjointProductBoundedBy, | 
| 44     LipschitzValues, | 26     LipschitzValues, | 
| 45 }; | 27 }; | 
| 46 use crate::seminorms::DiscreteMeasureOp; |  | 
| 47 //use crate::tolerance::Tolerance; | 28 //use crate::tolerance::Tolerance; | 
| 48 use crate::plot::{ | 29 use crate::plot::{ | 
| 49     SeqPlotter, | 30     SeqPlotter, | 
| 50     Plotting, | 31     Plotting, | 
| 51     PlotLookup | 32     PlotLookup | 
| 351 /// splitting | 332 /// splitting | 
| 352 /// | 333 /// | 
| 353 /// The parametrisation is as for [`pointsource_fb_reg`]. | 334 /// The parametrisation is as for [`pointsource_fb_reg`]. | 
| 354 /// Inertia is currently not supported. | 335 /// Inertia is currently not supported. | 
| 355 #[replace_float_literals(F::cast_from(literal))] | 336 #[replace_float_literals(F::cast_from(literal))] | 
| 356 pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>( | 337 pub fn pointsource_sliding_fb_reg<F, I, A, Reg, P, const N : usize>( | 
| 357     opA : &'a A, | 338     opA : &A, | 
| 358     b : &A::Observable, | 339     b : &A::Observable, | 
| 359     reg : Reg, | 340     reg : Reg, | 
| 360     op𝒟 : &'a 𝒟, | 341     prox_penalty : &P, | 
| 361     config : &SlidingFBConfig<F>, | 342     config : &SlidingFBConfig<F>, | 
| 362     iterator : I, | 343     iterator : I, | 
| 363     mut plotter : SeqPlotter<F, N>, | 344     mut plotter : SeqPlotter<F, N>, | 
| 364 ) -> RNDM<F, N> | 345 ) -> RNDM<F, N> | 
| 365 where F : Float + ToNalgebraRealField, | 346 where | 
| 366       I : AlgIteratorFactory<IterInfo<F, N>>, | 347     F : Float + ToNalgebraRealField, | 
| 367       for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, | 348     I : AlgIteratorFactory<IterInfo<F, N>>, | 
| 368       for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, | 349     A : ForwardModel<RNDM<F, N>, F> | 
| 369       A::PreadjointCodomain : DifferentiableMapping< | 350         + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, | 
| 370         Loc<F, N>, DerivativeDomain=Loc<F, N>, Codomain=F | 351         //+ TransportLipschitz<L2Squared, FloatType=F>, | 
| 371       >, | 352     for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, | 
| 372       GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, | 353     for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>, | 
| 373       A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> | 354     A::PreadjointCodomain : DifferentiableRealMapping<F, N>, | 
| 374           + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, | 355     RNDM<F, N> : SpikeMerging<F>, | 
| 375           //+ TransportLipschitz<L2Squared, FloatType=F>, | 356     Reg : SlidingRegTerm<F, N>, | 
| 376       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, | 357     P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, | 
| 377       G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, | 358     PlotLookup : Plotting<N>, | 
| 378       𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, | 359 { | 
| 379                                           Codomain = BTFN<F, G𝒟, BT𝒟, N>>, |  | 
| 380       BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |  | 
| 381       S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N> |  | 
| 382          + DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F,N>>, |  | 
| 383       K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |  | 
| 384          //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>, |  | 
| 385       BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |  | 
| 386       Cube<F, N>: P2Minimise<Loc<F, N>, F>, |  | 
| 387       PlotLookup : Plotting<N>, |  | 
| 388       RNDM<F, N> : SpikeMerging<F>, |  | 
| 389       Reg : SlidingRegTerm<F, N> { |  | 
| 390 | 360 | 
| 391     // Check parameters | 361     // Check parameters | 
| 392     assert!(config.τ0 > 0.0, "Invalid step length parameter"); | 362     assert!(config.τ0 > 0.0, "Invalid step length parameter"); | 
| 393     config.transport.check(); | 363     config.transport.check(); | 
| 394 | 364 | 
| 396     let mut μ = DiscreteMeasure::new(); | 366     let mut μ = DiscreteMeasure::new(); | 
| 397     let mut γ1 = DiscreteMeasure::new(); | 367     let mut γ1 = DiscreteMeasure::new(); | 
| 398     let mut residual = -b; // Has to equal $Aμ-b$. | 368     let mut residual = -b; // Has to equal $Aμ-b$. | 
| 399 | 369 | 
| 400     // Set up parameters | 370     // Set up parameters | 
| 401     let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |  | 
| 402     let opAnorm = opA.opnorm_bound(Radon, L2); | 371     let opAnorm = opA.opnorm_bound(Radon, L2); | 
| 403     //let max_transport = config.max_transport.scale | 372     //let max_transport = config.max_transport.scale | 
| 404     //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0); | 373     //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0); | 
| 405     //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; | 374     //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; | 
| 406     let ℓ = 0.0; | 375     let ℓ = 0.0; | 
| 407     let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap(); | 376     let τ = config.τ0 / opA.adjoint_product_bound(prox_penalty).unwrap(); | 
| 408     let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); | 377     let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v)); | 
| 409     let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { | 378     let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() { | 
| 410         // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v | 379         // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v | 
| 411         // (the uniform Lipschitz factor of ∇v). | 380         // (the uniform Lipschitz factor of ∇v). | 
| 412         // We assume that the residual is decreasing. | 381         // We assume that the residual is decreasing. | 
| 444             v, &config.transport, | 413             v, &config.transport, | 
| 445         ); | 414         ); | 
| 446 | 415 | 
| 447         // Solve finite-dimensional subproblem several times until the dual variable for the | 416         // Solve finite-dimensional subproblem several times until the dual variable for the | 
| 448         // regularisation term conforms to the assumptions made for the transport above. | 417         // regularisation term conforms to the assumptions made for the transport above. | 
| 449         let (d, _within_tolerances, τv̆) = 'adapt_transport: loop { | 418         let (maybe_d, _within_tolerances, τv̆) = 'adapt_transport: loop { | 
| 450             // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) | 419             // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) | 
| 451             let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); | 420             let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b); | 
| 452             let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); | 421             let mut τv̆ = opA.preadjoint().apply(residual_μ̆ * τ); | 
| 453 | 422 | 
| 454             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. | 423             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. | 
| 455             let (d, within_tolerances) = insert_and_reweigh( | 424             let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( | 
| 456                 &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0), | 425                 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), | 
| 457                 op𝒟, op𝒟norm, |  | 
| 458                 τ, ε, &config.insertion, | 426                 τ, ε, &config.insertion, | 
| 459                 ®, &state, &mut stats, | 427                 ®, &state, &mut stats, | 
| 460             ); | 428             ); | 
| 461 | 429 | 
| 462             // A posteriori transport adaptation. | 430             // A posteriori transport adaptation. | 
| 463             if aposteriori_transport( | 431             if aposteriori_transport( | 
| 464                 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, | 432                 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, | 
| 465                 ε, &config.transport | 433                 ε, &config.transport | 
| 466             ) { | 434             ) { | 
| 467                 break 'adapt_transport (d, within_tolerances, τv̆) | 435                 break 'adapt_transport (maybe_d, within_tolerances, τv̆) | 
| 468             } | 436             } | 
| 469         }; | 437         }; | 
| 470 | 438 | 
| 471         stats.untransported_fraction = Some({ | 439         stats.untransported_fraction = Some({ | 
| 472             assert_eq!(μ_base_masses.len(), γ1.len()); | 440             assert_eq!(μ_base_masses.len(), γ1.len()); | 
| 478             assert_eq!(μ_base_masses.len(), γ1.len()); | 446             assert_eq!(μ_base_masses.len(), γ1.len()); | 
| 479             let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); | 447             let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0)); | 
| 480             (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) | 448             (a + μ.dist_matching(&γ1), b + γ1.norm(Radon)) | 
| 481         }); | 449         }); | 
| 482 | 450 | 
| 483         // Merge spikes. | 451         // // Merge spikes. | 
| 484         // This expects the prune below to prune γ. | 452         // // This expects the prune below to prune γ. | 
| 485         // TODO: This may not work correctly in all cases. | 453         // // TODO: This may not work correctly in all cases. | 
| 486         let ins = &config.insertion; | 454         // let ins = &config.insertion; | 
| 487         if ins.merge_now(&state) { | 455         // if ins.merge_now(&state) { | 
| 488             if let SpikeMergingMethod::None = ins.merging { | 456         //     if let SpikeMergingMethod::None = ins.merging { | 
| 489             } else { | 457         //     } else { | 
| 490                 stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { | 458         //         stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { | 
| 491                     let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; | 459         //             let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; | 
| 492                     let mut d = &τv̆ + op𝒟.preapply(ν); | 460         //             let mut d = &τv̆ + op𝒟.preapply(ν); | 
| 493                     reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) | 461         //             reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) | 
| 494                 }); | 462         //         }); | 
| 495             } | 463         //     } | 
| 496         } | 464         // } | 
| 497 | 465 | 
| 498         // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the | 466         // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the | 
| 499         // latter needs to be pruned when μ is. | 467         // latter needs to be pruned when μ is. | 
| 500         // TODO: This could do with a two-vector Vec::retain to avoid copies. | 468         // TODO: This could do with a two-vector Vec::retain to avoid copies. | 
| 501         let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); | 469         let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned()); | 
| 512         let iter = state.iteration(); | 480         let iter = state.iteration(); | 
| 513         stats.this_iters += 1; | 481         stats.this_iters += 1; | 
| 514 | 482 | 
| 515         // Give statistics if requested | 483         // Give statistics if requested | 
| 516         state.if_verbose(|| { | 484         state.if_verbose(|| { | 
| 517             plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ); | 485             plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv̆), &μ); | 
| 518             full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) | 486             full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new())) | 
| 519         }); | 487         }); | 
| 520 | 488 | 
| 521         // Update main tolerance for next iteration | 489         // Update main tolerance for next iteration | 
| 522         ε = tolerance.update(ε, iter); | 490         ε = tolerance.update(ε, iter); |