diff -r 000000000000 -r eb3c7813b67a src/main.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/main.rs Thu Dec 01 23:07:35 2022 +0200 @@ -0,0 +1,233 @@ +// The main documentation is in the README. +#![doc = include_str!("../README.md")] + +// We use unicode. We would like to use much more of it than Rust allows. +// Live with it. Embrace it. +#![allow(uncommon_codepoints)] +#![allow(mixed_script_confusables)] +#![allow(confusable_idents)] +// Linear operators may be writtten e.g. as `opA` for a resemblance +// to mathematical convention. +#![allow(non_snake_case)] +// We need the drain filter for inertial prune +#![feature(drain_filter)] + +use clap::Parser; +use itertools::Itertools; +use serde_json; +use alg_tools::iterate::Verbose; +use alg_tools::parallelism::{ + set_num_threads, + set_max_threads, +}; +use std::num::NonZeroUsize; + +pub mod types; +pub mod measures; +pub mod fourier; +pub mod kernels; +pub mod seminorms; +pub mod forward_model; +pub mod plot; +pub mod subproblem; +pub mod tolerance; +pub mod fb; +pub mod frank_wolfe; +pub mod pdps; +pub mod run; +pub mod rand_distr; +pub mod experiments; + +use types::{float, ClapFloat}; +use run::{ + DefaultAlgorithm, + Configuration, + PlotLevel, + Named, + AlgorithmConfig, +}; +use experiments::DefaultExperiment; +use measures::merging::SpikeMergingMethod; +use DefaultExperiment::*; +use DefaultAlgorithm::*; + +/// Command line parameters +#[derive(Parser, Debug)] +#[clap( + about = env!("CARGO_PKG_DESCRIPTION"), + author = env!("CARGO_PKG_AUTHORS"), + version = env!("CARGO_PKG_VERSION"), + after_help = "Pass --help for longer descriptions.", + after_long_help = "", +)] +pub struct CommandLineArgs { + #[arg(long, short = 'm', value_name = "M")] + /// Maximum iteration count + max_iter : Option, + + #[arg(long, short = 'n', value_name = "N")] + /// Output status every N iterations. Set to 0 to disable. + verbose_iter : Option, + + #[arg(long, short = 'q')] + /// Don't display iteration progress + quiet : bool, + + /// List of experiments to perform. + #[arg(value_enum, value_name = "EXPERIMENT", + default_values_t = [Experiment1D, Experiment1DFast, + Experiment2D, Experiment2DFast, + Experiment1D_L1])] + experiments : Vec, + + /// Default algorithm configration(s) to use on the experiments. + /// + /// Not all algorithms are available for all the experiments. + /// In particular, only PDPS is available for the experiments with L¹ data term. + #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', + default_values_t = [FB, FISTA, PDPS, FW, FWRelax])] + algorithm : Vec, + + /// Saved algorithm configration(s) to use on the experiments + #[arg(value_name = "JSON_FILE", long)] + saved_algorithm : Vec, + + /// Write plots for every verbose iteration + #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] + plot : PlotLevel, + + /// Directory for saving results + #[arg(long, short = 'o', default_value = "out")] + outdir : String, + + #[arg(long, help_heading = "Multi-threading", default_value = "4")] + /// Maximum number of threads + max_threads : usize, + + #[arg(long, help_heading = "Multi-threading")] + /// Number of threads. Overrides the maximum number. + num_threads : Option, + + #[clap(flatten, next_help_heading = "Experiment overrides")] + /// Experiment setup overrides + experiment_overrides : ExperimentOverrides, + + #[clap(flatten, next_help_heading = "Algorithm overrides")] + /// Algorithm parametrisation overrides + algoritm_overrides : AlgorithmOverrides, +} + +/// Command line experiment setup overrides +#[derive(Parser, Debug)] +pub struct ExperimentOverrides { + #[arg(long)] + /// Regularisation parameter override. + /// + /// Only use if running just a single experiment, as different experiments have different + /// regularisation parameters. + alpha : Option, + + #[arg(long)] + /// Gaussian noise variance override + variance : Option, + + #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])] + /// Salt and pepper noise override. + salt_and_pepper : Option>, + + #[arg(long)] + /// Noise seed + noise_seed : Option, +} + +/// Command line algorithm parametrisation overrides +#[derive(Parser, Debug)] +pub struct AlgorithmOverrides { + #[arg(long, value_names = &["COUNT", "EACH"])] + /// Override bootstrap insertion iterations for --algorithm. + /// + /// The first parameter is the number of bootstrap insertion iterations, and the second + /// the maximum number of iterations on each of them. + bootstrap_insertions : Option>, + + #[arg(long, requires = "algorithm")] + /// Primal step length parameter override for --algorithm. + /// + /// Only use if running just a single algorithm, as different algorithms have different + /// regularisation parameters. Does not affect the algorithms fw and fwrelax. + tau0 : Option, + + #[arg(long, requires = "algorithm")] + /// Dual step length parameter override for --algorithm. + /// + /// Only use if running just a single algorithm, as different algorithms have different + /// regularisation parameters. Only affects PDPS. + sigma0 : Option, + + #[arg(value_enum, long)] + /// PDPS acceleration, when available. + acceleration : Option, + + #[arg(long)] + /// Perform postprocess weight optimisation for saved iterations + /// + /// Only affects FB, FISTA, and PDPS. + postprocessing : Option, + + #[arg(value_name = "n", long)] + /// Merging frequency, if merging enabled (every n iterations) + /// + /// Only affects FB, FISTA, and PDPS. + merge_every : Option, + + #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::::value_parser())] + /// Merging strategy + /// + /// Either the string "none", or a radius value for heuristic merging. + merging : Option>, + + #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::::value_parser())] + /// Final merging strategy + /// + /// Either the string "none", or a radius value for heuristic merging. + /// Only affects FB, FISTA, and PDPS. + final_merging : Option>, +} + +/// The entry point for the program. +pub fn main() { + let cli = CommandLineArgs::parse(); + + if let Some(n_threads) = cli.num_threads { + let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); + set_num_threads(n); + } else { + let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); + set_max_threads(m); + } + + for experiment_shorthand in cli.experiments.iter().unique() { + let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); + let mut config : Configuration = experiment.default_config(); + let mut algs : Vec>> + = cli.algorithm.iter() + .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides)) + .collect(); + for filename in cli.saved_algorithm.iter() { + let f = std::fs::File::open(filename).unwrap(); + let alg = serde_json::from_reader(f).unwrap(); + algs.push(alg); + } + cli.max_iter.map(|m| config.iterator_options.max_iter = m); + cli.verbose_iter.map(|n| config.iterator_options.verbose_iter = Verbose::Every(n)); + config.plot = cli.plot; + config.iterator_options.quiet = cli.quiet; + config.outdir = cli.outdir.clone(); + if !algs.is_empty() { + config.algorithms = algs.clone(); + } + + experiment.runall(config) + .unwrap() + } +}