src/main.rs

changeset 9
21b0e537ac0e
parent 6
bcb508479948
child 20
90f77ad9a98d
equal deleted inserted replaced
8:ea3ca78873e8 9:21b0e537ac0e
12 #![allow(non_snake_case)] 12 #![allow(non_snake_case)]
13 // We need the drain filter for inertial prune. 13 // We need the drain filter for inertial prune.
14 #![feature(drain_filter)] 14 #![feature(drain_filter)]
15 15
16 use clap::Parser; 16 use clap::Parser;
17 use serde::{Serialize, Deserialize};
18 use serde_json;
17 use itertools::Itertools; 19 use itertools::Itertools;
18 use serde_json;
19 use std::num::NonZeroUsize; 20 use std::num::NonZeroUsize;
20 21
21 use alg_tools::iterate::Verbose;
22 use alg_tools::parallelism::{ 22 use alg_tools::parallelism::{
23 set_num_threads, 23 set_num_threads,
24 set_max_threads, 24 set_max_threads,
25 }; 25 };
26 26
41 pub mod experiments; 41 pub mod experiments;
42 42
43 use types::{float, ClapFloat}; 43 use types::{float, ClapFloat};
44 use run::{ 44 use run::{
45 DefaultAlgorithm, 45 DefaultAlgorithm,
46 Configuration,
47 PlotLevel, 46 PlotLevel,
48 Named, 47 Named,
49 AlgorithmConfig, 48 AlgorithmConfig,
50 }; 49 };
51 use experiments::DefaultExperiment; 50 use experiments::DefaultExperiment;
52 use measures::merging::SpikeMergingMethod; 51 use measures::merging::SpikeMergingMethod;
53 use DefaultExperiment::*; 52 use DefaultExperiment::*;
54 use DefaultAlgorithm::*; 53 use DefaultAlgorithm::*;
55 54
56 /// Command line parameters 55 /// Command line parameters
57 #[derive(Parser, Debug)] 56 #[derive(Parser, Debug, Serialize)]
58 #[clap( 57 #[clap(
59 about = env!("CARGO_PKG_DESCRIPTION"), 58 about = env!("CARGO_PKG_DESCRIPTION"),
60 author = env!("CARGO_PKG_AUTHORS"), 59 author = env!("CARGO_PKG_AUTHORS"),
61 version = env!("CARGO_PKG_VERSION"), 60 version = env!("CARGO_PKG_VERSION"),
62 after_help = "Pass --help for longer descriptions.", 61 after_help = "Pass --help for longer descriptions.",
63 after_long_help = "", 62 after_long_help = "",
64 )] 63 )]
65 pub struct CommandLineArgs { 64 pub struct CommandLineArgs {
66 #[arg(long, short = 'm', value_name = "M")] 65 #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)]
67 /// Maximum iteration count 66 /// Maximum iteration count
68 max_iter : Option<usize>, 67 max_iter : usize,
69 68
70 #[arg(long, short = 'n', value_name = "N")] 69 #[arg(long, short = 'n', value_name = "N")]
71 /// Output status every N iterations. Set to 0 to disable. 70 /// Output status every N iterations. Set to 0 to disable.
71 ///
72 /// The default is to output status based on logarithmic increments.
72 verbose_iter : Option<usize>, 73 verbose_iter : Option<usize>,
73 74
74 #[arg(long, short = 'q')] 75 #[arg(long, short = 'q')]
75 /// Don't display iteration progress 76 /// Don't display iteration progress
76 quiet : bool, 77 quiet : bool,
92 93
93 /// Saved algorithm configration(s) to use on the experiments 94 /// Saved algorithm configration(s) to use on the experiments
94 #[arg(value_name = "JSON_FILE", long)] 95 #[arg(value_name = "JSON_FILE", long)]
95 saved_algorithm : Vec<String>, 96 saved_algorithm : Vec<String>,
96 97
97 /// Write plots for every verbose iteration 98 /// Plot saving scheme
98 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] 99 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)]
99 plot : PlotLevel, 100 plot : PlotLevel,
100 101
101 /// Directory for saving results 102 /// Directory for saving results
102 #[arg(long, short = 'o', default_value = "out")] 103 #[arg(long, short = 'o', required = true, default_value = "out")]
103 outdir : String, 104 outdir : String,
104 105
105 #[arg(long, help_heading = "Multi-threading", default_value = "4")] 106 #[arg(long, help_heading = "Multi-threading", default_value = "4")]
106 /// Maximum number of threads 107 /// Maximum number of threads
107 max_threads : usize, 108 max_threads : usize,
118 /// Algorithm parametrisation overrides 119 /// Algorithm parametrisation overrides
119 algoritm_overrides : AlgorithmOverrides<float>, 120 algoritm_overrides : AlgorithmOverrides<float>,
120 } 121 }
121 122
122 /// Command line experiment setup overrides 123 /// Command line experiment setup overrides
123 #[derive(Parser, Debug)] 124 #[derive(Parser, Debug, Serialize, Deserialize)]
124 pub struct ExperimentOverrides<F : ClapFloat> { 125 pub struct ExperimentOverrides<F : ClapFloat> {
125 #[arg(long)] 126 #[arg(long)]
126 /// Regularisation parameter override. 127 /// Regularisation parameter override.
127 /// 128 ///
128 /// Only use if running just a single experiment, as different experiments have different 129 /// Only use if running just a single experiment, as different experiments have different
141 /// Noise seed 142 /// Noise seed
142 noise_seed : Option<u64>, 143 noise_seed : Option<u64>,
143 } 144 }
144 145
145 /// Command line algorithm parametrisation overrides 146 /// Command line algorithm parametrisation overrides
146 #[derive(Parser, Debug)] 147 #[derive(Parser, Debug, Serialize, Deserialize)]
147 pub struct AlgorithmOverrides<F : ClapFloat> { 148 pub struct AlgorithmOverrides<F : ClapFloat> {
148 #[arg(long, value_names = &["COUNT", "EACH"])] 149 #[arg(long, value_names = &["COUNT", "EACH"])]
149 /// Override bootstrap insertion iterations for --algorithm. 150 /// Override bootstrap insertion iterations for --algorithm.
150 /// 151 ///
151 /// The first parameter is the number of bootstrap insertion iterations, and the second 152 /// The first parameter is the number of bootstrap insertion iterations, and the second
220 set_max_threads(m); 221 set_max_threads(m);
221 } 222 }
222 223
223 for experiment_shorthand in cli.experiments.iter().unique() { 224 for experiment_shorthand in cli.experiments.iter().unique() {
224 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); 225 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap();
225 let mut config : Configuration<float> = experiment.default_config();
226 let mut algs : Vec<Named<AlgorithmConfig<float>>> 226 let mut algs : Vec<Named<AlgorithmConfig<float>>>
227 = cli.algorithm.iter() 227 = cli.algorithm.iter()
228 .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides)) 228 .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides))
229 .collect(); 229 .collect();
230 for filename in cli.saved_algorithm.iter() { 230 for filename in cli.saved_algorithm.iter() {
231 let f = std::fs::File::open(filename).unwrap(); 231 let f = std::fs::File::open(filename).unwrap();
232 let alg = serde_json::from_reader(f).unwrap(); 232 let alg = serde_json::from_reader(f).unwrap();
233 algs.push(alg); 233 algs.push(alg);
234 } 234 }
235 cli.max_iter.map(|m| config.iterator_options.max_iter = m); 235 experiment.runall(&cli, (!algs.is_empty()).then_some(algs))
236 cli.verbose_iter.map(|n| config.iterator_options.verbose_iter = Verbose::Every(n));
237 config.plot = cli.plot;
238 config.iterator_options.quiet = cli.quiet;
239 config.outdir = cli.outdir.clone();
240 if !algs.is_empty() {
241 config.algorithms = algs.clone();
242 }
243
244 experiment.runall(config)
245 .unwrap() 236 .unwrap()
246 } 237 }
247 } 238 }

mercurial