12 #![allow(non_snake_case)] |
12 #![allow(non_snake_case)] |
13 // We need the drain filter for inertial prune. |
13 // We need the drain filter for inertial prune. |
14 #![feature(drain_filter)] |
14 #![feature(drain_filter)] |
15 |
15 |
16 use clap::Parser; |
16 use clap::Parser; |
|
17 use serde::{Serialize, Deserialize}; |
|
18 use serde_json; |
17 use itertools::Itertools; |
19 use itertools::Itertools; |
18 use serde_json; |
|
19 use std::num::NonZeroUsize; |
20 use std::num::NonZeroUsize; |
20 |
21 |
21 use alg_tools::iterate::Verbose; |
|
22 use alg_tools::parallelism::{ |
22 use alg_tools::parallelism::{ |
23 set_num_threads, |
23 set_num_threads, |
24 set_max_threads, |
24 set_max_threads, |
25 }; |
25 }; |
26 |
26 |
41 pub mod experiments; |
41 pub mod experiments; |
42 |
42 |
43 use types::{float, ClapFloat}; |
43 use types::{float, ClapFloat}; |
44 use run::{ |
44 use run::{ |
45 DefaultAlgorithm, |
45 DefaultAlgorithm, |
46 Configuration, |
|
47 PlotLevel, |
46 PlotLevel, |
48 Named, |
47 Named, |
49 AlgorithmConfig, |
48 AlgorithmConfig, |
50 }; |
49 }; |
51 use experiments::DefaultExperiment; |
50 use experiments::DefaultExperiment; |
52 use measures::merging::SpikeMergingMethod; |
51 use measures::merging::SpikeMergingMethod; |
53 use DefaultExperiment::*; |
52 use DefaultExperiment::*; |
54 use DefaultAlgorithm::*; |
53 use DefaultAlgorithm::*; |
55 |
54 |
56 /// Command line parameters |
55 /// Command line parameters |
57 #[derive(Parser, Debug)] |
56 #[derive(Parser, Debug, Serialize)] |
58 #[clap( |
57 #[clap( |
59 about = env!("CARGO_PKG_DESCRIPTION"), |
58 about = env!("CARGO_PKG_DESCRIPTION"), |
60 author = env!("CARGO_PKG_AUTHORS"), |
59 author = env!("CARGO_PKG_AUTHORS"), |
61 version = env!("CARGO_PKG_VERSION"), |
60 version = env!("CARGO_PKG_VERSION"), |
62 after_help = "Pass --help for longer descriptions.", |
61 after_help = "Pass --help for longer descriptions.", |
63 after_long_help = "", |
62 after_long_help = "", |
64 )] |
63 )] |
65 pub struct CommandLineArgs { |
64 pub struct CommandLineArgs { |
66 #[arg(long, short = 'm', value_name = "M")] |
65 #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)] |
67 /// Maximum iteration count |
66 /// Maximum iteration count |
68 max_iter : Option<usize>, |
67 max_iter : usize, |
69 |
68 |
70 #[arg(long, short = 'n', value_name = "N")] |
69 #[arg(long, short = 'n', value_name = "N")] |
71 /// Output status every N iterations. Set to 0 to disable. |
70 /// Output status every N iterations. Set to 0 to disable. |
|
71 /// |
|
72 /// The default is to output status based on logarithmic increments. |
72 verbose_iter : Option<usize>, |
73 verbose_iter : Option<usize>, |
73 |
74 |
74 #[arg(long, short = 'q')] |
75 #[arg(long, short = 'q')] |
75 /// Don't display iteration progress |
76 /// Don't display iteration progress |
76 quiet : bool, |
77 quiet : bool, |
92 |
93 |
93 /// Saved algorithm configration(s) to use on the experiments |
94 /// Saved algorithm configration(s) to use on the experiments |
94 #[arg(value_name = "JSON_FILE", long)] |
95 #[arg(value_name = "JSON_FILE", long)] |
95 saved_algorithm : Vec<String>, |
96 saved_algorithm : Vec<String>, |
96 |
97 |
97 /// Write plots for every verbose iteration |
98 /// Plot saving scheme |
98 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] |
99 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] |
99 plot : PlotLevel, |
100 plot : PlotLevel, |
100 |
101 |
101 /// Directory for saving results |
102 /// Directory for saving results |
102 #[arg(long, short = 'o', default_value = "out")] |
103 #[arg(long, short = 'o', required = true, default_value = "out")] |
103 outdir : String, |
104 outdir : String, |
104 |
105 |
105 #[arg(long, help_heading = "Multi-threading", default_value = "4")] |
106 #[arg(long, help_heading = "Multi-threading", default_value = "4")] |
106 /// Maximum number of threads |
107 /// Maximum number of threads |
107 max_threads : usize, |
108 max_threads : usize, |
118 /// Algorithm parametrisation overrides |
119 /// Algorithm parametrisation overrides |
119 algoritm_overrides : AlgorithmOverrides<float>, |
120 algoritm_overrides : AlgorithmOverrides<float>, |
120 } |
121 } |
121 |
122 |
122 /// Command line experiment setup overrides |
123 /// Command line experiment setup overrides |
123 #[derive(Parser, Debug)] |
124 #[derive(Parser, Debug, Serialize, Deserialize)] |
124 pub struct ExperimentOverrides<F : ClapFloat> { |
125 pub struct ExperimentOverrides<F : ClapFloat> { |
125 #[arg(long)] |
126 #[arg(long)] |
126 /// Regularisation parameter override. |
127 /// Regularisation parameter override. |
127 /// |
128 /// |
128 /// Only use if running just a single experiment, as different experiments have different |
129 /// Only use if running just a single experiment, as different experiments have different |
141 /// Noise seed |
142 /// Noise seed |
142 noise_seed : Option<u64>, |
143 noise_seed : Option<u64>, |
143 } |
144 } |
144 |
145 |
145 /// Command line algorithm parametrisation overrides |
146 /// Command line algorithm parametrisation overrides |
146 #[derive(Parser, Debug)] |
147 #[derive(Parser, Debug, Serialize, Deserialize)] |
147 pub struct AlgorithmOverrides<F : ClapFloat> { |
148 pub struct AlgorithmOverrides<F : ClapFloat> { |
148 #[arg(long, value_names = &["COUNT", "EACH"])] |
149 #[arg(long, value_names = &["COUNT", "EACH"])] |
149 /// Override bootstrap insertion iterations for --algorithm. |
150 /// Override bootstrap insertion iterations for --algorithm. |
150 /// |
151 /// |
151 /// The first parameter is the number of bootstrap insertion iterations, and the second |
152 /// The first parameter is the number of bootstrap insertion iterations, and the second |
220 set_max_threads(m); |
221 set_max_threads(m); |
221 } |
222 } |
222 |
223 |
223 for experiment_shorthand in cli.experiments.iter().unique() { |
224 for experiment_shorthand in cli.experiments.iter().unique() { |
224 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); |
225 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); |
225 let mut config : Configuration<float> = experiment.default_config(); |
|
226 let mut algs : Vec<Named<AlgorithmConfig<float>>> |
226 let mut algs : Vec<Named<AlgorithmConfig<float>>> |
227 = cli.algorithm.iter() |
227 = cli.algorithm.iter() |
228 .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides)) |
228 .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides)) |
229 .collect(); |
229 .collect(); |
230 for filename in cli.saved_algorithm.iter() { |
230 for filename in cli.saved_algorithm.iter() { |
231 let f = std::fs::File::open(filename).unwrap(); |
231 let f = std::fs::File::open(filename).unwrap(); |
232 let alg = serde_json::from_reader(f).unwrap(); |
232 let alg = serde_json::from_reader(f).unwrap(); |
233 algs.push(alg); |
233 algs.push(alg); |
234 } |
234 } |
235 cli.max_iter.map(|m| config.iterator_options.max_iter = m); |
235 experiment.runall(&cli, (!algs.is_empty()).then_some(algs)) |
236 cli.verbose_iter.map(|n| config.iterator_options.verbose_iter = Verbose::Every(n)); |
|
237 config.plot = cli.plot; |
|
238 config.iterator_options.quiet = cli.quiet; |
|
239 config.outdir = cli.outdir.clone(); |
|
240 if !algs.is_empty() { |
|
241 config.algorithms = algs.clone(); |
|
242 } |
|
243 |
|
244 experiment.runall(config) |
|
245 .unwrap() |
236 .unwrap() |
246 } |
237 } |
247 } |
238 } |