src/run.rs

branch
dev
changeset 35
b087e3eab191
parent 34
efa60bc4f743
child 37
c5d8bd1a7728
equal deleted inserted replaced
34:efa60bc4f743 35:b087e3eab191
24 use alg_tools::iterate::{ 24 use alg_tools::iterate::{
25 Timed, 25 Timed,
26 AlgIteratorOptions, 26 AlgIteratorOptions,
27 Verbose, 27 Verbose,
28 AlgIteratorFactory, 28 AlgIteratorFactory,
29 LoggingIteratorFactory,
30 TimingIteratorFactory,
31 BasicAlgIteratorFactory,
29 }; 32 };
30 use alg_tools::logger::Logger; 33 use alg_tools::logger::Logger;
31 use alg_tools::error::DynError; 34 use alg_tools::error::{
35 DynError,
36 DynResult,
37 };
32 use alg_tools::tabledump::TableDump; 38 use alg_tools::tabledump::TableDump;
33 use alg_tools::sets::Cube; 39 use alg_tools::sets::Cube;
34 use alg_tools::mapping::{ 40 use alg_tools::mapping::{
35 RealMapping, 41 RealMapping,
36 DifferentiableRealMapping 42 DifferentiableMapping,
43 DifferentiableRealMapping,
44 Instance
37 }; 45 };
38 use alg_tools::nalgebra_support::ToNalgebraRealField; 46 use alg_tools::nalgebra_support::ToNalgebraRealField;
39 use alg_tools::euclidean::Euclidean; 47 use alg_tools::euclidean::Euclidean;
40 use alg_tools::lingrid::lingrid; 48 use alg_tools::lingrid::{lingrid, LinSpace};
41 use alg_tools::sets::SetOrd; 49 use alg_tools::sets::SetOrd;
50 use alg_tools::linops::{RowOp, IdOp /*, ZeroOp*/};
51 use alg_tools::discrete_gradient::{Grad, ForwardNeumann};
52 use alg_tools::convex::Zero;
53 use alg_tools::maputil::map3;
54 use alg_tools::direct_product::Pair;
42 55
43 use crate::kernels::*; 56 use crate::kernels::*;
44 use crate::types::*; 57 use crate::types::*;
45 use crate::measures::*; 58 use crate::measures::*;
46 use crate::measures::merging::SpikeMerging; 59 use crate::measures::merging::SpikeMerging;
47 use crate::forward_model::*; 60 use crate::forward_model::*;
61 use crate::forward_model::sensor_grid::{
62 SensorGrid,
63 SensorGridBT,
64 //SensorGridBTFN,
65 Sensor,
66 Spread,
67 };
68
48 use crate::fb::{ 69 use crate::fb::{
49 FBConfig, 70 FBConfig,
50 FBGenericConfig, 71 FBGenericConfig,
51 pointsource_fb_reg, 72 pointsource_fb_reg,
52 pointsource_fista_reg, 73 pointsource_fista_reg,
56 pointsource_radon_fb_reg, 77 pointsource_radon_fb_reg,
57 pointsource_radon_fista_reg, 78 pointsource_radon_fista_reg,
58 }; 79 };
59 use crate::sliding_fb::{ 80 use crate::sliding_fb::{
60 SlidingFBConfig, 81 SlidingFBConfig,
82 TransportConfig,
61 pointsource_sliding_fb_reg 83 pointsource_sliding_fb_reg
84 };
85 use crate::sliding_pdps::{
86 SlidingPDPSConfig,
87 pointsource_sliding_pdps_pair
88 };
89 use crate::forward_pdps::{
90 ForwardPDPSConfig,
91 pointsource_forward_pdps_pair
62 }; 92 };
63 use crate::pdps::{ 93 use crate::pdps::{
64 PDPSConfig, 94 PDPSConfig,
65 pointsource_pdps_reg, 95 pointsource_pdps_reg,
66 }; 96 };
67 use crate::frank_wolfe::{ 97 use crate::frank_wolfe::{
68 FWConfig, 98 FWConfig,
69 FWVariant, 99 FWVariant,
70 pointsource_fw_reg, 100 pointsource_fw_reg,
71 WeightOptim, 101 //WeightOptim,
72 }; 102 };
73 use crate::subproblem::InnerSettings; 103 //use crate::subproblem::InnerSettings;
74 use crate::seminorms::*; 104 use crate::seminorms::*;
75 use crate::plot::*; 105 use crate::plot::*;
76 use crate::{AlgorithmOverrides, CommandLineArgs}; 106 use crate::{AlgorithmOverrides, CommandLineArgs};
77 use crate::tolerance::Tolerance; 107 use crate::tolerance::Tolerance;
78 use crate::regularisation::{ 108 use crate::regularisation::{
80 RadonRegTerm, 110 RadonRegTerm,
81 NonnegRadonRegTerm 111 NonnegRadonRegTerm
82 }; 112 };
83 use crate::dataterm::{ 113 use crate::dataterm::{
84 L1, 114 L1,
85 L2Squared 115 L2Squared,
86 }; 116 };
87 use alg_tools::norms::L2; 117 use alg_tools::norms::{L2, NormExponent};
118 use alg_tools::operator_arithmetic::Weighted;
119 use anyhow::anyhow;
88 120
89 /// Available algorithms and their configurations 121 /// Available algorithms and their configurations
90 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 122 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
91 pub enum AlgorithmConfig<F : Float> { 123 pub enum AlgorithmConfig<F : Float> {
92 FB(FBConfig<F>), 124 FB(FBConfig<F>),
94 FW(FWConfig<F>), 126 FW(FWConfig<F>),
95 PDPS(PDPSConfig<F>), 127 PDPS(PDPSConfig<F>),
96 RadonFB(RadonFBConfig<F>), 128 RadonFB(RadonFBConfig<F>),
97 RadonFISTA(RadonFBConfig<F>), 129 RadonFISTA(RadonFBConfig<F>),
98 SlidingFB(SlidingFBConfig<F>), 130 SlidingFB(SlidingFBConfig<F>),
131 ForwardPDPS(ForwardPDPSConfig<F>),
132 SlidingPDPS(SlidingPDPSConfig<F>),
99 } 133 }
100 134
101 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { 135 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> {
102 assert!(v.len() == 3); 136 assert!(v.len() == 3);
103 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } 137 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] }
114 |n| Some((n[0], n[1]))), 148 |n| Some((n[0], n[1]))),
115 merge_every : cli.merge_every.unwrap_or(g.merge_every), 149 merge_every : cli.merge_every.unwrap_or(g.merge_every),
116 merging : cli.merging.clone().unwrap_or(g.merging), 150 merging : cli.merging.clone().unwrap_or(g.merging),
117 final_merging : cli.final_merging.clone().unwrap_or(g.final_merging), 151 final_merging : cli.final_merging.clone().unwrap_or(g.final_merging),
118 tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance), 152 tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance),
153 .. g
154 }
155 };
156 let override_transport = |g : TransportConfig<F>| {
157 TransportConfig {
158 θ0 : cli.theta0.unwrap_or(g.θ0),
159 tolerance_ω: cli.transport_tolerance_omega.unwrap_or(g.tolerance_ω),
160 tolerance_dv: cli.transport_tolerance_dv.unwrap_or(g.tolerance_dv),
161 adaptation: cli.transport_adaptation.unwrap_or(g.adaptation),
119 .. g 162 .. g
120 } 163 }
121 }; 164 };
122 165
123 use AlgorithmConfig::*; 166 use AlgorithmConfig::*;
154 insertion : override_fb_generic(fb.insertion), 197 insertion : override_fb_generic(fb.insertion),
155 .. fb 198 .. fb
156 }), 199 }),
157 SlidingFB(sfb) => SlidingFB(SlidingFBConfig { 200 SlidingFB(sfb) => SlidingFB(SlidingFBConfig {
158 τ0 : cli.tau0.unwrap_or(sfb.τ0), 201 τ0 : cli.tau0.unwrap_or(sfb.τ0),
159 θ0 : cli.theta0.unwrap_or(sfb.θ0), 202 transport : override_transport(sfb.transport),
160 transport_tolerance_ω: cli.transport_tolerance_omega.unwrap_or(sfb.transport_tolerance_ω),
161 transport_tolerance_dv: cli.transport_tolerance_dv.unwrap_or(sfb.transport_tolerance_dv),
162 insertion : override_fb_generic(sfb.insertion), 203 insertion : override_fb_generic(sfb.insertion),
163 .. sfb 204 .. sfb
164 }), 205 }),
206 SlidingPDPS(spdps) => SlidingPDPS(SlidingPDPSConfig {
207 τ0 : cli.tau0.unwrap_or(spdps.τ0),
208 σp0 : cli.sigmap0.unwrap_or(spdps.σp0),
209 σd0 : cli.sigma0.unwrap_or(spdps.σd0),
210 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
211 transport : override_transport(spdps.transport),
212 insertion : override_fb_generic(spdps.insertion),
213 .. spdps
214 }),
215 ForwardPDPS(fpdps) => ForwardPDPS(ForwardPDPSConfig {
216 τ0 : cli.tau0.unwrap_or(fpdps.τ0),
217 σp0 : cli.sigmap0.unwrap_or(fpdps.σp0),
218 σd0 : cli.sigma0.unwrap_or(fpdps.σd0),
219 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
220 insertion : override_fb_generic(fpdps.insertion),
221 .. fpdps
222 }),
165 } 223 }
166 } 224 }
167 } 225 }
168 226
169 /// Helper struct for tagging and [`AlgorithmConfig`] or [`Experiment`] with a name. 227 /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name.
170 #[derive(Clone, Debug, Serialize, Deserialize)] 228 #[derive(Clone, Debug, Serialize, Deserialize)]
171 pub struct Named<Data> { 229 pub struct Named<Data> {
172 pub name : String, 230 pub name : String,
173 #[serde(flatten)] 231 #[serde(flatten)]
174 pub data : Data, 232 pub data : Data,
196 #[clap(name = "radon_fb")] 254 #[clap(name = "radon_fb")]
197 RadonFB, 255 RadonFB,
198 /// The RadonFISTA inertial forward-backward method 256 /// The RadonFISTA inertial forward-backward method
199 #[clap(name = "radon_fista")] 257 #[clap(name = "radon_fista")]
200 RadonFISTA, 258 RadonFISTA,
201 /// The Sliding FB method 259 /// The sliding FB method
202 #[clap(name = "sliding_fb", alias = "sfb")] 260 #[clap(name = "sliding_fb", alias = "sfb")]
203 SlidingFB, 261 SlidingFB,
262 /// The sliding PDPS method
263 #[clap(name = "sliding_pdps", alias = "spdps")]
264 SlidingPDPS,
265 /// The PDPS method with a forward step for the smooth function
266 #[clap(name = "forward_pdps", alias = "fpdps")]
267 ForwardPDPS,
204 } 268 }
205 269
206 impl DefaultAlgorithm { 270 impl DefaultAlgorithm {
207 /// Returns the algorithm configuration corresponding to the algorithm shorthand 271 /// Returns the algorithm configuration corresponding to the algorithm shorthand
208 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { 272 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> {
217 }), 281 }),
218 PDPS => AlgorithmConfig::PDPS(Default::default()), 282 PDPS => AlgorithmConfig::PDPS(Default::default()),
219 RadonFB => AlgorithmConfig::RadonFB(Default::default()), 283 RadonFB => AlgorithmConfig::RadonFB(Default::default()),
220 RadonFISTA => AlgorithmConfig::RadonFISTA(Default::default()), 284 RadonFISTA => AlgorithmConfig::RadonFISTA(Default::default()),
221 SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), 285 SlidingFB => AlgorithmConfig::SlidingFB(Default::default()),
286 SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default()),
287 ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default()),
222 } 288 }
223 } 289 }
224 290
225 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand 291 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand
226 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { 292 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> {
276 #[derive(Serialize)] 342 #[derive(Serialize)]
277 struct CSVLog<F> { 343 struct CSVLog<F> {
278 iter : usize, 344 iter : usize,
279 cpu_time : f64, 345 cpu_time : f64,
280 value : F, 346 value : F,
281 post_value : F, 347 relative_value : F,
348 //post_value : F,
282 n_spikes : usize, 349 n_spikes : usize,
283 inner_iters : usize, 350 inner_iters : usize,
284 merged : usize, 351 merged : usize,
285 pruned : usize, 352 pruned : usize,
286 this_iters : usize, 353 this_iters : usize,
353 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. 420 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$.
354 pub spread : P, 421 pub spread : P,
355 /// Kernel $ρ$ of $𝒟$. 422 /// Kernel $ρ$ of $𝒟$.
356 pub kernel : K, 423 pub kernel : K,
357 /// True point sources 424 /// True point sources
358 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, 425 pub μ_hat : RNDM<F, N>,
359 /// Regularisation term and parameter 426 /// Regularisation term and parameter
360 pub regularisation : Regularisation<F>, 427 pub regularisation : Regularisation<F>,
361 /// For plotting : how wide should the kernels be plotted 428 /// For plotting : how wide should the kernels be plotted
362 pub kernel_plot_width : F, 429 pub kernel_plot_width : F,
363 /// Data term 430 /// Data term
365 /// A map of default configurations for algorithms 432 /// A map of default configurations for algorithms
366 #[serde(skip)] 433 #[serde(skip)]
367 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, 434 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>,
368 } 435 }
369 436
437 #[derive(Debug, Clone, Serialize)]
438 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize>
439 where F : Float,
440 [usize; N] : Serialize,
441 NoiseDistr : Distribution<F>,
442 S : Sensor<F, N>,
443 P : Spread<F, N>,
444 K : SimpleConvolutionKernel<F, N>,
445 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug,
446 {
447 /// Basic setup
448 pub base : ExperimentV2<F, NoiseDistr, S, K, P, N>,
449 /// Weight of TV term
450 pub λ : F,
451 /// Bias function
452 pub bias : B,
453 }
454
370 /// Trait for runnable experiments 455 /// Trait for runnable experiments
371 pub trait RunnableExperiment<F : ClapFloat> { 456 pub trait RunnableExperiment<F : ClapFloat> {
372 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. 457 /// Run all algorithms provided, or default algorithms if none provided, on the experiment.
373 fn runall(&self, cli : &CommandLineArgs, 458 fn runall(&self, cli : &CommandLineArgs,
374 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; 459 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError;
375 460
376 /// Return algorithm default config 461 /// Return algorithm default config
377 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) 462 fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>>;
378 -> Named<AlgorithmConfig<F>>; 463 }
379 } 464
380 465 /// Helper function to print experiment start message and save setup.
381 // *** macro boilerplate *** 466 /// Returns saving prefix.
382 macro_rules! impl_experiment { 467 fn start_experiment<E, S>(
383 ($type:ident, $reg_field:ident, $reg_convert:path) => { 468 experiment : &Named<E>,
384 // *** macro *** 469 cli : &CommandLineArgs,
470 stats : S,
471 ) -> DynResult<String>
472 where
473 E : Serialize + std::fmt::Debug,
474 S : Serialize,
475 {
476 let Named { name : experiment_name, data } = experiment;
477
478 println!("{}\n{}",
479 format!("Performing experiment {}…", experiment_name).cyan(),
480 format!("{:?}", data).bright_black());
481
482 // Set up output directory
483 let prefix = format!("{}/{}/", cli.outdir, experiment_name);
484
485 // Save experiment configuration and statistics
486 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t);
487 std::fs::create_dir_all(&prefix)?;
488 write_json(mkname_e("experiment"), experiment)?;
489 write_json(mkname_e("config"), cli)?;
490 write_json(mkname_e("stats"), &stats)?;
491
492 Ok(prefix)
493 }
494
495 /// Error codes for running an algorithm on an experiment.
496 enum RunError {
497 /// Algorithm not implemented for this experiment
498 NotImplemented,
499 }
500
501 use RunError::*;
502
503 type DoRunAllIt<'a, F, const N : usize> = LoggingIteratorFactory<
504 'a,
505 Timed<IterInfo<F, N>>,
506 TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F, N>>>
507 >;
508
509 /// Helper function to run all algorithms on an experiment.
510 fn do_runall<F : Float, Z, const N : usize>(
511 experiment_name : &String,
512 prefix : &String,
513 cli : &CommandLineArgs,
514 algorithms : Vec<Named<AlgorithmConfig<F>>>,
515 plotgrid : LinSpace<Loc<F, N>, [usize; N]>,
516 mut save_extra : impl FnMut(String, Z) -> DynError,
517 mut do_alg : impl FnMut(
518 &AlgorithmConfig<F>,
519 DoRunAllIt<F, N>,
520 SeqPlotter<F, N>,
521 String,
522 ) -> Result<(RNDM<F, N>, Z), RunError>,
523 ) -> DynError
524 where
525 PlotLookup : Plotting<N>,
526 {
527 let mut logs = Vec::new();
528
529 let iterator_options = AlgIteratorOptions{
530 max_iter : cli.max_iter,
531 verbose_iter : cli.verbose_iter
532 .map_or(Verbose::Logarithmic(10),
533 |n| Verbose::Every(n)),
534 quiet : cli.quiet,
535 };
536
537 // Run the algorithm(s)
538 for named @ Named { name : alg_name, data : alg } in algorithms.iter() {
539 let this_prefix = format!("{}{}/", prefix, alg_name);
540
541 // Create Logger and IteratorFactory
542 let mut logger = Logger::new();
543 let iterator = iterator_options.instantiate()
544 .timed()
545 .into_log(&mut logger);
546
547 let running = if !cli.quiet {
548 format!("{}\n{}\n{}\n",
549 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(),
550 format!("{:?}", iterator_options).bright_black(),
551 format!("{:?}", alg).bright_black())
552 } else {
553 "".to_string()
554 };
555 //
556 // The following is for postprocessing, which has been disabled anyway.
557 //
558 // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation {
559 // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)),
560 // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)),
561 // };
562 //let findim_data = reg.prepare_optimise_weights(&opA, &b);
563 //let inner_config : InnerSettings<F> = Default::default();
564 //let inner_it = inner_config.iterator_options;
565
566 // Create plotter and directory if needed.
567 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 };
568 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid.clone());
569
570 let start = Instant::now();
571 let start_cpu = ProcessTime::now();
572
573 let (μ, z) = match do_alg(alg, iterator, plotter, running) {
574 Ok(μ) => μ,
575 Err(RunError::NotImplemented) => {
576 let msg = format!("Algorithm “{alg_name}” not implemented for {experiment_name}. \
577 Skipping.").red();
578 eprintln!("{}", msg);
579 continue
580 }
581 };
582
583 let elapsed = start.elapsed().as_secs_f64();
584 let cpu_time = start_cpu.elapsed().as_secs_f64();
585
586 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow());
587
588 // Save results
589 println!("{}", "Saving results …".green());
590
591 let mkname = |t| format!("{prefix}{alg_name}_{t}");
592
593 write_json(mkname("config.json"), &named)?;
594 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?;
595 μ.write_csv(mkname("reco.txt"))?;
596 save_extra(mkname(""), z)?;
597 //logger.write_csv(mkname("log.txt"))?;
598 logs.push((mkname("log.txt"), logger));
599 }
600
601 save_logs(logs)
602 }
603
604 #[replace_float_literals(F::cast_from(literal))]
385 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for 605 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for
386 Named<$type<F, NoiseDistr, S, K, P, N>> 606 Named<ExperimentV2<F, NoiseDistr, S, K, P, N>>
387 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, 607 where
388 [usize; N] : Serialize, 608 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
389 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, 609 [usize; N] : Serialize,
390 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, 610 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
391 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy 611 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug,
392 // TODO: shold not have differentiability as a requirement, but 612 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy
393 // decide availability of sliding based on it. 613 // TODO: shold not have differentiability as a requirement, but
394 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, 614 // decide availability of sliding based on it.
395 // TODO: very weird that rust only compiles with Differentiable 615 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>,
396 // instead of the above one on references, which is required by 616 // TODO: very weird that rust only compiles with Differentiable
397 // poitsource_sliding_fb_reg. 617 // instead of the above one on references, which is required by
398 + DifferentiableRealMapping<F, N> 618 // poitsource_sliding_fb_reg.
399 + Lipschitz<L2, FloatType=F>, 619 + DifferentiableRealMapping<F, N>
400 // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, 620 + Lipschitz<L2, FloatType=F>,
401 AutoConvolution<P> : BoundedBy<F, K>, 621 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb.
402 K : SimpleConvolutionKernel<F, N> 622 AutoConvolution<P> : BoundedBy<F, K>,
403 + LocalAnalysis<F, Bounds<F>, N> 623 K : SimpleConvolutionKernel<F, N>
404 + Copy + Serialize + std::fmt::Debug, 624 + LocalAnalysis<F, Bounds<F>, N>
405 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, 625 + Copy + Serialize + std::fmt::Debug,
406 PlotLookup : Plotting<N>, 626 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
407 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, 627 PlotLookup : Plotting<N>,
408 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 628 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
409 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 629 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
410 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug { 630 RNDM<F, N> : SpikeMerging<F>,
411 631 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug
412 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) 632 {
413 -> Named<AlgorithmConfig<F>> { 633
414 alg.to_named( 634 fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> {
415 self.data 635 self.data.algorithm_defaults.get(&alg).cloned()
416 .algorithm_defaults
417 .get(&alg)
418 .map_or_else(|| alg.default_config(),
419 |config| config.clone())
420 .cli_override(cli)
421 )
422 } 636 }
423 637
424 fn runall(&self, cli : &CommandLineArgs, 638 fn runall(&self, cli : &CommandLineArgs,
425 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { 639 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError {
426 // Get experiment configuration 640 // Get experiment configuration
427 let &Named { 641 let &Named {
428 name : ref experiment_name, 642 name : ref experiment_name,
429 data : $type { 643 data : ExperimentV2 {
430 domain, sensor_count, ref noise_distr, sensor, spread, kernel, 644 domain, sensor_count, ref noise_distr, sensor, spread, kernel,
431 ref μ_hat, /*regularisation,*/ kernel_plot_width, dataterm, noise_seed, 645 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed,
432 .. 646 ..
433 } 647 }
434 } = self; 648 } = self;
435 let regularisation = $reg_convert(self.data.$reg_field);
436
437 println!("{}\n{}",
438 format!("Performing experiment {}…", experiment_name).cyan(),
439 format!("{:?}", &self.data).bright_black());
440
441 // Set up output directory
442 let prefix = format!("{}/{}/", cli.outdir, self.name);
443 649
444 // Set up algorithms 650 // Set up algorithms
445 let iterator_options = AlgIteratorOptions{ 651 let algorithms = match (algs, dataterm) {
446 max_iter : cli.max_iter,
447 verbose_iter : cli.verbose_iter
448 .map_or(Verbose::Logarithmic(10),
449 |n| Verbose::Every(n)),
450 quiet : cli.quiet,
451 };
452 let algorithms = match (algs, self.data.dataterm) {
453 (Some(algs), _) => algs, 652 (Some(algs), _) => algs,
454 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], 653 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()],
455 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], 654 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()],
456 }; 655 };
457 656
462 661
463 // Set up random number generator. 662 // Set up random number generator.
464 let mut rng = StdRng::seed_from_u64(noise_seed); 663 let mut rng = StdRng::seed_from_u64(noise_seed);
465 664
466 // Generate the data and calculate SSNR statistic 665 // Generate the data and calculate SSNR statistic
467 let b_hat = opA.apply(μ_hat); 666 let b_hat : DVector<_> = opA.apply(μ_hat);
468 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); 667 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng);
469 let b = &b_hat + &noise; 668 let b = &b_hat + &noise;
470 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField 669 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
471 // overloading log10 and conflicting with standard NumTraits one. 670 // overloading log10 and conflicting with standard NumTraits one.
472 let stats = ExperimentStats::new(&b, &noise); 671 let stats = ExperimentStats::new(&b, &noise);
473 672
474 // Save experiment configuration and statistics 673 let prefix = start_experiment(&self, cli, stats)?;
475 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t);
476 std::fs::create_dir_all(&prefix)?;
477 write_json(mkname_e("experiment"), self)?;
478 write_json(mkname_e("config"), cli)?;
479 write_json(mkname_e("stats"), &stats)?;
480 674
481 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, 675 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread,
482 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; 676 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?;
483 677
484 // Run the algorithm(s) 678 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]);
485 for named @ Named { name : alg_name, data : alg } in algorithms.iter() { 679
486 let this_prefix = format!("{}{}/", prefix, alg_name); 680 let save_extra = |_, ()| Ok(());
487 681
488 let running = || if !cli.quiet { 682 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra,
489 println!("{}\n{}\n{}", 683 |alg, iterator, plotter, running|
490 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), 684 {
491 format!("{:?}", iterator_options).bright_black(),
492 format!("{:?}", alg).bright_black());
493 };
494 let not_implemented = || {
495 let msg = format!("Algorithm “{alg_name}” not implemented for \
496 dataterm {dataterm:?} and regularisation {regularisation:?}. \
497 Skipping.").red();
498 eprintln!("{}", msg);
499 };
500 // Create Logger and IteratorFactory
501 let mut logger = Logger::new();
502 let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation {
503 Regularisation::Radon(α) => Box::new(RadonRegTerm(α)),
504 Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)),
505 };
506 let findim_data = reg.prepare_optimise_weights(&opA, &b);
507 let inner_config : InnerSettings<F> = Default::default();
508 let inner_it = inner_config.iterator_options;
509 let logmap = |iter, Timed { cpu_time, data }| {
510 let IterInfo {
511 value,
512 n_spikes,
513 inner_iters,
514 merged,
515 pruned,
516 postprocessing,
517 this_iters,
518 ..
519 } = data;
520 let post_value = match (postprocessing, dataterm) {
521 (Some(mut μ), DataTerm::L2Squared) => {
522 // Comparison postprocessing is only implemented for the case handled
523 // by the FW variants.
524 reg.optimise_weights(
525 &mut μ, &opA, &b, &findim_data, &inner_config,
526 inner_it
527 );
528 dataterm.value_at_residual(opA.apply(&μ) - &b)
529 + regularisation.apply(&μ)
530 },
531 _ => value,
532 };
533 CSVLog {
534 iter,
535 value,
536 post_value,
537 n_spikes,
538 cpu_time : cpu_time.as_secs_f64(),
539 inner_iters,
540 merged,
541 pruned,
542 this_iters
543 }
544 };
545 let iterator = iterator_options.instantiate()
546 .timed()
547 .mapped(logmap)
548 .into_log(&mut logger);
549 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]);
550
551 // Create plotter and directory if needed.
552 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 };
553 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid);
554
555 // Run the algorithm
556 let start = Instant::now();
557 let start_cpu = ProcessTime::now();
558 let μ = match alg { 685 let μ = match alg {
559 AlgorithmConfig::FB(ref algconfig) => { 686 AlgorithmConfig::FB(ref algconfig) => {
560 match (regularisation, dataterm) { 687 match (regularisation, dataterm) {
561 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { 688 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({
562 running(); 689 print!("{running}");
563 pointsource_fb_reg( 690 pointsource_fb_reg(
564 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 691 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
565 iterator, plotter 692 iterator, plotter
566 ) 693 )
567 }, 694 }),
568 (Regularisation::Radon(α), DataTerm::L2Squared) => { 695 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({
569 running(); 696 print!("{running}");
570 pointsource_fb_reg( 697 pointsource_fb_reg(
571 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 698 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
572 iterator, plotter 699 iterator, plotter
573 ) 700 )
574 }, 701 }),
575 _ => { 702 _ => Err(NotImplemented)
576 not_implemented();
577 continue
578 }
579 } 703 }
580 }, 704 },
581 AlgorithmConfig::FISTA(ref algconfig) => { 705 AlgorithmConfig::FISTA(ref algconfig) => {
582 match (regularisation, dataterm) { 706 match (regularisation, dataterm) {
583 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { 707 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({
584 running(); 708 print!("{running}");
585 pointsource_fista_reg( 709 pointsource_fista_reg(
586 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 710 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
587 iterator, plotter 711 iterator, plotter
588 ) 712 )
589 }, 713 }),
590 (Regularisation::Radon(α), DataTerm::L2Squared) => { 714 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({
591 running(); 715 print!("{running}");
592 pointsource_fista_reg( 716 pointsource_fista_reg(
593 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 717 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
594 iterator, plotter 718 iterator, plotter
595 ) 719 )
596 }, 720 }),
597 _ => { 721 _ => Err(NotImplemented),
598 not_implemented();
599 continue
600 }
601 } 722 }
602 }, 723 },
603 AlgorithmConfig::RadonFB(ref algconfig) => { 724 AlgorithmConfig::RadonFB(ref algconfig) => {
604 match (regularisation, dataterm) { 725 match (regularisation, dataterm) {
605 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { 726 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({
606 running(); 727 print!("{running}");
607 pointsource_radon_fb_reg( 728 pointsource_radon_fb_reg(
608 &opA, &b, NonnegRadonRegTerm(α), algconfig, 729 &opA, &b, NonnegRadonRegTerm(α), algconfig,
609 iterator, plotter 730 iterator, plotter
610 ) 731 )
611 }, 732 }),
612 (Regularisation::Radon(α), DataTerm::L2Squared) => { 733 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({
613 running(); 734 print!("{running}");
614 pointsource_radon_fb_reg( 735 pointsource_radon_fb_reg(
615 &opA, &b, RadonRegTerm(α), algconfig, 736 &opA, &b, RadonRegTerm(α), algconfig,
616 iterator, plotter 737 iterator, plotter
617 ) 738 )
618 }, 739 }),
619 _ => { 740 _ => Err(NotImplemented),
620 not_implemented();
621 continue
622 }
623 } 741 }
624 }, 742 },
625 AlgorithmConfig::RadonFISTA(ref algconfig) => { 743 AlgorithmConfig::RadonFISTA(ref algconfig) => {
626 match (regularisation, dataterm) { 744 match (regularisation, dataterm) {
627 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { 745 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({
628 running(); 746 print!("{running}");
629 pointsource_radon_fista_reg( 747 pointsource_radon_fista_reg(
630 &opA, &b, NonnegRadonRegTerm(α), algconfig, 748 &opA, &b, NonnegRadonRegTerm(α), algconfig,
631 iterator, plotter 749 iterator, plotter
632 ) 750 )
633 }, 751 }),
634 (Regularisation::Radon(α), DataTerm::L2Squared) => { 752 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({
635 running(); 753 print!("{running}");
636 pointsource_radon_fista_reg( 754 pointsource_radon_fista_reg(
637 &opA, &b, RadonRegTerm(α), algconfig, 755 &opA, &b, RadonRegTerm(α), algconfig,
638 iterator, plotter 756 iterator, plotter
639 ) 757 )
640 }, 758 }),
641 _ => { 759 _ => Err(NotImplemented),
642 not_implemented();
643 continue
644 }
645 } 760 }
646 }, 761 },
647 AlgorithmConfig::SlidingFB(ref algconfig) => { 762 AlgorithmConfig::SlidingFB(ref algconfig) => {
648 match (regularisation, dataterm) { 763 match (regularisation, dataterm) {
649 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { 764 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({
650 running(); 765 print!("{running}");
651 pointsource_sliding_fb_reg( 766 pointsource_sliding_fb_reg(
652 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 767 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
653 iterator, plotter 768 iterator, plotter
654 ) 769 )
655 }, 770 }),
656 (Regularisation::Radon(α), DataTerm::L2Squared) => { 771 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({
657 running(); 772 print!("{running}");
658 pointsource_sliding_fb_reg( 773 pointsource_sliding_fb_reg(
659 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 774 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
660 iterator, plotter 775 iterator, plotter
661 ) 776 )
662 }, 777 }),
663 _ => { 778 _ => Err(NotImplemented),
664 not_implemented();
665 continue
666 }
667 } 779 }
668 }, 780 },
669 AlgorithmConfig::PDPS(ref algconfig) => { 781 AlgorithmConfig::PDPS(ref algconfig) => {
670 running(); 782 print!("{running}");
671 match (regularisation, dataterm) { 783 match (regularisation, dataterm) {
672 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { 784 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({
673 pointsource_pdps_reg( 785 pointsource_pdps_reg(
674 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 786 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
675 iterator, plotter, L2Squared 787 iterator, plotter, L2Squared
676 ) 788 )
677 }, 789 }),
678 (Regularisation::Radon(α),DataTerm::L2Squared) => { 790 (Regularisation::Radon(α),DataTerm::L2Squared) => Ok({
679 pointsource_pdps_reg( 791 pointsource_pdps_reg(
680 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 792 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
681 iterator, plotter, L2Squared 793 iterator, plotter, L2Squared
682 ) 794 )
683 }, 795 }),
684 (Regularisation::NonnegRadon(α), DataTerm::L1) => { 796 (Regularisation::NonnegRadon(α), DataTerm::L1) => Ok({
685 pointsource_pdps_reg( 797 pointsource_pdps_reg(
686 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 798 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
687 iterator, plotter, L1 799 iterator, plotter, L1
688 ) 800 )
689 }, 801 }),
690 (Regularisation::Radon(α), DataTerm::L1) => { 802 (Regularisation::Radon(α), DataTerm::L1) => Ok({
691 pointsource_pdps_reg( 803 pointsource_pdps_reg(
692 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 804 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
693 iterator, plotter, L1 805 iterator, plotter, L1
694 ) 806 )
695 }, 807 }),
696 } 808 }
697 }, 809 },
698 AlgorithmConfig::FW(ref algconfig) => { 810 AlgorithmConfig::FW(ref algconfig) => {
699 match (regularisation, dataterm) { 811 match (regularisation, dataterm) {
700 (Regularisation::Radon(α), DataTerm::L2Squared) => { 812 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({
701 running(); 813 print!("{running}");
702 pointsource_fw_reg(&opA, &b, RadonRegTerm(α), 814 pointsource_fw_reg(&opA, &b, RadonRegTerm(α),
703 algconfig, iterator, plotter) 815 algconfig, iterator, plotter)
704 }, 816 }),
705 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { 817 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({
706 running(); 818 print!("{running}");
707 pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), 819 pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α),
708 algconfig, iterator, plotter) 820 algconfig, iterator, plotter)
709 }, 821 }),
710 _ => { 822 _ => Err(NotImplemented),
711 not_implemented();
712 continue
713 }
714 } 823 }
824 },
825 _ => Err(NotImplemented),
826 }?;
827 Ok((μ, ()))
828 })
829 }
830 }
831
832
833 #[replace_float_literals(F::cast_from(literal))]
834 impl<F, NoiseDistr, S, K, P, B, const N : usize> RunnableExperiment<F> for
835 Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>>
836 where
837 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
838 [usize; N] : Serialize,
839 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
840 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug,
841 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy
842 // TODO: shold not have differentiability as a requirement, but
843 // decide availability of sliding based on it.
844 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>,
845 // TODO: very weird that rust only compiles with Differentiable
846 // instead of the above one on references, which is required by
847 // poitsource_sliding_fb_reg.
848 + DifferentiableRealMapping<F, N>
849 + Lipschitz<L2, FloatType=F>,
850 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb.
851 AutoConvolution<P> : BoundedBy<F, K>,
852 K : SimpleConvolutionKernel<F, N>
853 + LocalAnalysis<F, Bounds<F>, N>
854 + Copy + Serialize + std::fmt::Debug,
855 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
856 PlotLookup : Plotting<N>,
857 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
858 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
859 RNDM<F, N> : SpikeMerging<F>,
860 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug,
861 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug,
862 {
863
864 fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> {
865 self.data.base.algorithm_defaults.get(&alg).cloned()
866 }
867
868 fn runall(&self, cli : &CommandLineArgs,
869 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError {
870 // Get experiment configuration
871 let &Named {
872 name : ref experiment_name,
873 data : ExperimentBiased {
874 λ,
875 ref bias,
876 base : ExperimentV2 {
877 domain, sensor_count, ref noise_distr, sensor, spread, kernel,
878 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed,
879 ..
715 } 880 }
716 }; 881 }
717 882 } = self;
718 let elapsed = start.elapsed().as_secs_f64(); 883
719 let cpu_time = start_cpu.elapsed().as_secs_f64(); 884 // Set up algorithms
720 885 let algorithms = match (algs, dataterm) {
721 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); 886 (Some(algs), _) => algs,
722 887 _ => vec![DefaultAlgorithm::SlidingPDPS.get_named()],
723 // Save results 888 };
724 println!("{}", "Saving results…".green()); 889
725 890 // Set up operators
726 let mkname = |t| format!("{prefix}{alg_name}_{t}"); 891 let depth = DynamicDepth(8);
727 892 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth);
728 write_json(mkname("config.json"), &named)?; 893 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel);
729 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; 894 let opAext = RowOp(opA.clone(), IdOp::new());
730 μ.write_csv(mkname("reco.txt"))?; 895 let fnR = Zero::new();
731 logger.write_csv(mkname("log.txt"))?; 896 let h = map3(domain.span_start(), domain.span_end(), sensor_count,
897 |a, b, n| (b-a)/F::cast_from(n))
898 .into_iter()
899 .reduce(NumTraitsFloat::max)
900 .unwrap();
901 let z = DVector::zeros(sensor_count.iter().product());
902 let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap();
903 let y = opKz.apply(&z);
904 let fnH = Weighted{ base_fn : L1.as_mapping(), weight : λ}; // TODO: L_{2,1}
905 // let zero_y = y.clone();
906 // let zeroBTFN = opA.preadjoint().apply(&zero_y);
907 // let opKμ = ZeroOp::new(&zero_y, zeroBTFN);
908
909 // Set up random number generator.
910 let mut rng = StdRng::seed_from_u64(noise_seed);
911
912 // Generate the data and calculate SSNR statistic
913 let bias_vec = DVector::from_vec(opA.grid()
914 .into_iter()
915 .map(|v| bias.apply(v))
916 .collect::<Vec<F>>());
917 let b_hat : DVector<_> = opA.apply(μ_hat) + &bias_vec;
918 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng);
919 let b = &b_hat + &noise;
920 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
921 // overloading log10 and conflicting with standard NumTraits one.
922 let stats = ExperimentStats::new(&b, &noise);
923
924 let prefix = start_experiment(&self, cli, stats)?;
925
926 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread,
927 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?;
928
929 opA.write_observable(&bias_vec, format!("{prefix}bias"))?;
930
931 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]);
932
933 let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z"));
934
935 // Run the algorithms
936 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra,
937 |alg, iterator, plotter, running|
938 {
939 let Pair(μ, z) = match alg {
940 AlgorithmConfig::ForwardPDPS(ref algconfig) => {
941 match (regularisation, dataterm) {
942 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({
943 print!("{running}");
944 pointsource_forward_pdps_pair(
945 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
946 iterator, plotter,
947 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
948 )
949 }),
950 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({
951 print!("{running}");
952 pointsource_forward_pdps_pair(
953 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig,
954 iterator, plotter,
955 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
956 )
957 }),
958 _ => Err(NotImplemented)
959 }
960 },
961 AlgorithmConfig::SlidingPDPS(ref algconfig) => {
962 match (regularisation, dataterm) {
963 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({
964 print!("{running}");
965 pointsource_sliding_pdps_pair(
966 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
967 iterator, plotter,
968 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
969 )
970 }),
971 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({
972 print!("{running}");
973 pointsource_sliding_pdps_pair(
974 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig,
975 iterator, plotter,
976 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
977 )
978 }),
979 _ => Err(NotImplemented)
980 }
981 },
982 _ => Err(NotImplemented)
983 }?;
984 Ok((μ, z))
985 })
986 }
987 }
988
989
990 /// Calculative minimum and maximum values of all the `logs`, and save them into
991 /// corresponding file names given as the first elements of the tuples in the vectors.
992 fn save_logs<F : Float, const N : usize>(
993 logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)>
994 ) -> DynError {
995 // Process logs for relative values
996 println!("{}", "Processing logs…");
997
998
999 // Find minimum value and initial value within a single log
1000 let proc_single_log = |log : &Logger<Timed<IterInfo<F, N>>>| {
1001 let d = log.data();
1002 let mi = d.iter()
1003 .map(|i| i.data.value)
1004 .reduce(NumTraitsFloat::min);
1005 d.first()
1006 .map(|i| i.data.value)
1007 .zip(mi)
1008 };
1009
1010 // Find minimum and maximum value over all logs
1011 let (v_ini, v_min) = logs.iter()
1012 .filter_map(|&(_, ref log)| proc_single_log(log))
1013 .reduce(|(i1, m1), (i2, m2)| (i1.max(i2), m1.min(m2)))
1014 .ok_or(anyhow!("No algorithms found"))?;
1015
1016 let logmap = |Timed { cpu_time, iter, data }| {
1017 let IterInfo {
1018 value,
1019 n_spikes,
1020 inner_iters,
1021 merged,
1022 pruned,
1023 //postprocessing,
1024 this_iters,
1025 ..
1026 } = data;
1027 // let post_value = match (postprocessing, dataterm) {
1028 // (Some(mut μ), DataTerm::L2Squared) => {
1029 // // Comparison postprocessing is only implemented for the case handled
1030 // // by the FW variants.
1031 // reg.optimise_weights(
1032 // &mut μ, &opA, &b, &findim_data, &inner_config,
1033 // inner_it
1034 // );
1035 // dataterm.value_at_residual(opA.apply(&μ) - &b)
1036 // + regularisation.apply(&μ)
1037 // },
1038 // _ => value,
1039 // };
1040 let relative_value = (value - v_min)/(v_ini - v_min);
1041 CSVLog {
1042 iter,
1043 value,
1044 relative_value,
1045 //post_value,
1046 n_spikes,
1047 cpu_time : cpu_time.as_secs_f64(),
1048 inner_iters,
1049 merged,
1050 pruned,
1051 this_iters
732 } 1052 }
733 1053 };
734 Ok(()) 1054
1055 println!("{}", "Saving logs …".green());
1056
1057 for (name, logger) in logs {
1058 logger.map(logmap).write_csv(name)?;
735 } 1059 }
736 } 1060
737 // *** macro end boiler plate *** 1061 Ok(())
738 }} 1062 }
739 // *** actual code *** 1063
740
741 impl_experiment!(ExperimentV2, regularisation, std::convert::identity);
742 1064
743 /// Plot experiment setup 1065 /// Plot experiment setup
744 #[replace_float_literals(F::cast_from(literal))] 1066 #[replace_float_literals(F::cast_from(literal))]
745 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( 1067 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>(
746 cli : &CommandLineArgs, 1068 cli : &CommandLineArgs,
747 prefix : &String, 1069 prefix : &String,
748 domain : &Cube<F, N>, 1070 domain : &Cube<F, N>,
749 sensor : &Sensor, 1071 sensor : &Sensor,
750 kernel : &Kernel, 1072 kernel : &Kernel,
751 spread : &Spread, 1073 spread : &Spread,
752 μ_hat : &DiscreteMeasure<Loc<F, N>, F>, 1074 μ_hat : &RNDM<F, N>,
753 op𝒟 : &𝒟, 1075 op𝒟 : &𝒟,
754 opA : &A, 1076 opA : &A,
755 b_hat : &A::Observable, 1077 b_hat : &A::Observable,
756 b : &A::Observable, 1078 b : &A::Observable,
757 kernel_plot_width : F, 1079 kernel_plot_width : F,
759 where F : Float + ToNalgebraRealField, 1081 where F : Float + ToNalgebraRealField,
760 Sensor : RealMapping<F, N> + Support<F, N> + Clone, 1082 Sensor : RealMapping<F, N> + Support<F, N> + Clone,
761 Spread : RealMapping<F, N> + Support<F, N> + Clone, 1083 Spread : RealMapping<F, N> + Support<F, N> + Clone,
762 Kernel : RealMapping<F, N> + Support<F, N>, 1084 Kernel : RealMapping<F, N> + Support<F, N>,
763 Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>, 1085 Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>,
764 //Differential<Loc<F, N>, Convolution<Sensor, Spread>> : RealVectorField<F, N, N>,
765 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, 1086 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>,
766 𝒟::Codomain : RealMapping<F, N>, 1087 𝒟::Codomain : RealMapping<F, N>,
767 A : ForwardModel<Loc<F, N>, F>, 1088 A : ForwardModel<RNDM<F, N>, F>,
1089 for<'a> &'a A::Observable : Instance<A::Observable>,
768 A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>, 1090 A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>,
769 PlotLookup : Plotting<N>, 1091 PlotLookup : Plotting<N>,
770 Cube<F, N> : SetOrd { 1092 Cube<F, N> : SetOrd {
771 1093
772 if cli.plot < PlotLevel::Data { 1094 if cli.plot < PlotLevel::Data {
774 } 1096 }
775 1097
776 let base = Convolution(sensor.clone(), spread.clone()); 1098 let base = Convolution(sensor.clone(), spread.clone());
777 1099
778 let resolution = if N==1 { 100 } else { 40 }; 1100 let resolution = if N==1 { 100 } else { 40 };
779 let pfx = |n| format!("{}{}", prefix, n); 1101 let pfx = |n| format!("{prefix}{n}");
780 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); 1102 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]);
781 1103
782 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); 1104 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"));
783 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); 1105 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"));
784 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); 1106 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"));
785 PlotLookup::plot_into_file_diff(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); 1107 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"));
786 1108
787 let plotgrid2 = lingrid(&domain, &[resolution; N]); 1109 let plotgrid2 = lingrid(&domain, &[resolution; N]);
788 1110
789 let ω_hat = op𝒟.apply(μ_hat); 1111 let ω_hat = op𝒟.apply(μ_hat);
790 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); 1112 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b);
791 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string()); 1113 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"));
792 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"), 1114 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"));
793 "noise Aᵀ(Aμ̂ - b)".to_string());
794 1115
795 let preadj_b = opA.preadjoint().apply(b); 1116 let preadj_b = opA.preadjoint().apply(b);
796 let preadj_b_hat = opA.preadjoint().apply(b_hat); 1117 let preadj_b_hat = opA.preadjoint().apply(b_hat);
797 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); 1118 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds());
798 PlotLookup::plot_into_file_spikes( 1119 PlotLookup::plot_into_file_spikes(
799 "Aᵀb".to_string(), &preadj_b, 1120 Some(&preadj_b),
800 "Aᵀb̂".to_string(), Some(&preadj_b_hat), 1121 Some(&preadj_b_hat),
801 plotgrid2, None, &μ_hat, 1122 plotgrid2,
1123 &μ_hat,
802 pfx("omega_b") 1124 pfx("omega_b")
803 ); 1125 );
804 PlotLookup::plot_into_file_diff(&preadj_b, plotgrid2, pfx("preadj_b"), 1126 PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b"));
805 "preadj_b".to_string()); 1127 PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"));
806 PlotLookup::plot_into_file_diff(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"),
807 "preadj_b_hat".to_string());
808 1128
809 // Save true solution and observables 1129 // Save true solution and observables
810 let pfx = |n| format!("{}{}", prefix, n);
811 μ_hat.write_csv(pfx("orig.txt"))?; 1130 μ_hat.write_csv(pfx("orig.txt"))?;
812 opA.write_observable(&b_hat, pfx("b_hat"))?; 1131 opA.write_observable(&b_hat, pfx("b_hat"))?;
813 opA.write_observable(&b, pfx("b_noisy")) 1132 opA.write_observable(&b, pfx("b_noisy"))
814 } 1133 }

mercurial