14 #![feature(dec2flt)] |
14 #![feature(dec2flt)] |
15 |
15 |
16 use clap::Parser; |
16 use clap::Parser; |
17 use serde::{Serialize, Deserialize}; |
17 use serde::{Serialize, Deserialize}; |
18 use serde_json; |
18 use serde_json; |
|
19 use serde_with::skip_serializing_none; |
19 use itertools::Itertools; |
20 use itertools::Itertools; |
20 use std::num::NonZeroUsize; |
21 use std::num::NonZeroUsize; |
21 |
22 |
22 use alg_tools::parallelism::{ |
23 use alg_tools::parallelism::{ |
23 set_num_threads, |
24 set_num_threads, |
54 PlotLevel, |
55 PlotLevel, |
55 Named, |
56 Named, |
56 AlgorithmConfig, |
57 AlgorithmConfig, |
57 }; |
58 }; |
58 use experiments::DefaultExperiment; |
59 use experiments::DefaultExperiment; |
59 use measures::merging::SpikeMergingMethod; |
|
60 use DefaultExperiment::*; |
60 use DefaultExperiment::*; |
61 use DefaultAlgorithm::*; |
61 use DefaultAlgorithm::*; |
62 |
62 |
63 /// Command line parameters |
63 /// Command line parameters |
64 #[derive(Parser, Debug, Serialize)] |
64 #[skip_serializing_none] |
|
65 #[derive(Parser, Debug, Serialize, Default, Clone)] |
65 #[clap( |
66 #[clap( |
66 about = env!("CARGO_PKG_DESCRIPTION"), |
67 about = env!("CARGO_PKG_DESCRIPTION"), |
67 author = env!("CARGO_PKG_AUTHORS"), |
68 author = env!("CARGO_PKG_AUTHORS"), |
68 version = env!("CARGO_PKG_VERSION"), |
69 version = env!("CARGO_PKG_VERSION"), |
69 after_help = "Pass --help for longer descriptions.", |
70 after_help = "Pass --help for longer descriptions.", |
94 /// Default algorithm configration(s) to use on the experiments. |
95 /// Default algorithm configration(s) to use on the experiments. |
95 /// |
96 /// |
96 /// Not all algorithms are available for all the experiments. |
97 /// Not all algorithms are available for all the experiments. |
97 /// In particular, only PDPS is available for the experiments with L¹ data term. |
98 /// In particular, only PDPS is available for the experiments with L¹ data term. |
98 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', |
99 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', |
99 default_values_t = [FB, FISTA, PDPS, SlidingFB, FW, FWRelax])] |
100 default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])] |
100 algorithm : Vec<DefaultAlgorithm>, |
101 algorithm : Vec<DefaultAlgorithm>, |
101 |
102 |
102 /// Saved algorithm configration(s) to use on the experiments |
103 /// Saved algorithm configration(s) to use on the experiments |
103 #[arg(value_name = "JSON_FILE", long)] |
104 #[arg(value_name = "JSON_FILE", long)] |
104 saved_algorithm : Vec<String>, |
105 saved_algorithm : Vec<String>, |
117 |
118 |
118 #[arg(long, help_heading = "Multi-threading")] |
119 #[arg(long, help_heading = "Multi-threading")] |
119 /// Number of threads. Overrides the maximum number. |
120 /// Number of threads. Overrides the maximum number. |
120 num_threads : Option<usize>, |
121 num_threads : Option<usize>, |
121 |
122 |
|
123 #[arg(long, default_value_t = false)] |
|
124 /// Load saved value ranges (if exists) to do partial update. |
|
125 load_valuerange : bool, |
|
126 |
122 #[clap(flatten, next_help_heading = "Experiment overrides")] |
127 #[clap(flatten, next_help_heading = "Experiment overrides")] |
123 /// Experiment setup overrides |
128 /// Experiment setup overrides |
124 experiment_overrides : ExperimentOverrides<float>, |
129 experiment_overrides : ExperimentOverrides<float>, |
125 |
130 |
126 #[clap(flatten, next_help_heading = "Algorithm overrides")] |
131 #[clap(flatten, next_help_heading = "Algorithm overrides")] |
127 /// Algorithm parametrisation overrides |
132 /// Algorithm parametrisation overrides |
128 algoritm_overrides : AlgorithmOverrides<float>, |
133 algoritm_overrides : AlgorithmOverrides<float>, |
129 } |
134 } |
130 |
135 |
131 /// Command line experiment setup overrides |
136 /// Command line experiment setup overrides |
132 #[derive(Parser, Debug, Serialize, Deserialize)] |
137 #[skip_serializing_none] |
|
138 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] |
133 pub struct ExperimentOverrides<F : ClapFloat> { |
139 pub struct ExperimentOverrides<F : ClapFloat> { |
134 #[arg(long)] |
140 #[arg(long)] |
135 /// Regularisation parameter override. |
141 /// Regularisation parameter override. |
136 /// |
142 /// |
137 /// Only use if running just a single experiment, as different experiments have different |
143 /// Only use if running just a single experiment, as different experiments have different |
150 /// Noise seed |
156 /// Noise seed |
151 noise_seed : Option<u64>, |
157 noise_seed : Option<u64>, |
152 } |
158 } |
153 |
159 |
154 /// Command line algorithm parametrisation overrides |
160 /// Command line algorithm parametrisation overrides |
155 #[derive(Parser, Debug, Serialize, Deserialize)] |
161 #[skip_serializing_none] |
|
162 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] |
156 pub struct AlgorithmOverrides<F : ClapFloat> { |
163 pub struct AlgorithmOverrides<F : ClapFloat> { |
157 #[arg(long, value_names = &["COUNT", "EACH"])] |
164 #[arg(long, value_names = &["COUNT", "EACH"])] |
158 /// Override bootstrap insertion iterations for --algorithm. |
165 /// Override bootstrap insertion iterations for --algorithm. |
159 /// |
166 /// |
160 /// The first parameter is the number of bootstrap insertion iterations, and the second |
167 /// The first parameter is the number of bootstrap insertion iterations, and the second |
185 #[arg(long)] |
192 #[arg(long)] |
186 /// Normalised transport step length for sliding methods. |
193 /// Normalised transport step length for sliding methods. |
187 theta0 : Option<F>, |
194 theta0 : Option<F>, |
188 |
195 |
189 #[arg(long)] |
196 #[arg(long)] |
190 /// Transport toleranced wrt. ω |
197 /// A priori transport tolerance multiplier (C_pri) |
191 transport_tolerance_omega : Option<F>, |
198 transport_tolerance_pri : Option<F>, |
192 |
199 |
193 #[arg(long)] |
200 #[arg(long)] |
194 /// Transport toleranced wrt. ∇v |
201 /// A posteriori transport tolerance multiplier (C_pos) |
195 transport_tolerance_dv : Option<F>, |
202 transport_tolerance_pos : Option<F>, |
196 |
203 |
197 #[arg(long)] |
204 #[arg(long)] |
198 /// Transport adaptation factor. Must be in (0, 1). |
205 /// Transport adaptation factor. Must be in (0, 1). |
199 transport_adaptation : Option<F>, |
206 transport_adaptation : Option<F>, |
200 |
207 |
216 /// Merging frequency, if merging enabled (every n iterations) |
223 /// Merging frequency, if merging enabled (every n iterations) |
217 /// |
224 /// |
218 /// Only affects FB, FISTA, and PDPS. |
225 /// Only affects FB, FISTA, and PDPS. |
219 merge_every : Option<usize>, |
226 merge_every : Option<usize>, |
220 |
227 |
221 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] |
228 #[arg(long)] |
222 /// Merging strategy |
229 /// Enable merging (default: determined by algorithm) |
223 /// |
230 merge : Option<bool>, |
224 /// Either the string "none", or a radius value for heuristic merging. |
231 |
225 merging : Option<SpikeMergingMethod<F>>, |
232 #[arg(long)] |
226 |
233 /// Merging radius (default: determined by experiment) |
227 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] |
234 merge_radius : Option<F>, |
228 /// Final merging strategy |
235 |
229 /// |
236 #[arg(long)] |
230 /// Either the string "none", or a radius value for heuristic merging. |
237 /// Interpolate when merging (default : determined by algorithm) |
231 /// Only affects FB, FISTA, and PDPS. |
238 merge_interp : Option<bool>, |
232 final_merging : Option<SpikeMergingMethod<F>>, |
239 |
|
240 #[arg(long)] |
|
241 /// Enable final merging (default: determined by algorithm) |
|
242 final_merging : Option<bool>, |
|
243 |
|
244 #[arg(long)] |
|
245 /// Enable fitness-based merging for relevant FB-type methods. |
|
246 /// This has worse convergence guarantees that merging based on optimality conditions. |
|
247 fitness_merging : Option<bool>, |
233 |
248 |
234 #[arg(long, value_names = &["ε", "θ", "p"])] |
249 #[arg(long, value_names = &["ε", "θ", "p"])] |
235 /// Set the tolerance to ε_k = ε/(1+θk)^p |
250 /// Set the tolerance to ε_k = ε/(1+θk)^p |
236 tolerance : Option<Vec<F>>, |
251 tolerance : Option<Vec<F>>, |
237 |
252 |
264 for experiment_shorthand in cli.experiments.iter().unique() { |
279 for experiment_shorthand in cli.experiments.iter().unique() { |
265 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); |
280 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); |
266 let mut algs : Vec<Named<AlgorithmConfig<float>>> |
281 let mut algs : Vec<Named<AlgorithmConfig<float>>> |
267 = cli.algorithm |
282 = cli.algorithm |
268 .iter() |
283 .iter() |
269 .map(|alg| alg.to_named( |
284 .map(|alg| { |
270 experiment.algorithm_defaults(*alg) |
285 let cfg = alg.default_config() |
271 .unwrap_or_else(|| alg.default_config()) |
286 .cli_override(&experiment.algorithm_overrides(*alg)) |
272 .cli_override(&cli.algoritm_overrides) |
287 .cli_override(&cli.algoritm_overrides); |
273 )) |
288 alg.to_named(cfg) |
|
289 }) |
274 .collect(); |
290 .collect(); |
275 for filename in cli.saved_algorithm.iter() { |
291 for filename in cli.saved_algorithm.iter() { |
276 let f = std::fs::File::open(filename).unwrap(); |
292 let f = std::fs::File::open(filename).unwrap(); |
277 let alg = serde_json::from_reader(f).unwrap(); |
293 let alg = serde_json::from_reader(f).unwrap(); |
278 algs.push(alg); |
294 algs.push(alg); |