365 /// A map of default configurations for algorithms |
432 /// A map of default configurations for algorithms |
366 #[serde(skip)] |
433 #[serde(skip)] |
367 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, |
434 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, |
368 } |
435 } |
369 |
436 |
|
437 #[derive(Debug, Clone, Serialize)] |
|
438 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize> |
|
439 where F : Float, |
|
440 [usize; N] : Serialize, |
|
441 NoiseDistr : Distribution<F>, |
|
442 S : Sensor<F, N>, |
|
443 P : Spread<F, N>, |
|
444 K : SimpleConvolutionKernel<F, N>, |
|
445 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, |
|
446 { |
|
447 /// Basic setup |
|
448 pub base : ExperimentV2<F, NoiseDistr, S, K, P, N>, |
|
449 /// Weight of TV term |
|
450 pub λ : F, |
|
451 /// Bias function |
|
452 pub bias : B, |
|
453 } |
|
454 |
370 /// Trait for runnable experiments |
455 /// Trait for runnable experiments |
371 pub trait RunnableExperiment<F : ClapFloat> { |
456 pub trait RunnableExperiment<F : ClapFloat> { |
372 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. |
457 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. |
373 fn runall(&self, cli : &CommandLineArgs, |
458 fn runall(&self, cli : &CommandLineArgs, |
374 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; |
459 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; |
375 |
460 |
376 /// Return algorithm default config |
461 /// Return algorithm default config |
377 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
462 fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>>; |
378 -> Named<AlgorithmConfig<F>>; |
463 } |
379 } |
464 |
380 |
465 /// Helper function to print experiment start message and save setup. |
381 // *** macro boilerplate *** |
466 /// Returns saving prefix. |
382 macro_rules! impl_experiment { |
467 fn start_experiment<E, S>( |
383 ($type:ident, $reg_field:ident, $reg_convert:path) => { |
468 experiment : &Named<E>, |
384 // *** macro *** |
469 cli : &CommandLineArgs, |
|
470 stats : S, |
|
471 ) -> DynResult<String> |
|
472 where |
|
473 E : Serialize + std::fmt::Debug, |
|
474 S : Serialize, |
|
475 { |
|
476 let Named { name : experiment_name, data } = experiment; |
|
477 |
|
478 println!("{}\n{}", |
|
479 format!("Performing experiment {}…", experiment_name).cyan(), |
|
480 format!("{:?}", data).bright_black()); |
|
481 |
|
482 // Set up output directory |
|
483 let prefix = format!("{}/{}/", cli.outdir, experiment_name); |
|
484 |
|
485 // Save experiment configuration and statistics |
|
486 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); |
|
487 std::fs::create_dir_all(&prefix)?; |
|
488 write_json(mkname_e("experiment"), experiment)?; |
|
489 write_json(mkname_e("config"), cli)?; |
|
490 write_json(mkname_e("stats"), &stats)?; |
|
491 |
|
492 Ok(prefix) |
|
493 } |
|
494 |
|
495 /// Error codes for running an algorithm on an experiment. |
|
496 enum RunError { |
|
497 /// Algorithm not implemented for this experiment |
|
498 NotImplemented, |
|
499 } |
|
500 |
|
501 use RunError::*; |
|
502 |
|
503 type DoRunAllIt<'a, F, const N : usize> = LoggingIteratorFactory< |
|
504 'a, |
|
505 Timed<IterInfo<F, N>>, |
|
506 TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F, N>>> |
|
507 >; |
|
508 |
|
509 /// Helper function to run all algorithms on an experiment. |
|
510 fn do_runall<F : Float, Z, const N : usize>( |
|
511 experiment_name : &String, |
|
512 prefix : &String, |
|
513 cli : &CommandLineArgs, |
|
514 algorithms : Vec<Named<AlgorithmConfig<F>>>, |
|
515 plotgrid : LinSpace<Loc<F, N>, [usize; N]>, |
|
516 mut save_extra : impl FnMut(String, Z) -> DynError, |
|
517 mut do_alg : impl FnMut( |
|
518 &AlgorithmConfig<F>, |
|
519 DoRunAllIt<F, N>, |
|
520 SeqPlotter<F, N>, |
|
521 String, |
|
522 ) -> Result<(RNDM<F, N>, Z), RunError>, |
|
523 ) -> DynError |
|
524 where |
|
525 PlotLookup : Plotting<N>, |
|
526 { |
|
527 let mut logs = Vec::new(); |
|
528 |
|
529 let iterator_options = AlgIteratorOptions{ |
|
530 max_iter : cli.max_iter, |
|
531 verbose_iter : cli.verbose_iter |
|
532 .map_or(Verbose::Logarithmic(10), |
|
533 |n| Verbose::Every(n)), |
|
534 quiet : cli.quiet, |
|
535 }; |
|
536 |
|
537 // Run the algorithm(s) |
|
538 for named @ Named { name : alg_name, data : alg } in algorithms.iter() { |
|
539 let this_prefix = format!("{}{}/", prefix, alg_name); |
|
540 |
|
541 // Create Logger and IteratorFactory |
|
542 let mut logger = Logger::new(); |
|
543 let iterator = iterator_options.instantiate() |
|
544 .timed() |
|
545 .into_log(&mut logger); |
|
546 |
|
547 let running = if !cli.quiet { |
|
548 format!("{}\n{}\n{}\n", |
|
549 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), |
|
550 format!("{:?}", iterator_options).bright_black(), |
|
551 format!("{:?}", alg).bright_black()) |
|
552 } else { |
|
553 "".to_string() |
|
554 }; |
|
555 // |
|
556 // The following is for postprocessing, which has been disabled anyway. |
|
557 // |
|
558 // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation { |
|
559 // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), |
|
560 // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), |
|
561 // }; |
|
562 //let findim_data = reg.prepare_optimise_weights(&opA, &b); |
|
563 //let inner_config : InnerSettings<F> = Default::default(); |
|
564 //let inner_it = inner_config.iterator_options; |
|
565 |
|
566 // Create plotter and directory if needed. |
|
567 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
|
568 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()); |
|
569 |
|
570 let start = Instant::now(); |
|
571 let start_cpu = ProcessTime::now(); |
|
572 |
|
573 let (μ, z) = match do_alg(alg, iterator, plotter, running) { |
|
574 Ok(μ) => μ, |
|
575 Err(RunError::NotImplemented) => { |
|
576 let msg = format!("Algorithm “{alg_name}” not implemented for {experiment_name}. \ |
|
577 Skipping.").red(); |
|
578 eprintln!("{}", msg); |
|
579 continue |
|
580 } |
|
581 }; |
|
582 |
|
583 let elapsed = start.elapsed().as_secs_f64(); |
|
584 let cpu_time = start_cpu.elapsed().as_secs_f64(); |
|
585 |
|
586 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); |
|
587 |
|
588 // Save results |
|
589 println!("{}", "Saving results …".green()); |
|
590 |
|
591 let mkname = |t| format!("{prefix}{alg_name}_{t}"); |
|
592 |
|
593 write_json(mkname("config.json"), &named)?; |
|
594 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; |
|
595 μ.write_csv(mkname("reco.txt"))?; |
|
596 save_extra(mkname(""), z)?; |
|
597 //logger.write_csv(mkname("log.txt"))?; |
|
598 logs.push((mkname("log.txt"), logger)); |
|
599 } |
|
600 |
|
601 save_logs(logs) |
|
602 } |
|
603 |
|
604 #[replace_float_literals(F::cast_from(literal))] |
385 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for |
605 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for |
386 Named<$type<F, NoiseDistr, S, K, P, N>> |
606 Named<ExperimentV2<F, NoiseDistr, S, K, P, N>> |
387 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, |
607 where |
388 [usize; N] : Serialize, |
608 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, |
389 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
609 [usize; N] : Serialize, |
390 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
610 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
391 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy |
611 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
392 // TODO: shold not have differentiability as a requirement, but |
612 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy |
393 // decide availability of sliding based on it. |
613 // TODO: shold not have differentiability as a requirement, but |
394 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
614 // decide availability of sliding based on it. |
395 // TODO: very weird that rust only compiles with Differentiable |
615 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
396 // instead of the above one on references, which is required by |
616 // TODO: very weird that rust only compiles with Differentiable |
397 // poitsource_sliding_fb_reg. |
617 // instead of the above one on references, which is required by |
398 + DifferentiableRealMapping<F, N> |
618 // poitsource_sliding_fb_reg. |
399 + Lipschitz<L2, FloatType=F>, |
619 + DifferentiableRealMapping<F, N> |
400 // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
620 + Lipschitz<L2, FloatType=F>, |
401 AutoConvolution<P> : BoundedBy<F, K>, |
621 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb. |
402 K : SimpleConvolutionKernel<F, N> |
622 AutoConvolution<P> : BoundedBy<F, K>, |
403 + LocalAnalysis<F, Bounds<F>, N> |
623 K : SimpleConvolutionKernel<F, N> |
404 + Copy + Serialize + std::fmt::Debug, |
624 + LocalAnalysis<F, Bounds<F>, N> |
405 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
625 + Copy + Serialize + std::fmt::Debug, |
406 PlotLookup : Plotting<N>, |
626 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
407 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
627 PlotLookup : Plotting<N>, |
408 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
628 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
409 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
629 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
410 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug { |
630 RNDM<F, N> : SpikeMerging<F>, |
411 |
631 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug |
412 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
632 { |
413 -> Named<AlgorithmConfig<F>> { |
633 |
414 alg.to_named( |
634 fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> { |
415 self.data |
635 self.data.algorithm_defaults.get(&alg).cloned() |
416 .algorithm_defaults |
|
417 .get(&alg) |
|
418 .map_or_else(|| alg.default_config(), |
|
419 |config| config.clone()) |
|
420 .cli_override(cli) |
|
421 ) |
|
422 } |
636 } |
423 |
637 |
424 fn runall(&self, cli : &CommandLineArgs, |
638 fn runall(&self, cli : &CommandLineArgs, |
425 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
639 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
426 // Get experiment configuration |
640 // Get experiment configuration |
427 let &Named { |
641 let &Named { |
428 name : ref experiment_name, |
642 name : ref experiment_name, |
429 data : $type { |
643 data : ExperimentV2 { |
430 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
644 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
431 ref μ_hat, /*regularisation,*/ kernel_plot_width, dataterm, noise_seed, |
645 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, |
432 .. |
646 .. |
433 } |
647 } |
434 } = self; |
648 } = self; |
435 let regularisation = $reg_convert(self.data.$reg_field); |
|
436 |
|
437 println!("{}\n{}", |
|
438 format!("Performing experiment {}…", experiment_name).cyan(), |
|
439 format!("{:?}", &self.data).bright_black()); |
|
440 |
|
441 // Set up output directory |
|
442 let prefix = format!("{}/{}/", cli.outdir, self.name); |
|
443 |
649 |
444 // Set up algorithms |
650 // Set up algorithms |
445 let iterator_options = AlgIteratorOptions{ |
651 let algorithms = match (algs, dataterm) { |
446 max_iter : cli.max_iter, |
|
447 verbose_iter : cli.verbose_iter |
|
448 .map_or(Verbose::Logarithmic(10), |
|
449 |n| Verbose::Every(n)), |
|
450 quiet : cli.quiet, |
|
451 }; |
|
452 let algorithms = match (algs, self.data.dataterm) { |
|
453 (Some(algs), _) => algs, |
652 (Some(algs), _) => algs, |
454 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], |
653 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], |
455 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], |
654 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], |
456 }; |
655 }; |
457 |
656 |
462 |
661 |
463 // Set up random number generator. |
662 // Set up random number generator. |
464 let mut rng = StdRng::seed_from_u64(noise_seed); |
663 let mut rng = StdRng::seed_from_u64(noise_seed); |
465 |
664 |
466 // Generate the data and calculate SSNR statistic |
665 // Generate the data and calculate SSNR statistic |
467 let b_hat = opA.apply(μ_hat); |
666 let b_hat : DVector<_> = opA.apply(μ_hat); |
468 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); |
667 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); |
469 let b = &b_hat + &noise; |
668 let b = &b_hat + &noise; |
470 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
669 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
471 // overloading log10 and conflicting with standard NumTraits one. |
670 // overloading log10 and conflicting with standard NumTraits one. |
472 let stats = ExperimentStats::new(&b, &noise); |
671 let stats = ExperimentStats::new(&b, &noise); |
473 |
672 |
474 // Save experiment configuration and statistics |
673 let prefix = start_experiment(&self, cli, stats)?; |
475 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); |
|
476 std::fs::create_dir_all(&prefix)?; |
|
477 write_json(mkname_e("experiment"), self)?; |
|
478 write_json(mkname_e("config"), cli)?; |
|
479 write_json(mkname_e("stats"), &stats)?; |
|
480 |
674 |
481 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
675 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
482 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
676 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
483 |
677 |
484 // Run the algorithm(s) |
678 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
485 for named @ Named { name : alg_name, data : alg } in algorithms.iter() { |
679 |
486 let this_prefix = format!("{}{}/", prefix, alg_name); |
680 let save_extra = |_, ()| Ok(()); |
487 |
681 |
488 let running = || if !cli.quiet { |
682 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, |
489 println!("{}\n{}\n{}", |
683 |alg, iterator, plotter, running| |
490 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), |
684 { |
491 format!("{:?}", iterator_options).bright_black(), |
|
492 format!("{:?}", alg).bright_black()); |
|
493 }; |
|
494 let not_implemented = || { |
|
495 let msg = format!("Algorithm “{alg_name}” not implemented for \ |
|
496 dataterm {dataterm:?} and regularisation {regularisation:?}. \ |
|
497 Skipping.").red(); |
|
498 eprintln!("{}", msg); |
|
499 }; |
|
500 // Create Logger and IteratorFactory |
|
501 let mut logger = Logger::new(); |
|
502 let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation { |
|
503 Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), |
|
504 Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), |
|
505 }; |
|
506 let findim_data = reg.prepare_optimise_weights(&opA, &b); |
|
507 let inner_config : InnerSettings<F> = Default::default(); |
|
508 let inner_it = inner_config.iterator_options; |
|
509 let logmap = |iter, Timed { cpu_time, data }| { |
|
510 let IterInfo { |
|
511 value, |
|
512 n_spikes, |
|
513 inner_iters, |
|
514 merged, |
|
515 pruned, |
|
516 postprocessing, |
|
517 this_iters, |
|
518 .. |
|
519 } = data; |
|
520 let post_value = match (postprocessing, dataterm) { |
|
521 (Some(mut μ), DataTerm::L2Squared) => { |
|
522 // Comparison postprocessing is only implemented for the case handled |
|
523 // by the FW variants. |
|
524 reg.optimise_weights( |
|
525 &mut μ, &opA, &b, &findim_data, &inner_config, |
|
526 inner_it |
|
527 ); |
|
528 dataterm.value_at_residual(opA.apply(&μ) - &b) |
|
529 + regularisation.apply(&μ) |
|
530 }, |
|
531 _ => value, |
|
532 }; |
|
533 CSVLog { |
|
534 iter, |
|
535 value, |
|
536 post_value, |
|
537 n_spikes, |
|
538 cpu_time : cpu_time.as_secs_f64(), |
|
539 inner_iters, |
|
540 merged, |
|
541 pruned, |
|
542 this_iters |
|
543 } |
|
544 }; |
|
545 let iterator = iterator_options.instantiate() |
|
546 .timed() |
|
547 .mapped(logmap) |
|
548 .into_log(&mut logger); |
|
549 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
|
550 |
|
551 // Create plotter and directory if needed. |
|
552 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
|
553 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid); |
|
554 |
|
555 // Run the algorithm |
|
556 let start = Instant::now(); |
|
557 let start_cpu = ProcessTime::now(); |
|
558 let μ = match alg { |
685 let μ = match alg { |
559 AlgorithmConfig::FB(ref algconfig) => { |
686 AlgorithmConfig::FB(ref algconfig) => { |
560 match (regularisation, dataterm) { |
687 match (regularisation, dataterm) { |
561 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
688 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ |
562 running(); |
689 print!("{running}"); |
563 pointsource_fb_reg( |
690 pointsource_fb_reg( |
564 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
691 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
565 iterator, plotter |
692 iterator, plotter |
566 ) |
693 ) |
567 }, |
694 }), |
568 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
695 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ |
569 running(); |
696 print!("{running}"); |
570 pointsource_fb_reg( |
697 pointsource_fb_reg( |
571 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
698 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
572 iterator, plotter |
699 iterator, plotter |
573 ) |
700 ) |
574 }, |
701 }), |
575 _ => { |
702 _ => Err(NotImplemented) |
576 not_implemented(); |
|
577 continue |
|
578 } |
|
579 } |
703 } |
580 }, |
704 }, |
581 AlgorithmConfig::FISTA(ref algconfig) => { |
705 AlgorithmConfig::FISTA(ref algconfig) => { |
582 match (regularisation, dataterm) { |
706 match (regularisation, dataterm) { |
583 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
707 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ |
584 running(); |
708 print!("{running}"); |
585 pointsource_fista_reg( |
709 pointsource_fista_reg( |
586 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
710 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
587 iterator, plotter |
711 iterator, plotter |
588 ) |
712 ) |
589 }, |
713 }), |
590 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
714 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ |
591 running(); |
715 print!("{running}"); |
592 pointsource_fista_reg( |
716 pointsource_fista_reg( |
593 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
717 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
594 iterator, plotter |
718 iterator, plotter |
595 ) |
719 ) |
596 }, |
720 }), |
597 _ => { |
721 _ => Err(NotImplemented), |
598 not_implemented(); |
|
599 continue |
|
600 } |
|
601 } |
722 } |
602 }, |
723 }, |
603 AlgorithmConfig::RadonFB(ref algconfig) => { |
724 AlgorithmConfig::RadonFB(ref algconfig) => { |
604 match (regularisation, dataterm) { |
725 match (regularisation, dataterm) { |
605 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
726 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ |
606 running(); |
727 print!("{running}"); |
607 pointsource_radon_fb_reg( |
728 pointsource_radon_fb_reg( |
608 &opA, &b, NonnegRadonRegTerm(α), algconfig, |
729 &opA, &b, NonnegRadonRegTerm(α), algconfig, |
609 iterator, plotter |
730 iterator, plotter |
610 ) |
731 ) |
611 }, |
732 }), |
612 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
733 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ |
613 running(); |
734 print!("{running}"); |
614 pointsource_radon_fb_reg( |
735 pointsource_radon_fb_reg( |
615 &opA, &b, RadonRegTerm(α), algconfig, |
736 &opA, &b, RadonRegTerm(α), algconfig, |
616 iterator, plotter |
737 iterator, plotter |
617 ) |
738 ) |
618 }, |
739 }), |
619 _ => { |
740 _ => Err(NotImplemented), |
620 not_implemented(); |
|
621 continue |
|
622 } |
|
623 } |
741 } |
624 }, |
742 }, |
625 AlgorithmConfig::RadonFISTA(ref algconfig) => { |
743 AlgorithmConfig::RadonFISTA(ref algconfig) => { |
626 match (regularisation, dataterm) { |
744 match (regularisation, dataterm) { |
627 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
745 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ |
628 running(); |
746 print!("{running}"); |
629 pointsource_radon_fista_reg( |
747 pointsource_radon_fista_reg( |
630 &opA, &b, NonnegRadonRegTerm(α), algconfig, |
748 &opA, &b, NonnegRadonRegTerm(α), algconfig, |
631 iterator, plotter |
749 iterator, plotter |
632 ) |
750 ) |
633 }, |
751 }), |
634 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
752 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ |
635 running(); |
753 print!("{running}"); |
636 pointsource_radon_fista_reg( |
754 pointsource_radon_fista_reg( |
637 &opA, &b, RadonRegTerm(α), algconfig, |
755 &opA, &b, RadonRegTerm(α), algconfig, |
638 iterator, plotter |
756 iterator, plotter |
639 ) |
757 ) |
640 }, |
758 }), |
641 _ => { |
759 _ => Err(NotImplemented), |
642 not_implemented(); |
|
643 continue |
|
644 } |
|
645 } |
760 } |
646 }, |
761 }, |
647 AlgorithmConfig::SlidingFB(ref algconfig) => { |
762 AlgorithmConfig::SlidingFB(ref algconfig) => { |
648 match (regularisation, dataterm) { |
763 match (regularisation, dataterm) { |
649 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
764 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ |
650 running(); |
765 print!("{running}"); |
651 pointsource_sliding_fb_reg( |
766 pointsource_sliding_fb_reg( |
652 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
767 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
653 iterator, plotter |
768 iterator, plotter |
654 ) |
769 ) |
655 }, |
770 }), |
656 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
771 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ |
657 running(); |
772 print!("{running}"); |
658 pointsource_sliding_fb_reg( |
773 pointsource_sliding_fb_reg( |
659 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
774 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
660 iterator, plotter |
775 iterator, plotter |
661 ) |
776 ) |
662 }, |
777 }), |
663 _ => { |
778 _ => Err(NotImplemented), |
664 not_implemented(); |
|
665 continue |
|
666 } |
|
667 } |
779 } |
668 }, |
780 }, |
669 AlgorithmConfig::PDPS(ref algconfig) => { |
781 AlgorithmConfig::PDPS(ref algconfig) => { |
670 running(); |
782 print!("{running}"); |
671 match (regularisation, dataterm) { |
783 match (regularisation, dataterm) { |
672 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
784 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ |
673 pointsource_pdps_reg( |
785 pointsource_pdps_reg( |
674 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
786 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
675 iterator, plotter, L2Squared |
787 iterator, plotter, L2Squared |
676 ) |
788 ) |
677 }, |
789 }), |
678 (Regularisation::Radon(α),DataTerm::L2Squared) => { |
790 (Regularisation::Radon(α),DataTerm::L2Squared) => Ok({ |
679 pointsource_pdps_reg( |
791 pointsource_pdps_reg( |
680 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
792 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
681 iterator, plotter, L2Squared |
793 iterator, plotter, L2Squared |
682 ) |
794 ) |
683 }, |
795 }), |
684 (Regularisation::NonnegRadon(α), DataTerm::L1) => { |
796 (Regularisation::NonnegRadon(α), DataTerm::L1) => Ok({ |
685 pointsource_pdps_reg( |
797 pointsource_pdps_reg( |
686 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
798 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
687 iterator, plotter, L1 |
799 iterator, plotter, L1 |
688 ) |
800 ) |
689 }, |
801 }), |
690 (Regularisation::Radon(α), DataTerm::L1) => { |
802 (Regularisation::Radon(α), DataTerm::L1) => Ok({ |
691 pointsource_pdps_reg( |
803 pointsource_pdps_reg( |
692 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
804 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
693 iterator, plotter, L1 |
805 iterator, plotter, L1 |
694 ) |
806 ) |
695 }, |
807 }), |
696 } |
808 } |
697 }, |
809 }, |
698 AlgorithmConfig::FW(ref algconfig) => { |
810 AlgorithmConfig::FW(ref algconfig) => { |
699 match (regularisation, dataterm) { |
811 match (regularisation, dataterm) { |
700 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
812 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ |
701 running(); |
813 print!("{running}"); |
702 pointsource_fw_reg(&opA, &b, RadonRegTerm(α), |
814 pointsource_fw_reg(&opA, &b, RadonRegTerm(α), |
703 algconfig, iterator, plotter) |
815 algconfig, iterator, plotter) |
704 }, |
816 }), |
705 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
817 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ |
706 running(); |
818 print!("{running}"); |
707 pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), |
819 pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), |
708 algconfig, iterator, plotter) |
820 algconfig, iterator, plotter) |
709 }, |
821 }), |
710 _ => { |
822 _ => Err(NotImplemented), |
711 not_implemented(); |
|
712 continue |
|
713 } |
|
714 } |
823 } |
|
824 }, |
|
825 _ => Err(NotImplemented), |
|
826 }?; |
|
827 Ok((μ, ())) |
|
828 }) |
|
829 } |
|
830 } |
|
831 |
|
832 |
|
833 #[replace_float_literals(F::cast_from(literal))] |
|
834 impl<F, NoiseDistr, S, K, P, B, const N : usize> RunnableExperiment<F> for |
|
835 Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> |
|
836 where |
|
837 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, |
|
838 [usize; N] : Serialize, |
|
839 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
|
840 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
|
841 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy |
|
842 // TODO: shold not have differentiability as a requirement, but |
|
843 // decide availability of sliding based on it. |
|
844 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
|
845 // TODO: very weird that rust only compiles with Differentiable |
|
846 // instead of the above one on references, which is required by |
|
847 // poitsource_sliding_fb_reg. |
|
848 + DifferentiableRealMapping<F, N> |
|
849 + Lipschitz<L2, FloatType=F>, |
|
850 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb. |
|
851 AutoConvolution<P> : BoundedBy<F, K>, |
|
852 K : SimpleConvolutionKernel<F, N> |
|
853 + LocalAnalysis<F, Bounds<F>, N> |
|
854 + Copy + Serialize + std::fmt::Debug, |
|
855 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
|
856 PlotLookup : Plotting<N>, |
|
857 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
|
858 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
859 RNDM<F, N> : SpikeMerging<F>, |
|
860 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, |
|
861 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, |
|
862 { |
|
863 |
|
864 fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> { |
|
865 self.data.base.algorithm_defaults.get(&alg).cloned() |
|
866 } |
|
867 |
|
868 fn runall(&self, cli : &CommandLineArgs, |
|
869 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
|
870 // Get experiment configuration |
|
871 let &Named { |
|
872 name : ref experiment_name, |
|
873 data : ExperimentBiased { |
|
874 λ, |
|
875 ref bias, |
|
876 base : ExperimentV2 { |
|
877 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
|
878 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, |
|
879 .. |
715 } |
880 } |
716 }; |
881 } |
717 |
882 } = self; |
718 let elapsed = start.elapsed().as_secs_f64(); |
883 |
719 let cpu_time = start_cpu.elapsed().as_secs_f64(); |
884 // Set up algorithms |
720 |
885 let algorithms = match (algs, dataterm) { |
721 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); |
886 (Some(algs), _) => algs, |
722 |
887 _ => vec![DefaultAlgorithm::SlidingPDPS.get_named()], |
723 // Save results |
888 }; |
724 println!("{}", "Saving results…".green()); |
889 |
725 |
890 // Set up operators |
726 let mkname = |t| format!("{prefix}{alg_name}_{t}"); |
891 let depth = DynamicDepth(8); |
727 |
892 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); |
728 write_json(mkname("config.json"), &named)?; |
893 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); |
729 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; |
894 let opAext = RowOp(opA.clone(), IdOp::new()); |
730 μ.write_csv(mkname("reco.txt"))?; |
895 let fnR = Zero::new(); |
731 logger.write_csv(mkname("log.txt"))?; |
896 let h = map3(domain.span_start(), domain.span_end(), sensor_count, |
|
897 |a, b, n| (b-a)/F::cast_from(n)) |
|
898 .into_iter() |
|
899 .reduce(NumTraitsFloat::max) |
|
900 .unwrap(); |
|
901 let z = DVector::zeros(sensor_count.iter().product()); |
|
902 let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap(); |
|
903 let y = opKz.apply(&z); |
|
904 let fnH = Weighted{ base_fn : L1.as_mapping(), weight : λ}; // TODO: L_{2,1} |
|
905 // let zero_y = y.clone(); |
|
906 // let zeroBTFN = opA.preadjoint().apply(&zero_y); |
|
907 // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); |
|
908 |
|
909 // Set up random number generator. |
|
910 let mut rng = StdRng::seed_from_u64(noise_seed); |
|
911 |
|
912 // Generate the data and calculate SSNR statistic |
|
913 let bias_vec = DVector::from_vec(opA.grid() |
|
914 .into_iter() |
|
915 .map(|v| bias.apply(v)) |
|
916 .collect::<Vec<F>>()); |
|
917 let b_hat : DVector<_> = opA.apply(μ_hat) + &bias_vec; |
|
918 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); |
|
919 let b = &b_hat + &noise; |
|
920 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
|
921 // overloading log10 and conflicting with standard NumTraits one. |
|
922 let stats = ExperimentStats::new(&b, &noise); |
|
923 |
|
924 let prefix = start_experiment(&self, cli, stats)?; |
|
925 |
|
926 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
|
927 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
|
928 |
|
929 opA.write_observable(&bias_vec, format!("{prefix}bias"))?; |
|
930 |
|
931 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
|
932 |
|
933 let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z")); |
|
934 |
|
935 // Run the algorithms |
|
936 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, |
|
937 |alg, iterator, plotter, running| |
|
938 { |
|
939 let Pair(μ, z) = match alg { |
|
940 AlgorithmConfig::ForwardPDPS(ref algconfig) => { |
|
941 match (regularisation, dataterm) { |
|
942 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ |
|
943 print!("{running}"); |
|
944 pointsource_forward_pdps_pair( |
|
945 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
|
946 iterator, plotter, |
|
947 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
948 ) |
|
949 }), |
|
950 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ |
|
951 print!("{running}"); |
|
952 pointsource_forward_pdps_pair( |
|
953 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, |
|
954 iterator, plotter, |
|
955 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
956 ) |
|
957 }), |
|
958 _ => Err(NotImplemented) |
|
959 } |
|
960 }, |
|
961 AlgorithmConfig::SlidingPDPS(ref algconfig) => { |
|
962 match (regularisation, dataterm) { |
|
963 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ |
|
964 print!("{running}"); |
|
965 pointsource_sliding_pdps_pair( |
|
966 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
|
967 iterator, plotter, |
|
968 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
969 ) |
|
970 }), |
|
971 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ |
|
972 print!("{running}"); |
|
973 pointsource_sliding_pdps_pair( |
|
974 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, |
|
975 iterator, plotter, |
|
976 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
977 ) |
|
978 }), |
|
979 _ => Err(NotImplemented) |
|
980 } |
|
981 }, |
|
982 _ => Err(NotImplemented) |
|
983 }?; |
|
984 Ok((μ, z)) |
|
985 }) |
|
986 } |
|
987 } |
|
988 |
|
989 |
|
990 /// Calculative minimum and maximum values of all the `logs`, and save them into |
|
991 /// corresponding file names given as the first elements of the tuples in the vectors. |
|
992 fn save_logs<F : Float, const N : usize>( |
|
993 logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)> |
|
994 ) -> DynError { |
|
995 // Process logs for relative values |
|
996 println!("{}", "Processing logs…"); |
|
997 |
|
998 |
|
999 // Find minimum value and initial value within a single log |
|
1000 let proc_single_log = |log : &Logger<Timed<IterInfo<F, N>>>| { |
|
1001 let d = log.data(); |
|
1002 let mi = d.iter() |
|
1003 .map(|i| i.data.value) |
|
1004 .reduce(NumTraitsFloat::min); |
|
1005 d.first() |
|
1006 .map(|i| i.data.value) |
|
1007 .zip(mi) |
|
1008 }; |
|
1009 |
|
1010 // Find minimum and maximum value over all logs |
|
1011 let (v_ini, v_min) = logs.iter() |
|
1012 .filter_map(|&(_, ref log)| proc_single_log(log)) |
|
1013 .reduce(|(i1, m1), (i2, m2)| (i1.max(i2), m1.min(m2))) |
|
1014 .ok_or(anyhow!("No algorithms found"))?; |
|
1015 |
|
1016 let logmap = |Timed { cpu_time, iter, data }| { |
|
1017 let IterInfo { |
|
1018 value, |
|
1019 n_spikes, |
|
1020 inner_iters, |
|
1021 merged, |
|
1022 pruned, |
|
1023 //postprocessing, |
|
1024 this_iters, |
|
1025 .. |
|
1026 } = data; |
|
1027 // let post_value = match (postprocessing, dataterm) { |
|
1028 // (Some(mut μ), DataTerm::L2Squared) => { |
|
1029 // // Comparison postprocessing is only implemented for the case handled |
|
1030 // // by the FW variants. |
|
1031 // reg.optimise_weights( |
|
1032 // &mut μ, &opA, &b, &findim_data, &inner_config, |
|
1033 // inner_it |
|
1034 // ); |
|
1035 // dataterm.value_at_residual(opA.apply(&μ) - &b) |
|
1036 // + regularisation.apply(&μ) |
|
1037 // }, |
|
1038 // _ => value, |
|
1039 // }; |
|
1040 let relative_value = (value - v_min)/(v_ini - v_min); |
|
1041 CSVLog { |
|
1042 iter, |
|
1043 value, |
|
1044 relative_value, |
|
1045 //post_value, |
|
1046 n_spikes, |
|
1047 cpu_time : cpu_time.as_secs_f64(), |
|
1048 inner_iters, |
|
1049 merged, |
|
1050 pruned, |
|
1051 this_iters |
732 } |
1052 } |
733 |
1053 }; |
734 Ok(()) |
1054 |
|
1055 println!("{}", "Saving logs …".green()); |
|
1056 |
|
1057 for (name, logger) in logs { |
|
1058 logger.map(logmap).write_csv(name)?; |
735 } |
1059 } |
736 } |
1060 |
737 // *** macro end boiler plate *** |
1061 Ok(()) |
738 }} |
1062 } |
739 // *** actual code *** |
1063 |
740 |
|
741 impl_experiment!(ExperimentV2, regularisation, std::convert::identity); |
|
742 |
1064 |
743 /// Plot experiment setup |
1065 /// Plot experiment setup |
744 #[replace_float_literals(F::cast_from(literal))] |
1066 #[replace_float_literals(F::cast_from(literal))] |
745 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( |
1067 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( |
746 cli : &CommandLineArgs, |
1068 cli : &CommandLineArgs, |
747 prefix : &String, |
1069 prefix : &String, |
748 domain : &Cube<F, N>, |
1070 domain : &Cube<F, N>, |
749 sensor : &Sensor, |
1071 sensor : &Sensor, |
750 kernel : &Kernel, |
1072 kernel : &Kernel, |
751 spread : &Spread, |
1073 spread : &Spread, |
752 μ_hat : &DiscreteMeasure<Loc<F, N>, F>, |
1074 μ_hat : &RNDM<F, N>, |
753 op𝒟 : &𝒟, |
1075 op𝒟 : &𝒟, |
754 opA : &A, |
1076 opA : &A, |
755 b_hat : &A::Observable, |
1077 b_hat : &A::Observable, |
756 b : &A::Observable, |
1078 b : &A::Observable, |
757 kernel_plot_width : F, |
1079 kernel_plot_width : F, |