src/run.rs

changeset 24
d29d1fcf5423
parent 23
9869fa1e0ccd
child 25
79943be70720
equal deleted inserted replaced
23:9869fa1e0ccd 24:d29d1fcf5423
32 use alg_tools::tabledump::TableDump; 32 use alg_tools::tabledump::TableDump;
33 use alg_tools::sets::Cube; 33 use alg_tools::sets::Cube;
34 use alg_tools::mapping::RealMapping; 34 use alg_tools::mapping::RealMapping;
35 use alg_tools::nalgebra_support::ToNalgebraRealField; 35 use alg_tools::nalgebra_support::ToNalgebraRealField;
36 use alg_tools::euclidean::Euclidean; 36 use alg_tools::euclidean::Euclidean;
37 use alg_tools::norms::{Norm, L1}; 37 use alg_tools::norms::L1;
38 use alg_tools::lingrid::lingrid; 38 use alg_tools::lingrid::lingrid;
39 use alg_tools::sets::SetOrd; 39 use alg_tools::sets::SetOrd;
40 40
41 use crate::kernels::*; 41 use crate::kernels::*;
42 use crate::types::*; 42 use crate::types::*;
43 use crate::measures::*; 43 use crate::measures::*;
44 use crate::measures::merging::SpikeMerging; 44 use crate::measures::merging::SpikeMerging;
45 use crate::forward_model::*; 45 use crate::forward_model::*;
46 use crate::fb::{ 46 use crate::fb::{
47 FBConfig, 47 FBConfig,
48 pointsource_fb, 48 pointsource_fb_reg,
49 FBMetaAlgorithm, FBGenericConfig, 49 FBMetaAlgorithm,
50 FBGenericConfig,
50 }; 51 };
51 use crate::pdps::{ 52 use crate::pdps::{
52 PDPSConfig, 53 PDPSConfig,
53 L2Squared, 54 L2Squared,
54 pointsource_pdps, 55 pointsource_pdps_reg,
55 }; 56 };
56 use crate::frank_wolfe::{ 57 use crate::frank_wolfe::{
57 FWConfig, 58 FWConfig,
58 FWVariant, 59 FWVariant,
59 pointsource_fw, 60 pointsource_fw,
63 use crate::subproblem::InnerSettings; 64 use crate::subproblem::InnerSettings;
64 use crate::seminorms::*; 65 use crate::seminorms::*;
65 use crate::plot::*; 66 use crate::plot::*;
66 use crate::{AlgorithmOverrides, CommandLineArgs}; 67 use crate::{AlgorithmOverrides, CommandLineArgs};
67 use crate::tolerance::Tolerance; 68 use crate::tolerance::Tolerance;
69 use crate::regularisation::{Regularisation, RadonRegTerm, NonnegRadonRegTerm};
68 70
69 /// Available algorithms and their configurations 71 /// Available algorithms and their configurations
70 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 72 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
71 pub enum AlgorithmConfig<F : Float> { 73 pub enum AlgorithmConfig<F : Float> {
72 FB(FBConfig<F>), 74 FB(FBConfig<F>),
271 fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { 273 fn write_json<T : Serialize>(filename : String, data : &T) -> DynError {
272 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; 274 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?;
273 Ok(()) 275 Ok(())
274 } 276 }
275 277
278
279 /// Struct for experiment configurations
280 #[derive(Debug, Clone, Serialize)]
281 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize>
282 where F : Float,
283 [usize; N] : Serialize,
284 NoiseDistr : Distribution<F>,
285 S : Sensor<F, N>,
286 P : Spread<F, N>,
287 K : SimpleConvolutionKernel<F, N>,
288 {
289 /// Domain $Ω$.
290 pub domain : Cube<F, N>,
291 /// Number of sensors along each dimension
292 pub sensor_count : [usize; N],
293 /// Noise distribution
294 pub noise_distr : NoiseDistr,
295 /// Seed for random noise generation (for repeatable experiments)
296 pub noise_seed : u64,
297 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$.
298 pub sensor : S,
299 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$.
300 pub spread : P,
301 /// Kernel $ρ$ of $𝒟$.
302 pub kernel : K,
303 /// True point sources
304 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>,
305 /// Regularisation term and parameter
306 pub regularisation : Regularisation<F>,
307 /// For plotting : how wide should the kernels be plotted
308 pub kernel_plot_width : F,
309 /// Data term
310 pub dataterm : DataTerm,
311 /// A map of default configurations for algorithms
312 #[serde(skip)]
313 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>,
314 }
315
316 /// Trait for runnable experiments
317 pub trait RunnableExperiment<F : ClapFloat> {
318 /// Run all algorithms provided, or default algorithms if none provided, on the experiment.
319 fn runall(&self, cli : &CommandLineArgs,
320 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError;
321
322 /// Return algorithm default config
323 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>)
324 -> Named<AlgorithmConfig<F>>;
325 }
326
327 // *** macro boilerplate ***
328 macro_rules! impl_experiment {
329 ($type:ident, $reg_field:ident, $reg_convert:path) => {
330 // *** macro ***
331 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for
332 Named<$type<F, NoiseDistr, S, K, P, N>>
333 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
334 [usize; N] : Serialize,
335 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
336 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug,
337 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy,
338 AutoConvolution<P> : BoundedBy<F, K>,
339 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N>
340 + Copy + Serialize + std::fmt::Debug,
341 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
342 PlotLookup : Plotting<N>,
343 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
344 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
345 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
346 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug {
347
348 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>)
349 -> Named<AlgorithmConfig<F>> {
350 alg.to_named(
351 self.data
352 .algorithm_defaults
353 .get(&alg)
354 .map_or_else(|| alg.default_config(),
355 |config| config.clone())
356 .cli_override(cli)
357 )
358 }
359
360 fn runall(&self, cli : &CommandLineArgs,
361 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError {
362 // Get experiment configuration
363 let &Named {
364 name : ref experiment_name,
365 data : $type {
366 domain, sensor_count, ref noise_distr, sensor, spread, kernel,
367 ref μ_hat, /*regularisation,*/ kernel_plot_width, dataterm, noise_seed,
368 ..
369 }
370 } = self;
371 #[allow(deprecated)]
372 let regularisation = $reg_convert(self.data.$reg_field);
373
374 println!("{}\n{}",
375 format!("Performing experiment {}…", experiment_name).cyan(),
376 format!("{:?}", &self.data).bright_black());
377
378 // Set up output directory
379 let prefix = format!("{}/{}/", cli.outdir, self.name);
380
381 // Set up algorithms
382 let iterator_options = AlgIteratorOptions{
383 max_iter : cli.max_iter,
384 verbose_iter : cli.verbose_iter
385 .map_or(Verbose::Logarithmic(10),
386 |n| Verbose::Every(n)),
387 quiet : cli.quiet,
388 };
389 let algorithms = match (algs, self.data.dataterm) {
390 (Some(algs), _) => algs,
391 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()],
392 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()],
393 };
394
395 // Set up operators
396 let depth = DynamicDepth(8);
397 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth);
398 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel);
399
400 // Set up random number generator.
401 let mut rng = StdRng::seed_from_u64(noise_seed);
402
403 // Generate the data and calculate SSNR statistic
404 let b_hat = opA.apply(μ_hat);
405 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng);
406 let b = &b_hat + &noise;
407 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
408 // overloading log10 and conflicting with standard NumTraits one.
409 let stats = ExperimentStats::new(&b, &noise);
410
411 // Save experiment configuration and statistics
412 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t);
413 std::fs::create_dir_all(&prefix)?;
414 write_json(mkname_e("experiment"), self)?;
415 write_json(mkname_e("config"), cli)?;
416 write_json(mkname_e("stats"), &stats)?;
417
418 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread,
419 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?;
420
421 // Run the algorithm(s)
422 for named @ Named { name : alg_name, data : alg } in algorithms.iter() {
423 let this_prefix = format!("{}{}/", prefix, alg_name);
424
425 let running = || if !cli.quiet {
426 println!("{}\n{}\n{}",
427 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(),
428 format!("{:?}", iterator_options).bright_black(),
429 format!("{:?}", alg).bright_black());
430 };
431 let not_implemented = || {
432 let msg = format!("Algorithm “{alg_name}” not implemented for \
433 dataterm {dataterm:?} and regularisation {regularisation:?}. \
434 Skipping.").red();
435 eprintln!("{}", msg);
436 };
437 // Create Logger and IteratorFactory
438 let mut logger = Logger::new();
439 let findim_data = prepare_optimise_weights(&opA);
440 let inner_config : InnerSettings<F> = Default::default();
441 let inner_it = inner_config.iterator_options;
442 let logmap = |iter, Timed { cpu_time, data }| {
443 let IterInfo {
444 value,
445 n_spikes,
446 inner_iters,
447 merged,
448 pruned,
449 postprocessing,
450 this_iters,
451 ..
452 } = data;
453 let post_value = match (postprocessing, dataterm, regularisation) {
454 (Some(mut μ), DataTerm::L2Squared, Regularisation::Radon(α)) => {
455 // Comparison postprocessing is only implemented for the case handled
456 // by the FW variants.
457 optimise_weights(
458 &mut μ, &opA, &b, α, &findim_data, &inner_config,
459 inner_it
460 );
461 dataterm.value_at_residual(opA.apply(&μ) - &b)
462 + regularisation.apply(&μ)
463 },
464 _ => value,
465 };
466 CSVLog {
467 iter,
468 value,
469 post_value,
470 n_spikes,
471 cpu_time : cpu_time.as_secs_f64(),
472 inner_iters,
473 merged,
474 pruned,
475 this_iters
476 }
477 };
478 let iterator = iterator_options.instantiate()
479 .timed()
480 .mapped(logmap)
481 .into_log(&mut logger);
482 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]);
483
484 // Create plotter and directory if needed.
485 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 };
486 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid);
487
488 // Run the algorithm
489 let start = Instant::now();
490 let start_cpu = ProcessTime::now();
491 let μ = match alg {
492 AlgorithmConfig::FB(ref algconfig) => {
493 match (regularisation, dataterm) {
494 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
495 running();
496 pointsource_fb_reg(
497 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
498 iterator, plotter
499 )
500 },
501 (Regularisation::Radon(α), DataTerm::L2Squared) => {
502 running();
503 pointsource_fb_reg(
504 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
505 iterator, plotter
506 )
507 },
508 _ => {
509 not_implemented();
510 continue
511 }
512 }
513 },
514 AlgorithmConfig::PDPS(ref algconfig) => {
515 running();
516 match (regularisation, dataterm) {
517 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
518 pointsource_pdps_reg(
519 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
520 iterator, plotter, L2Squared
521 )
522 },
523 (Regularisation::Radon(α),DataTerm::L2Squared) => {
524 pointsource_pdps_reg(
525 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
526 iterator, plotter, L2Squared
527 )
528 },
529 (Regularisation::NonnegRadon(α), DataTerm::L1) => {
530 pointsource_pdps_reg(
531 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
532 iterator, plotter, L1
533 )
534 },
535 (Regularisation::Radon(α), DataTerm::L1) => {
536 pointsource_pdps_reg(
537 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
538 iterator, plotter, L1
539 )
540 },
541 }
542 },
543 AlgorithmConfig::FW(ref algconfig) => {
544 match (regularisation, dataterm) {
545 (Regularisation::Radon(α), DataTerm::L2Squared) => {
546 running();
547 pointsource_fw(&opA, &b, α, algconfig, iterator, plotter)
548 },
549 _ => {
550 not_implemented();
551 continue
552 }
553 }
554 }
555 };
556
557 let elapsed = start.elapsed().as_secs_f64();
558 let cpu_time = start_cpu.elapsed().as_secs_f64();
559
560 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow());
561
562 // Save results
563 println!("{}", "Saving results…".green());
564
565 let mkname = |t| format!("{prefix}{alg_name}_{t}");
566
567 write_json(mkname("config.json"), &named)?;
568 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?;
569 μ.write_csv(mkname("reco.txt"))?;
570 logger.write_csv(mkname("log.txt"))?;
571 }
572
573 Ok(())
574 }
575 }
576 // *** macro end boiler plate ***
577 }}
578 // *** actual code ***
579
580 impl_experiment!(ExperimentV2, regularisation, std::convert::identity);
581
582 /// Plot experiment setup
583 #[replace_float_literals(F::cast_from(literal))]
584 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>(
585 cli : &CommandLineArgs,
586 prefix : &String,
587 domain : &Cube<F, N>,
588 sensor : &Sensor,
589 kernel : &Kernel,
590 spread : &Spread,
591 μ_hat : &DiscreteMeasure<Loc<F, N>, F>,
592 op𝒟 : &𝒟,
593 opA : &A,
594 b_hat : &A::Observable,
595 b : &A::Observable,
596 kernel_plot_width : F,
597 ) -> DynError
598 where F : Float + ToNalgebraRealField,
599 Sensor : RealMapping<F, N> + Support<F, N> + Clone,
600 Spread : RealMapping<F, N> + Support<F, N> + Clone,
601 Kernel : RealMapping<F, N> + Support<F, N>,
602 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>,
603 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>,
604 𝒟::Codomain : RealMapping<F, N>,
605 A : ForwardModel<Loc<F, N>, F>,
606 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>,
607 PlotLookup : Plotting<N>,
608 Cube<F, N> : SetOrd {
609
610 if cli.plot < PlotLevel::Data {
611 return Ok(())
612 }
613
614 let base = Convolution(sensor.clone(), spread.clone());
615
616 let resolution = if N==1 { 100 } else { 40 };
617 let pfx = |n| format!("{}{}", prefix, n);
618 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]);
619
620 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string());
621 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string());
622 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string());
623 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string());
624
625 let plotgrid2 = lingrid(&domain, &[resolution; N]);
626
627 let ω_hat = op𝒟.apply(μ_hat);
628 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b);
629 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string());
630 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"),
631 "noise Aᵀ(Aμ̂ - b)".to_string());
632
633 let preadj_b = opA.preadjoint().apply(b);
634 let preadj_b_hat = opA.preadjoint().apply(b_hat);
635 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds());
636 PlotLookup::plot_into_file_spikes(
637 "Aᵀb".to_string(), &preadj_b,
638 "Aᵀb̂".to_string(), Some(&preadj_b_hat),
639 plotgrid2, None, &μ_hat,
640 pfx("omega_b")
641 );
642
643 // Save true solution and observables
644 let pfx = |n| format!("{}{}", prefix, n);
645 μ_hat.write_csv(pfx("orig.txt"))?;
646 opA.write_observable(&b_hat, pfx("b_hat"))?;
647 opA.write_observable(&b, pfx("b_noisy"))
648 }
649
650 //
651 // Deprecated interface
652 //
276 653
277 /// Struct for experiment configurations 654 /// Struct for experiment configurations
278 #[derive(Debug, Clone, Serialize)] 655 #[derive(Debug, Clone, Serialize)]
279 pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> 656 pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize>
280 where F : Float, 657 where F : Float,
299 /// Kernel $ρ$ of $𝒟$. 676 /// Kernel $ρ$ of $𝒟$.
300 pub kernel : K, 677 pub kernel : K,
301 /// True point sources 678 /// True point sources
302 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, 679 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>,
303 /// Regularisation parameter 680 /// Regularisation parameter
681 #[deprecated(note = "Use [`ExperimentV2`], which replaces `α` by more generic `regularisation`")]
304 pub α : F, 682 pub α : F,
305 /// For plotting : how wide should the kernels be plotted 683 /// For plotting : how wide should the kernels be plotted
306 pub kernel_plot_width : F, 684 pub kernel_plot_width : F,
307 /// Data term 685 /// Data term
308 pub dataterm : DataTerm, 686 pub dataterm : DataTerm,
309 /// A map of default configurations for algorithms 687 /// A map of default configurations for algorithms
310 #[serde(skip)] 688 #[serde(skip)]
311 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, 689 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>,
312 } 690 }
313 691
314 /// Trait for runnable experiments 692 impl_experiment!(Experiment, α, Regularisation::NonnegRadon);
315 pub trait RunnableExperiment<F : ClapFloat> {
316 /// Run all algorithms provided, or default algorithms if none provided, on the experiment.
317 fn runall(&self, cli : &CommandLineArgs,
318 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError;
319
320 /// Return algorithm default config
321 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>)
322 -> Named<AlgorithmConfig<F>>;
323 }
324
325 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for
326 Named<Experiment<F, NoiseDistr, S, K, P, N>>
327 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
328 [usize; N] : Serialize,
329 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
330 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug,
331 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy,
332 AutoConvolution<P> : BoundedBy<F, K>,
333 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N>
334 + Copy + Serialize + std::fmt::Debug,
335 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
336 PlotLookup : Plotting<N>,
337 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
338 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
339 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
340 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug {
341
342 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>)
343 -> Named<AlgorithmConfig<F>> {
344 alg.to_named(
345 self.data
346 .algorithm_defaults
347 .get(&alg)
348 .map_or_else(|| alg.default_config(),
349 |config| config.clone())
350 .cli_override(cli)
351 )
352 }
353
354 fn runall(&self, cli : &CommandLineArgs,
355 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError {
356 // Get experiment configuration
357 let &Named {
358 name : ref experiment_name,
359 data : Experiment {
360 domain, sensor_count, ref noise_distr, sensor, spread, kernel,
361 ref μ_hat, α, kernel_plot_width, dataterm, noise_seed,
362 ..
363 }
364 } = self;
365
366 println!("{}\n{}",
367 format!("Performing experiment {}…", experiment_name).cyan(),
368 format!("{:?}", &self.data).bright_black());
369
370 // Set up output directory
371 let prefix = format!("{}/{}/", cli.outdir, self.name);
372
373 // Set up algorithms
374 let iterator_options = AlgIteratorOptions{
375 max_iter : cli.max_iter,
376 verbose_iter : cli.verbose_iter
377 .map_or(Verbose::Logarithmic(10),
378 |n| Verbose::Every(n)),
379 quiet : cli.quiet,
380 };
381 let algorithms = match (algs, self.data.dataterm) {
382 (Some(algs), _) => algs,
383 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()],
384 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()],
385 };
386
387 // Set up operators
388 let depth = DynamicDepth(8);
389 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth);
390 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel);
391
392 // Set up random number generator.
393 let mut rng = StdRng::seed_from_u64(noise_seed);
394
395 // Generate the data and calculate SSNR statistic
396 let b_hat = opA.apply(μ_hat);
397 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng);
398 let b = &b_hat + &noise;
399 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
400 // overloading log10 and conflicting with standard NumTraits one.
401 let stats = ExperimentStats::new(&b, &noise);
402
403 // Save experiment configuration and statistics
404 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t);
405 std::fs::create_dir_all(&prefix)?;
406 write_json(mkname_e("experiment"), self)?;
407 write_json(mkname_e("config"), cli)?;
408 write_json(mkname_e("stats"), &stats)?;
409
410 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread,
411 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?;
412
413 // Run the algorithm(s)
414 for named @ Named { name : alg_name, data : alg } in algorithms.iter() {
415 let this_prefix = format!("{}{}/", prefix, alg_name);
416
417 let running = || if !cli.quiet {
418 println!("{}\n{}\n{}",
419 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(),
420 format!("{:?}", iterator_options).bright_black(),
421 format!("{:?}", alg).bright_black());
422 };
423
424 // Create Logger and IteratorFactory
425 let mut logger = Logger::new();
426 let findim_data = prepare_optimise_weights(&opA);
427 let inner_config : InnerSettings<F> = Default::default();
428 let inner_it = inner_config.iterator_options;
429 let logmap = |iter, Timed { cpu_time, data }| {
430 let IterInfo {
431 value,
432 n_spikes,
433 inner_iters,
434 merged,
435 pruned,
436 postprocessing,
437 this_iters,
438 ..
439 } = data;
440 let post_value = match postprocessing {
441 None => value,
442 Some(mut μ) => {
443 match dataterm {
444 DataTerm::L2Squared => {
445 optimise_weights(
446 &mut μ, &opA, &b, α, &findim_data, &inner_config,
447 inner_it
448 );
449 dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon)
450 },
451 _ => value,
452 }
453 }
454 };
455 CSVLog {
456 iter,
457 value,
458 post_value,
459 n_spikes,
460 cpu_time : cpu_time.as_secs_f64(),
461 inner_iters,
462 merged,
463 pruned,
464 this_iters
465 }
466 };
467 let iterator = iterator_options.instantiate()
468 .timed()
469 .mapped(logmap)
470 .into_log(&mut logger);
471 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]);
472
473 // Create plotter and directory if needed.
474 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 };
475 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid);
476
477 // Run the algorithm
478 let start = Instant::now();
479 let start_cpu = ProcessTime::now();
480 let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) {
481 (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => {
482 running();
483 pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter)
484 },
485 (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => {
486 running();
487 pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter)
488 },
489 (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => {
490 running();
491 pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared)
492 },
493 (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => {
494 running();
495 pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1)
496 },
497 _ => {
498 let msg = format!("Algorithm “{alg_name}” not implemented for \
499 dataterm {dataterm:?}. Skipping.").red();
500 eprintln!("{}", msg);
501 continue
502 }
503 };
504 let elapsed = start.elapsed().as_secs_f64();
505 let cpu_time = start_cpu.elapsed().as_secs_f64();
506
507 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow());
508
509 // Save results
510 println!("{}", "Saving results…".green());
511
512 let mkname = |t| format!("{prefix}{alg_name}_{t}");
513
514 write_json(mkname("config.json"), &named)?;
515 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?;
516 μ.write_csv(mkname("reco.txt"))?;
517 logger.write_csv(mkname("log.txt"))?;
518 }
519
520 Ok(())
521 }
522 }
523
524 /// Plot experiment setup
525 #[replace_float_literals(F::cast_from(literal))]
526 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>(
527 cli : &CommandLineArgs,
528 prefix : &String,
529 domain : &Cube<F, N>,
530 sensor : &Sensor,
531 kernel : &Kernel,
532 spread : &Spread,
533 μ_hat : &DiscreteMeasure<Loc<F, N>, F>,
534 op𝒟 : &𝒟,
535 opA : &A,
536 b_hat : &A::Observable,
537 b : &A::Observable,
538 kernel_plot_width : F,
539 ) -> DynError
540 where F : Float + ToNalgebraRealField,
541 Sensor : RealMapping<F, N> + Support<F, N> + Clone,
542 Spread : RealMapping<F, N> + Support<F, N> + Clone,
543 Kernel : RealMapping<F, N> + Support<F, N>,
544 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>,
545 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>,
546 𝒟::Codomain : RealMapping<F, N>,
547 A : ForwardModel<Loc<F, N>, F>,
548 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>,
549 PlotLookup : Plotting<N>,
550 Cube<F, N> : SetOrd {
551
552 if cli.plot < PlotLevel::Data {
553 return Ok(())
554 }
555
556 let base = Convolution(sensor.clone(), spread.clone());
557
558 let resolution = if N==1 { 100 } else { 40 };
559 let pfx = |n| format!("{}{}", prefix, n);
560 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]);
561
562 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string());
563 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string());
564 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string());
565 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string());
566
567 let plotgrid2 = lingrid(&domain, &[resolution; N]);
568
569 let ω_hat = op𝒟.apply(μ_hat);
570 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b);
571 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string());
572 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"),
573 "noise Aᵀ(Aμ̂ - b)".to_string());
574
575 let preadj_b = opA.preadjoint().apply(b);
576 let preadj_b_hat = opA.preadjoint().apply(b_hat);
577 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds());
578 PlotLookup::plot_into_file_spikes(
579 "Aᵀb".to_string(), &preadj_b,
580 "Aᵀb̂".to_string(), Some(&preadj_b_hat),
581 plotgrid2, None, &μ_hat,
582 pfx("omega_b")
583 );
584
585 // Save true solution and observables
586 let pfx = |n| format!("{}{}", prefix, n);
587 μ_hat.write_csv(pfx("orig.txt"))?;
588 opA.write_observable(&b_hat, pfx("b_hat"))?;
589 opA.write_observable(&b, pfx("b_noisy"))
590 }
591

mercurial