Tue, 06 Dec 2022 14:12:20 +0200
v1.0.0-pre-arxiv (missing arXiv links)
// The main documentation is in the README. // We need to uglify it in build.rs because rustdoc is stuck in the past. #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))] // We use unicode. We would like to use much more of it than Rust allows. // Live with it. Embrace it. #![allow(uncommon_codepoints)] #![allow(mixed_script_confusables)] #![allow(confusable_idents)] // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical // convention while referring to the type (trait) of the operator as `A`. #![allow(non_snake_case)] // We need the drain filter for inertial prune. #![feature(drain_filter)] use clap::Parser; use serde::{Serialize, Deserialize}; use serde_json; use itertools::Itertools; use std::num::NonZeroUsize; use alg_tools::parallelism::{ set_num_threads, set_max_threads, }; pub mod types; pub mod measures; pub mod fourier; pub mod kernels; pub mod seminorms; pub mod forward_model; pub mod plot; pub mod subproblem; pub mod tolerance; pub mod fb; pub mod frank_wolfe; pub mod pdps; pub mod run; pub mod rand_distr; pub mod experiments; use types::{float, ClapFloat}; use run::{ DefaultAlgorithm, PlotLevel, Named, AlgorithmConfig, }; use experiments::DefaultExperiment; use measures::merging::SpikeMergingMethod; use DefaultExperiment::*; use DefaultAlgorithm::*; /// Command line parameters #[derive(Parser, Debug, Serialize)] #[clap( about = env!("CARGO_PKG_DESCRIPTION"), author = env!("CARGO_PKG_AUTHORS"), version = env!("CARGO_PKG_VERSION"), after_help = "Pass --help for longer descriptions.", after_long_help = "", )] pub struct CommandLineArgs { #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)] /// Maximum iteration count max_iter : usize, #[arg(long, short = 'n', value_name = "N")] /// Output status every N iterations. Set to 0 to disable. /// /// The default is to output status based on logarithmic increments. verbose_iter : Option<usize>, #[arg(long, short = 'q')] /// Don't display iteration progress quiet : bool, /// List of experiments to perform. #[arg(value_enum, value_name = "EXPERIMENT", default_values_t = [Experiment1D, Experiment1DFast, Experiment2D, Experiment2DFast, Experiment1D_L1])] experiments : Vec<DefaultExperiment>, /// Default algorithm configration(s) to use on the experiments. /// /// Not all algorithms are available for all the experiments. /// In particular, only PDPS is available for the experiments with L¹ data term. #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', default_values_t = [FB, FISTA, PDPS, FW, FWRelax])] algorithm : Vec<DefaultAlgorithm>, /// Saved algorithm configration(s) to use on the experiments #[arg(value_name = "JSON_FILE", long)] saved_algorithm : Vec<String>, /// Plot saving scheme #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] plot : PlotLevel, /// Directory for saving results #[arg(long, short = 'o', required = true, default_value = "out")] outdir : String, #[arg(long, help_heading = "Multi-threading", default_value = "4")] /// Maximum number of threads max_threads : usize, #[arg(long, help_heading = "Multi-threading")] /// Number of threads. Overrides the maximum number. num_threads : Option<usize>, #[clap(flatten, next_help_heading = "Experiment overrides")] /// Experiment setup overrides experiment_overrides : ExperimentOverrides<float>, #[clap(flatten, next_help_heading = "Algorithm overrides")] /// Algorithm parametrisation overrides algoritm_overrides : AlgorithmOverrides<float>, } /// Command line experiment setup overrides #[derive(Parser, Debug, Serialize, Deserialize)] pub struct ExperimentOverrides<F : ClapFloat> { #[arg(long)] /// Regularisation parameter override. /// /// Only use if running just a single experiment, as different experiments have different /// regularisation parameters. alpha : Option<F>, #[arg(long)] /// Gaussian noise variance override variance : Option<F>, #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])] /// Salt and pepper noise override. salt_and_pepper : Option<Vec<F>>, #[arg(long)] /// Noise seed noise_seed : Option<u64>, } /// Command line algorithm parametrisation overrides #[derive(Parser, Debug, Serialize, Deserialize)] pub struct AlgorithmOverrides<F : ClapFloat> { #[arg(long, value_names = &["COUNT", "EACH"])] /// Override bootstrap insertion iterations for --algorithm. /// /// The first parameter is the number of bootstrap insertion iterations, and the second /// the maximum number of iterations on each of them. bootstrap_insertions : Option<Vec<usize>>, #[arg(long, requires = "algorithm")] /// Primal step length parameter override for --algorithm. /// /// Only use if running just a single algorithm, as different algorithms have different /// regularisation parameters. Does not affect the algorithms fw and fwrelax. tau0 : Option<F>, #[arg(long, requires = "algorithm")] /// Dual step length parameter override for --algorithm. /// /// Only use if running just a single algorithm, as different algorithms have different /// regularisation parameters. Only affects PDPS. sigma0 : Option<F>, #[arg(value_enum, long)] /// PDPS acceleration, when available. acceleration : Option<pdps::Acceleration>, #[arg(long)] /// Perform postprocess weight optimisation for saved iterations /// /// Only affects FB, FISTA, and PDPS. postprocessing : Option<bool>, #[arg(value_name = "n", long)] /// Merging frequency, if merging enabled (every n iterations) /// /// Only affects FB, FISTA, and PDPS. merge_every : Option<usize>, #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] /// Merging strategy /// /// Either the string "none", or a radius value for heuristic merging. merging : Option<SpikeMergingMethod<F>>, #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] /// Final merging strategy /// /// Either the string "none", or a radius value for heuristic merging. /// Only affects FB, FISTA, and PDPS. final_merging : Option<SpikeMergingMethod<F>>, } /// The entry point for the program. pub fn main() { let cli = CommandLineArgs::parse(); #[cfg(debug_assertions)] { use colored::Colorize; println!("{}", format!("\n\ ********\n\ WARNING: Compiled without optimisations; {}\n\ Please recompile with `--release` flag.\n\ ********\n\ ", "performance will be poor!".blink() ).red()); } if let Some(n_threads) = cli.num_threads { let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); set_num_threads(n); } else { let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); set_max_threads(m); } for experiment_shorthand in cli.experiments.iter().unique() { let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); let mut algs : Vec<Named<AlgorithmConfig<float>>> = cli.algorithm.iter() .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides)) .collect(); for filename in cli.saved_algorithm.iter() { let f = std::fs::File::open(filename).unwrap(); let alg = serde_json::from_reader(f).unwrap(); algs.push(alg); } experiment.runall(&cli, (!algs.is_empty()).then_some(algs)) .unwrap() } }