src/main.rs

branch
dev
changeset 39
6316d68b58af
parent 37
c5d8bd1a7728
child 41
b6bdb6cb4d44
equal deleted inserted replaced
37:c5d8bd1a7728 39:6316d68b58af
14 #![feature(dec2flt)] 14 #![feature(dec2flt)]
15 15
16 use clap::Parser; 16 use clap::Parser;
17 use serde::{Serialize, Deserialize}; 17 use serde::{Serialize, Deserialize};
18 use serde_json; 18 use serde_json;
19 use serde_with::skip_serializing_none;
19 use itertools::Itertools; 20 use itertools::Itertools;
20 use std::num::NonZeroUsize; 21 use std::num::NonZeroUsize;
21 22
22 use alg_tools::parallelism::{ 23 use alg_tools::parallelism::{
23 set_num_threads, 24 set_num_threads,
54 PlotLevel, 55 PlotLevel,
55 Named, 56 Named,
56 AlgorithmConfig, 57 AlgorithmConfig,
57 }; 58 };
58 use experiments::DefaultExperiment; 59 use experiments::DefaultExperiment;
59 use measures::merging::SpikeMergingMethod;
60 use DefaultExperiment::*; 60 use DefaultExperiment::*;
61 use DefaultAlgorithm::*; 61 use DefaultAlgorithm::*;
62 62
63 /// Command line parameters 63 /// Command line parameters
64 #[derive(Parser, Debug, Serialize)] 64 #[skip_serializing_none]
65 #[derive(Parser, Debug, Serialize, Default, Clone)]
65 #[clap( 66 #[clap(
66 about = env!("CARGO_PKG_DESCRIPTION"), 67 about = env!("CARGO_PKG_DESCRIPTION"),
67 author = env!("CARGO_PKG_AUTHORS"), 68 author = env!("CARGO_PKG_AUTHORS"),
68 version = env!("CARGO_PKG_VERSION"), 69 version = env!("CARGO_PKG_VERSION"),
69 after_help = "Pass --help for longer descriptions.", 70 after_help = "Pass --help for longer descriptions.",
94 /// Default algorithm configration(s) to use on the experiments. 95 /// Default algorithm configration(s) to use on the experiments.
95 /// 96 ///
96 /// Not all algorithms are available for all the experiments. 97 /// Not all algorithms are available for all the experiments.
97 /// In particular, only PDPS is available for the experiments with L¹ data term. 98 /// In particular, only PDPS is available for the experiments with L¹ data term.
98 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', 99 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
99 default_values_t = [FB, FISTA, PDPS, SlidingFB, FW, FWRelax])] 100 default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])]
100 algorithm : Vec<DefaultAlgorithm>, 101 algorithm : Vec<DefaultAlgorithm>,
101 102
102 /// Saved algorithm configration(s) to use on the experiments 103 /// Saved algorithm configration(s) to use on the experiments
103 #[arg(value_name = "JSON_FILE", long)] 104 #[arg(value_name = "JSON_FILE", long)]
104 saved_algorithm : Vec<String>, 105 saved_algorithm : Vec<String>,
117 118
118 #[arg(long, help_heading = "Multi-threading")] 119 #[arg(long, help_heading = "Multi-threading")]
119 /// Number of threads. Overrides the maximum number. 120 /// Number of threads. Overrides the maximum number.
120 num_threads : Option<usize>, 121 num_threads : Option<usize>,
121 122
123 #[arg(long, default_value_t = false)]
124 /// Load saved value ranges (if exists) to do partial update.
125 load_valuerange : bool,
126
122 #[clap(flatten, next_help_heading = "Experiment overrides")] 127 #[clap(flatten, next_help_heading = "Experiment overrides")]
123 /// Experiment setup overrides 128 /// Experiment setup overrides
124 experiment_overrides : ExperimentOverrides<float>, 129 experiment_overrides : ExperimentOverrides<float>,
125 130
126 #[clap(flatten, next_help_heading = "Algorithm overrides")] 131 #[clap(flatten, next_help_heading = "Algorithm overrides")]
127 /// Algorithm parametrisation overrides 132 /// Algorithm parametrisation overrides
128 algoritm_overrides : AlgorithmOverrides<float>, 133 algoritm_overrides : AlgorithmOverrides<float>,
129 } 134 }
130 135
131 /// Command line experiment setup overrides 136 /// Command line experiment setup overrides
132 #[derive(Parser, Debug, Serialize, Deserialize)] 137 #[skip_serializing_none]
138 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
133 pub struct ExperimentOverrides<F : ClapFloat> { 139 pub struct ExperimentOverrides<F : ClapFloat> {
134 #[arg(long)] 140 #[arg(long)]
135 /// Regularisation parameter override. 141 /// Regularisation parameter override.
136 /// 142 ///
137 /// Only use if running just a single experiment, as different experiments have different 143 /// Only use if running just a single experiment, as different experiments have different
150 /// Noise seed 156 /// Noise seed
151 noise_seed : Option<u64>, 157 noise_seed : Option<u64>,
152 } 158 }
153 159
154 /// Command line algorithm parametrisation overrides 160 /// Command line algorithm parametrisation overrides
155 #[derive(Parser, Debug, Serialize, Deserialize)] 161 #[skip_serializing_none]
162 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
156 pub struct AlgorithmOverrides<F : ClapFloat> { 163 pub struct AlgorithmOverrides<F : ClapFloat> {
157 #[arg(long, value_names = &["COUNT", "EACH"])] 164 #[arg(long, value_names = &["COUNT", "EACH"])]
158 /// Override bootstrap insertion iterations for --algorithm. 165 /// Override bootstrap insertion iterations for --algorithm.
159 /// 166 ///
160 /// The first parameter is the number of bootstrap insertion iterations, and the second 167 /// The first parameter is the number of bootstrap insertion iterations, and the second
185 #[arg(long)] 192 #[arg(long)]
186 /// Normalised transport step length for sliding methods. 193 /// Normalised transport step length for sliding methods.
187 theta0 : Option<F>, 194 theta0 : Option<F>,
188 195
189 #[arg(long)] 196 #[arg(long)]
190 /// Transport toleranced wrt. ω 197 /// A priori transport tolerance multiplier (C_pri)
191 transport_tolerance_omega : Option<F>, 198 transport_tolerance_pri : Option<F>,
192 199
193 #[arg(long)] 200 #[arg(long)]
194 /// Transport toleranced wrt. ∇v 201 /// A posteriori transport tolerance multiplier (C_pos)
195 transport_tolerance_dv : Option<F>, 202 transport_tolerance_pos : Option<F>,
196 203
197 #[arg(long)] 204 #[arg(long)]
198 /// Transport adaptation factor. Must be in (0, 1). 205 /// Transport adaptation factor. Must be in (0, 1).
199 transport_adaptation : Option<F>, 206 transport_adaptation : Option<F>,
200 207
216 /// Merging frequency, if merging enabled (every n iterations) 223 /// Merging frequency, if merging enabled (every n iterations)
217 /// 224 ///
218 /// Only affects FB, FISTA, and PDPS. 225 /// Only affects FB, FISTA, and PDPS.
219 merge_every : Option<usize>, 226 merge_every : Option<usize>,
220 227
221 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] 228 #[arg(long)]
222 /// Merging strategy 229 /// Enable merging (default: determined by algorithm)
223 /// 230 merge : Option<bool>,
224 /// Either the string "none", or a radius value for heuristic merging. 231
225 merging : Option<SpikeMergingMethod<F>>, 232 #[arg(long)]
226 233 /// Merging radius (default: determined by experiment)
227 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] 234 merge_radius : Option<F>,
228 /// Final merging strategy 235
229 /// 236 #[arg(long)]
230 /// Either the string "none", or a radius value for heuristic merging. 237 /// Interpolate when merging (default : determined by algorithm)
231 /// Only affects FB, FISTA, and PDPS. 238 merge_interp : Option<bool>,
232 final_merging : Option<SpikeMergingMethod<F>>, 239
240 #[arg(long)]
241 /// Enable final merging (default: determined by algorithm)
242 final_merging : Option<bool>,
243
244 #[arg(long)]
245 /// Enable fitness-based merging for relevant FB-type methods.
246 /// This has worse convergence guarantees that merging based on optimality conditions.
247 fitness_merging : Option<bool>,
233 248
234 #[arg(long, value_names = &["ε", "θ", "p"])] 249 #[arg(long, value_names = &["ε", "θ", "p"])]
235 /// Set the tolerance to ε_k = ε/(1+θk)^p 250 /// Set the tolerance to ε_k = ε/(1+θk)^p
236 tolerance : Option<Vec<F>>, 251 tolerance : Option<Vec<F>>,
237 252
264 for experiment_shorthand in cli.experiments.iter().unique() { 279 for experiment_shorthand in cli.experiments.iter().unique() {
265 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); 280 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap();
266 let mut algs : Vec<Named<AlgorithmConfig<float>>> 281 let mut algs : Vec<Named<AlgorithmConfig<float>>>
267 = cli.algorithm 282 = cli.algorithm
268 .iter() 283 .iter()
269 .map(|alg| alg.to_named( 284 .map(|alg| {
270 experiment.algorithm_defaults(*alg) 285 let cfg = alg.default_config()
271 .unwrap_or_else(|| alg.default_config()) 286 .cli_override(&experiment.algorithm_overrides(*alg))
272 .cli_override(&cli.algoritm_overrides) 287 .cli_override(&cli.algoritm_overrides);
273 )) 288 alg.to_named(cfg)
289 })
274 .collect(); 290 .collect();
275 for filename in cli.saved_algorithm.iter() { 291 for filename in cli.saved_algorithm.iter() {
276 let f = std::fs::File::open(filename).unwrap(); 292 let f = std::fs::File::open(filename).unwrap();
277 let alg = serde_json::from_reader(f).unwrap(); 293 let alg = serde_json::from_reader(f).unwrap();
278 algs.push(alg); 294 algs.push(alg);

mercurial