src/main.rs

changeset 0
eb3c7813b67a
child 2
7a953a87b6c1
child 5
df971c81282e
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/main.rs	Thu Dec 01 23:07:35 2022 +0200
@@ -0,0 +1,233 @@
+// The main documentation is in the README.
+#![doc = include_str!("../README.md")]
+
+// We use unicode. We would like to use much more of it than Rust allows.
+// Live with it. Embrace it.
+#![allow(uncommon_codepoints)]
+#![allow(mixed_script_confusables)]
+#![allow(confusable_idents)]
+// Linear operators may be writtten e.g. as `opA` for a resemblance
+// to mathematical convention.
+#![allow(non_snake_case)]
+// We need the drain filter for inertial prune
+#![feature(drain_filter)]
+
+use clap::Parser;
+use itertools::Itertools;
+use serde_json;
+use alg_tools::iterate::Verbose;
+use alg_tools::parallelism::{
+    set_num_threads,
+    set_max_threads,
+};
+use std::num::NonZeroUsize;
+
+pub mod types;
+pub mod measures;
+pub mod fourier;
+pub mod kernels;
+pub mod seminorms;
+pub mod forward_model;
+pub mod plot;
+pub mod subproblem;
+pub mod tolerance;
+pub mod fb;
+pub mod frank_wolfe;
+pub mod pdps;
+pub mod run;
+pub mod rand_distr;
+pub mod experiments;
+
+use types::{float, ClapFloat};
+use run::{
+    DefaultAlgorithm,
+    Configuration,
+    PlotLevel,
+    Named,
+    AlgorithmConfig,
+};
+use experiments::DefaultExperiment;
+use measures::merging::SpikeMergingMethod;
+use DefaultExperiment::*;
+use DefaultAlgorithm::*;
+
+/// Command line parameters
+#[derive(Parser, Debug)]
+#[clap(
+    about = env!("CARGO_PKG_DESCRIPTION"),
+    author = env!("CARGO_PKG_AUTHORS"),
+    version = env!("CARGO_PKG_VERSION"),
+    after_help = "Pass --help for longer descriptions.",
+    after_long_help = "",
+)]
+pub struct CommandLineArgs {
+    #[arg(long, short = 'm', value_name = "M")]
+    /// Maximum iteration count
+    max_iter : Option<usize>,
+
+    #[arg(long, short = 'n', value_name = "N")]
+    /// Output status every N iterations. Set to 0 to disable.
+    verbose_iter : Option<usize>,
+
+    #[arg(long, short = 'q')]
+    /// Don't display iteration progress
+    quiet : bool,
+
+    /// List of experiments to perform.
+    #[arg(value_enum, value_name = "EXPERIMENT",
+           default_values_t = [Experiment1D, Experiment1DFast,
+                               Experiment2D, Experiment2DFast,
+                               Experiment1D_L1])]
+    experiments : Vec<DefaultExperiment>,
+
+    /// Default algorithm configration(s) to use on the experiments.
+    ///
+    /// Not all algorithms are available for all the experiments.
+    /// In particular, only PDPS is available for the experiments with L¹ data term.
+    #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
+           default_values_t = [FB, FISTA, PDPS, FW, FWRelax])]
+    algorithm : Vec<DefaultAlgorithm>,
+
+    /// Saved algorithm configration(s) to use on the experiments
+    #[arg(value_name = "JSON_FILE", long)]
+    saved_algorithm : Vec<String>,
+
+    /// Write plots for every verbose iteration
+    #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)]
+    plot : PlotLevel,
+
+    /// Directory for saving results
+    #[arg(long, short = 'o', default_value = "out")]
+    outdir : String,
+
+    #[arg(long, help_heading = "Multi-threading", default_value = "4")]
+    /// Maximum number of threads
+    max_threads : usize,
+
+    #[arg(long, help_heading = "Multi-threading")]
+    /// Number of threads. Overrides the maximum number.
+    num_threads : Option<usize>,
+
+    #[clap(flatten, next_help_heading = "Experiment overrides")]
+    /// Experiment setup overrides
+    experiment_overrides : ExperimentOverrides<float>,
+
+    #[clap(flatten, next_help_heading = "Algorithm overrides")]
+    /// Algorithm parametrisation overrides
+    algoritm_overrides : AlgorithmOverrides<float>,
+}
+
+/// Command line experiment setup overrides
+#[derive(Parser, Debug)]
+pub struct ExperimentOverrides<F : ClapFloat> {
+    #[arg(long)]
+    /// Regularisation parameter override.
+    ///
+    /// Only use if running just a single experiment, as different experiments have different
+    /// regularisation parameters.
+    alpha : Option<F>,
+
+    #[arg(long)]
+    /// Gaussian noise variance override
+    variance : Option<F>,
+
+    #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])]
+    /// Salt and pepper noise override.
+    salt_and_pepper : Option<Vec<F>>,
+
+    #[arg(long)]
+    /// Noise seed
+    noise_seed : Option<u64>,
+}
+
+/// Command line algorithm parametrisation overrides
+#[derive(Parser, Debug)]
+pub struct AlgorithmOverrides<F : ClapFloat> {
+    #[arg(long, value_names = &["COUNT", "EACH"])]
+    /// Override bootstrap insertion iterations for --algorithm.
+    ///
+    /// The first parameter is the number of bootstrap insertion iterations, and the second
+    /// the maximum number of iterations on each of them.
+    bootstrap_insertions : Option<Vec<usize>>,
+
+    #[arg(long, requires = "algorithm")]
+    /// Primal step length parameter override for --algorithm.
+    ///
+    /// Only use if running just a single algorithm, as different algorithms have different
+    /// regularisation parameters. Does not affect the algorithms fw and fwrelax.
+    tau0 : Option<F>,
+
+    #[arg(long, requires = "algorithm")]
+    /// Dual step length parameter override for --algorithm.
+    ///
+    /// Only use if running just a single algorithm, as different algorithms have different
+    /// regularisation parameters. Only affects PDPS.
+    sigma0 : Option<F>,
+
+    #[arg(value_enum, long)]
+    /// PDPS acceleration, when available.
+    acceleration : Option<pdps::Acceleration>,
+
+    #[arg(long)]
+    /// Perform postprocess weight optimisation for saved iterations
+    ///
+    /// Only affects FB, FISTA, and PDPS.
+    postprocessing : Option<bool>,
+
+    #[arg(value_name = "n", long)]
+    /// Merging frequency, if merging enabled (every n iterations)
+    ///
+    /// Only affects FB, FISTA, and PDPS.
+    merge_every : Option<usize>,
+
+    #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())]
+    /// Merging strategy
+    ///
+    /// Either the string "none", or a radius value for heuristic merging.
+    merging : Option<SpikeMergingMethod<F>>,
+
+    #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())]
+    /// Final merging strategy
+    ///
+    /// Either the string "none", or a radius value for heuristic merging.
+    /// Only affects FB, FISTA, and PDPS.
+    final_merging : Option<SpikeMergingMethod<F>>,
+}
+
+/// The entry point for the program.
+pub fn main() {
+    let cli = CommandLineArgs::parse();
+
+    if let Some(n_threads) = cli.num_threads {
+        let n = NonZeroUsize::new(n_threads).expect("Invalid thread count");
+        set_num_threads(n);
+    } else {
+        let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count");
+        set_max_threads(m);
+    }
+
+    for experiment_shorthand in cli.experiments.iter().unique() {
+        let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap();
+        let mut config : Configuration<float> = experiment.default_config();
+        let mut algs : Vec<Named<AlgorithmConfig<float>>>
+            = cli.algorithm.iter()
+                            .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides))
+                            .collect();
+        for filename in cli.saved_algorithm.iter() {
+            let f = std::fs::File::open(filename).unwrap();
+            let alg = serde_json::from_reader(f).unwrap();
+            algs.push(alg);
+        }
+        cli.max_iter.map(|m| config.iterator_options.max_iter = m);
+        cli.verbose_iter.map(|n| config.iterator_options.verbose_iter = Verbose::Every(n));
+        config.plot = cli.plot;
+        config.iterator_options.quiet = cli.quiet;
+        config.outdir = cli.outdir.clone();
+        if !algs.is_empty() {
+            config.algorithms = algs.clone();
+        }
+
+        experiment.runall(config)
+                  .unwrap()
+    }
+}

mercurial