src/run.rs

changeset 0
eb3c7813b67a
child 1
d4fd5f32d10e
child 2
7a953a87b6c1
child 9
21b0e537ac0e
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/run.rs	Thu Dec 01 23:07:35 2022 +0200
@@ -0,0 +1,602 @@
+/*!
+This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment.
+*/
+
+use numeric_literals::replace_float_literals;
+use colored::Colorize;
+use serde::{Serialize, Deserialize};
+use serde_json;
+use nalgebra::base::DVector;
+use std::hash::Hash;
+use chrono::{DateTime, Utc};
+use cpu_time::ProcessTime;
+use clap::ValueEnum;
+use std::collections::HashMap;
+use std::time::Instant;
+
+use rand::prelude::{
+    StdRng,
+    SeedableRng
+};
+use rand_distr::Distribution;
+
+use alg_tools::bisection_tree::*;
+use alg_tools::iterate::{
+    Timed,
+    AlgIteratorOptions,
+    Verbose,
+    AlgIteratorFactory,
+};
+use alg_tools::logger::Logger;
+use alg_tools::error::DynError;
+use alg_tools::tabledump::TableDump;
+use alg_tools::sets::Cube;
+use alg_tools::mapping::RealMapping;
+use alg_tools::nalgebra_support::ToNalgebraRealField;
+use alg_tools::euclidean::Euclidean;
+use alg_tools::norms::{Norm, L1};
+use alg_tools::lingrid::lingrid;
+use alg_tools::sets::SetOrd;
+
+use crate::kernels::*;
+use crate::types::*;
+use crate::measures::*;
+use crate::measures::merging::SpikeMerging;
+use crate::forward_model::*;
+use crate::fb::{
+    FBConfig,
+    pointsource_fb,
+    FBMetaAlgorithm, FBGenericConfig,
+};
+use crate::pdps::{
+    PDPSConfig,
+    L2Squared,
+    pointsource_pdps,
+};
+use crate::frank_wolfe::{
+    FWConfig,
+    FWVariant,
+    pointsource_fw,
+    prepare_optimise_weights,
+    optimise_weights,
+};
+use crate::subproblem::InnerSettings;
+use crate::seminorms::*;
+use crate::plot::*;
+use crate::AlgorithmOverrides;
+
+/// Available algorithms and their configurations
+#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
+pub enum AlgorithmConfig<F : Float> {
+    FB(FBConfig<F>),
+    FW(FWConfig<F>),
+    PDPS(PDPSConfig<F>),
+}
+
+impl<F : ClapFloat> AlgorithmConfig<F> {
+    /// Override supported parameters based on the command line.
+    pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self {
+        let override_fb_generic = |g : FBGenericConfig<F>| {
+            FBGenericConfig {
+                bootstrap_insertions : cli.bootstrap_insertions
+                                          .as_ref()
+                                          .map_or(g.bootstrap_insertions,
+                                                  |n| Some((n[0], n[1]))),
+                merge_every : cli.merge_every.unwrap_or(g.merge_every),
+                merging : cli.merging.clone().unwrap_or(g.merging),
+                final_merging : cli.final_merging.clone().unwrap_or(g.final_merging),
+                .. g
+            }
+        };
+
+        use AlgorithmConfig::*;
+        match self {
+            FB(fb) => FB(FBConfig {
+                τ0 : cli.tau0.unwrap_or(fb.τ0),
+                insertion : override_fb_generic(fb.insertion),
+                .. fb
+            }),
+            PDPS(pdps) => PDPS(PDPSConfig {
+                τ0 : cli.tau0.unwrap_or(pdps.τ0),
+                σ0 : cli.sigma0.unwrap_or(pdps.σ0),
+                acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
+                insertion : override_fb_generic(pdps.insertion),
+                .. pdps
+            }),
+            FW(fw) => FW(FWConfig {
+                merging : cli.merging.clone().unwrap_or(fw.merging),
+                .. fw
+            })
+        }
+    }
+}
+
+/// Helper struct for tagging and [`AlgorithmConfig`] or [`Experiment`] with a name.
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct Named<Data> {
+    pub name : String,
+    #[serde(flatten)]
+    pub data : Data,
+}
+
+/// Shorthand algorithm configurations, to be used with the command line parser
+#[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash)]
+pub enum DefaultAlgorithm {
+    /// The μFB forward-backward method
+    #[clap(name = "fb")]
+    FB,
+    /// The μFISTA inertial forward-backward method
+    #[clap(name = "fista")]
+    FISTA,
+    /// The “fully corrective” conditional gradient method
+    #[clap(name = "fw")]
+    FW,
+    /// The “relaxed conditional gradient method
+    #[clap(name = "fwrelax")]
+    FWRelax,
+    /// The μPDPS primal-dual proximal splitting method
+    #[clap(name = "pdps")]
+    PDPS,
+}
+
+impl DefaultAlgorithm {
+    /// Returns the algorithm configuration corresponding to the algorithm shorthand
+    pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> {
+        use DefaultAlgorithm::*;
+        match *self {
+            FB => AlgorithmConfig::FB(Default::default()),
+            FISTA => AlgorithmConfig::FB(FBConfig{
+                meta : FBMetaAlgorithm::InertiaFISTA,
+                .. Default::default()
+            }),
+            FW => AlgorithmConfig::FW(Default::default()),
+            FWRelax => AlgorithmConfig::FW(FWConfig{
+                variant : FWVariant::Relaxed,
+                .. Default::default()
+            }),
+            PDPS => AlgorithmConfig::PDPS(Default::default()),
+        }
+    }
+
+    /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand
+    pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> {
+        self.to_named(self.default_config())
+    }
+
+    pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> {
+        let name = self.to_possible_value().unwrap().get_name().to_string();
+        Named{ name , data : alg }
+    }
+}
+
+
+// // Floats cannot be hashed directly, so just hash the debug formatting
+// // for use as file identifier.
+// impl<F : Float> Hash for AlgorithmConfig<F> {
+//     fn hash<H: Hasher>(&self, state: &mut H) {
+//         format!("{:?}", self).hash(state);
+//     }
+// }
+
+/// Plotting level configuration
+#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, ValueEnum, Debug)]
+pub enum PlotLevel {
+    /// Plot nothing
+    #[clap(name = "none")]
+    None,
+    /// Plot problem data
+    #[clap(name = "data")]
+    Data,
+    /// Plot iterationwise state
+    #[clap(name = "iter")]
+    Iter,
+}
+
+/// Algorithm and iterator config for the experiments
+
+#[derive(Clone, Debug, Serialize)]
+#[serde(default)]
+pub struct Configuration<F : Float> {
+    /// Algorithms to run
+    pub algorithms : Vec<Named<AlgorithmConfig<F>>>,
+    /// Options for algorithm step iteration (verbosity, etc.)
+    pub iterator_options : AlgIteratorOptions,
+    /// Plotting level
+    pub plot : PlotLevel,
+    /// Directory where to save results
+    pub outdir : String,
+    /// Bisection tree depth
+    pub bt_depth : DynamicDepth,
+}
+
+type DefaultBT<F, const N : usize> = BT<
+    DynamicDepth,
+    F,
+    usize,
+    Bounds<F>,
+    N
+>;
+type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>;
+type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::<
+    F,
+    Sensor,
+    Spread,
+    DefaultBT<F, N>,
+    N
+>;
+
+/// This is a dirty workaround to rust-csv not supporting struct flattening etc.
+#[derive(Serialize)]
+struct CSVLog<F> {
+    iter : usize,
+    cpu_time : f64,
+    value : F,
+    post_value : F,
+    n_spikes : usize,
+    inner_iters : usize,
+    merged : usize,
+    pruned : usize,
+    this_iters : usize,
+}
+
+/// Collected experiment statistics
+#[derive(Clone, Debug, Serialize)]
+struct ExperimentStats<F : Float> {
+    /// Signal-to-noise ratio in decibels
+    ssnr : F,
+    /// Proportion of noise in the signal as a number in $[0, 1]$.
+    noise_ratio : F,
+    /// When the experiment was run (UTC)
+    when : DateTime<Utc>,
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F : Float> ExperimentStats<F> {
+    /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal.
+    fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self {
+        let s = signal.norm2_squared();
+        let n = noise.norm2_squared();
+        let noise_ratio = (n / s).sqrt();
+        let ssnr = 10.0 * (s /  n).log10();
+        ExperimentStats {
+            ssnr,
+            noise_ratio,
+            when : Utc::now(),
+        }
+    }
+}
+/// Collected algorithm statistics
+#[derive(Clone, Debug, Serialize)]
+struct AlgorithmStats<F : Float> {
+    /// Overall CPU time spent
+    cpu_time : F,
+    /// Real time spent
+    elapsed : F
+}
+
+
+/// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input
+/// and outputs a [`DynError`].
+fn write_json<T : Serialize>(filename : String, data : &T) -> DynError {
+    serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?;
+    Ok(())
+}
+
+
+/// Struct for experiment configurations
+#[derive(Debug, Clone, Serialize)]
+pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize>
+where F : Float,
+      [usize; N] : Serialize,
+      NoiseDistr : Distribution<F>,
+      S : Sensor<F, N>,
+      P : Spread<F, N>,
+      K : SimpleConvolutionKernel<F, N>,
+{
+    /// Domain $Ω$.
+    pub domain : Cube<F, N>,
+    /// Number of sensors along each dimension
+    pub sensor_count : [usize; N],
+    /// Noise distribution
+    pub noise_distr : NoiseDistr,
+    /// Seed for random noise generation (for repeatable experiments)
+    pub noise_seed : u64,
+    /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$.
+    pub sensor : S,
+    /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$.
+    pub spread : P,
+    /// Kernel $ρ$ of $𝒟$.
+    pub kernel : K,
+    /// True point sources
+    pub μ_hat : DiscreteMeasure<Loc<F, N>, F>,
+    /// Regularisation parameter
+    pub α : F,
+    /// For plotting : how wide should the kernels be plotted
+    pub kernel_plot_width : F,
+    /// Data term
+    pub dataterm : DataTerm,
+    /// A map of default configurations for algorithms
+    #[serde(skip)]
+    pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>,
+}
+
+/// Trait for runnable experiments
+pub trait RunnableExperiment<F : ClapFloat> {
+    /// Run all algorithms of the [`Configuration`] `config` on the experiment.
+    fn runall(&self, config : Configuration<F>) -> DynError;
+
+    /// Returns the default configuration
+    fn default_config(&self) -> Configuration<F>;
+
+    /// Return algorithm default config
+    fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>)
+    -> Named<AlgorithmConfig<F>>;
+}
+
+impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for
+Named<Experiment<F, NoiseDistr, S, K, P, N>>
+where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
+      [usize; N] : Serialize,
+      S : Sensor<F, N> + Copy + Serialize,
+      P : Spread<F, N> + Copy + Serialize,
+      Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy,
+      AutoConvolution<P> : BoundedBy<F, K>,
+      K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + Copy + Serialize,
+      Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
+      PlotLookup : Plotting<N>,
+      DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
+      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
+      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
+      NoiseDistr : Distribution<F> + Serialize {
+
+    fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>)
+    -> Named<AlgorithmConfig<F>> {
+        alg.to_named(
+            self.data
+                .algorithm_defaults
+                .get(&alg)
+                .map_or_else(|| alg.default_config(),
+                            |config| config.clone())
+                .cli_override(cli)
+        )
+    }
+
+    fn default_config(&self) -> Configuration<F> {
+        let default_alg = match self.data.dataterm {
+            DataTerm::L2Squared => DefaultAlgorithm::FB.get_named(),
+            DataTerm::L1 => DefaultAlgorithm::PDPS.get_named(),
+        };
+
+        Configuration{
+            algorithms : vec![default_alg],
+            iterator_options : AlgIteratorOptions{
+                max_iter : 2000,
+                verbose_iter : Verbose::Logarithmic(10),
+                quiet : false,
+            },
+            plot : PlotLevel::Data,
+            outdir : "out".to_string(),
+            bt_depth : DynamicDepth(8),
+        }
+    }
+
+    fn runall(&self, config : Configuration<F>) -> DynError {
+        let &Named {
+            name : ref experiment_name,
+            data : Experiment {
+                domain, sensor_count, ref noise_distr, sensor, spread, kernel,
+                ref μ_hat, α, kernel_plot_width, dataterm, noise_seed,
+                ..
+            }
+        } = self;
+
+        // Set path
+        let prefix = format!("{}/{}/", config.outdir, experiment_name);
+
+        // Set up operators
+        let depth = config.bt_depth;
+        let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth);
+        let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel);
+
+        // Set up random number generator.
+        let mut rng = StdRng::seed_from_u64(noise_seed);
+
+        // Generate the data and calculate SSNR statistic
+        let b_hat = opA.apply(μ_hat);
+        let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng);
+        let b = &b_hat + &noise;
+        // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
+        // overloading log10 and conflicting with standard NumTraits one.
+        let stats = ExperimentStats::new(&b, &noise);
+
+        // Save experiment configuration and statistics
+        let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t);
+        std::fs::create_dir_all(&prefix)?;
+        write_json(mkname_e("experiment"), self)?;
+        write_json(mkname_e("config"), &config)?;
+        write_json(mkname_e("stats"), &stats)?;
+
+        plotall(&config, &prefix, &domain, &sensor, &kernel, &spread,
+                &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?;
+
+        // Run the algorithm(s)
+        for named @ Named { name : alg_name, data : alg } in config.algorithms.iter() {
+            let this_prefix = format!("{}{}/", prefix, alg_name);
+
+            let running = || {
+                println!("{}\n{}\n{}",
+                        format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(),
+                        format!("{:?}", config.iterator_options).bright_black(),
+                        format!("{:?}", alg).bright_black());
+            };
+
+            // Create Logger and IteratorFactory
+            let mut logger = Logger::new();
+            let findim_data = prepare_optimise_weights(&opA);
+            let inner_config : InnerSettings<F> = Default::default();
+            let inner_it = inner_config.iterator_options;
+            let logmap = |iter, Timed { cpu_time, data }| {
+                let IterInfo {
+                    value,
+                    n_spikes,
+                    inner_iters,
+                    merged,
+                    pruned,
+                    postprocessing,
+                    this_iters,
+                    ..
+                } = data;
+                let post_value = match postprocessing {
+                    None => value,
+                    Some(mut μ) => {
+                        match dataterm {
+                            DataTerm::L2Squared => {
+                                optimise_weights(
+                                    &mut μ, &opA, &b, α, &findim_data, &inner_config,
+                                    inner_it
+                                );
+                                dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon)
+                            },
+                            _ => value,
+                        }
+                    }
+                };
+                CSVLog {
+                    iter,
+                    value,
+                    post_value,
+                    n_spikes,
+                    cpu_time : cpu_time.as_secs_f64(),
+                    inner_iters,
+                    merged,
+                    pruned,
+                    this_iters
+                }
+            };
+            let iterator = config.iterator_options
+                                 .instantiate()
+                                 .timed()
+                                 .mapped(logmap)
+                                 .into_log(&mut logger);
+            let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]);
+
+            // Create plotter and directory if needed.
+            let plot_count = if config.plot >= PlotLevel::Iter { 2000 } else { 0 };
+            let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid);
+
+            // Run the algorithm
+            let start = Instant::now();
+            let start_cpu = ProcessTime::now();
+            let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) {
+                (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => {
+                    running();
+                    pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter)
+                },
+                (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => {
+                    running();
+                    pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter)
+                },
+                (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => {
+                    running();
+                    pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared)
+                },
+                (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => {
+                    running();
+                    pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1)
+                },
+                _ =>  {
+                    let msg = format!("Algorithm “{}” not implemented for dataterm {:?}. Skipping.",
+                                      alg_name, dataterm).red();
+                    eprintln!("{}", msg);
+                    continue
+                }
+            };
+            let elapsed = start.elapsed().as_secs_f64();
+            let cpu_time = start_cpu.elapsed().as_secs_f64();
+
+            println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow());
+
+            // Save results
+            println!("{}", "Saving results…".green());
+
+            let mkname = |
+            t| format!("{p}{n}_{t}", p = prefix, n = alg_name, t = t);
+
+            write_json(mkname("config.json"), &named)?;
+            write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?;
+            μ.write_csv(mkname("reco.txt"))?;
+            logger.write_csv(mkname("log.txt"))?;
+        }
+
+        Ok(())
+    }
+}
+
+/// Plot experiment setup
+#[replace_float_literals(F::cast_from(literal))]
+fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>(
+    config : &Configuration<F>,
+    prefix : &String,
+    domain : &Cube<F, N>,
+    sensor : &Sensor,
+    kernel : &Kernel,
+    spread : &Spread,
+    μ_hat : &DiscreteMeasure<Loc<F, N>, F>,
+    op𝒟 : &𝒟,
+    opA : &A,
+    b_hat : &A::Observable,
+    b : &A::Observable,
+    kernel_plot_width : F,
+) -> DynError
+where F : Float + ToNalgebraRealField,
+      Sensor : RealMapping<F, N> + Support<F, N> + Clone,
+      Spread : RealMapping<F, N> + Support<F, N> + Clone,
+      Kernel : RealMapping<F, N> + Support<F, N>,
+      Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>,
+      𝒟 : DiscreteMeasureOp<Loc<F, N>, F>,
+      𝒟::Codomain : RealMapping<F, N>,
+      A : ForwardModel<Loc<F, N>, F>,
+      A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>,
+      PlotLookup : Plotting<N>,
+      Cube<F, N> : SetOrd {
+
+    if config.plot < PlotLevel::Data {
+        return Ok(())
+    }
+
+    let base = Convolution(sensor.clone(), spread.clone());
+
+    let resolution = if N==1 { 100 } else { 40 };
+    let pfx = |n| format!("{}{}", prefix, n);
+    let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]);
+
+    PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string());
+    PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string());
+    PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string());
+    PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string());
+
+    let plotgrid2 = lingrid(&domain, &[resolution; N]);
+
+    let ω_hat = op𝒟.apply(μ_hat);
+    let noise =  opA.preadjoint().apply(opA.apply(μ_hat) - b);
+    PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string());
+    PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"),
+                               "noise Aᵀ(Aμ̂ - b)".to_string());
+
+    let preadj_b =  opA.preadjoint().apply(b);
+    let preadj_b_hat =  opA.preadjoint().apply(b_hat);
+    //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds());
+    PlotLookup::plot_into_file_spikes(
+        "Aᵀb".to_string(), &preadj_b,
+        "Aᵀb̂".to_string(), Some(&preadj_b_hat),
+        plotgrid2, None, &μ_hat,
+        pfx("omega_b")
+    );
+
+    // Save true solution and observables
+    let pfx = |n| format!("{}{}", prefix, n);
+    μ_hat.write_csv(pfx("orig.txt"))?;
+    opA.write_observable(&b_hat, pfx("b_hat"))?;
+    opA.write_observable(&b, pfx("b_noisy"))
+}
+

mercurial