src/main.rs

changeset 52
f0e8704d3f0e
parent 44
03251c546744
equal deleted inserted replaced
31:6105b5cd8d89 52:f0e8704d3f0e
8 #![allow(mixed_script_confusables)] 8 #![allow(mixed_script_confusables)]
9 #![allow(confusable_idents)] 9 #![allow(confusable_idents)]
10 // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical 10 // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical
11 // convention while referring to the type (trait) of the operator as `A`. 11 // convention while referring to the type (trait) of the operator as `A`.
12 #![allow(non_snake_case)] 12 #![allow(non_snake_case)]
13 // We need the drain filter for inertial prune. 13 // Need to create parse errors
14 #![feature(drain_filter)] 14 #![feature(dec2flt)]
15 15
16 use clap::Parser; 16 use clap::Parser;
17 use serde::{Serialize, Deserialize}; 17 use serde::{Serialize, Deserialize};
18 use serde_json; 18 use serde_json;
19 use serde_with::skip_serializing_none;
19 use itertools::Itertools; 20 use itertools::Itertools;
20 use std::num::NonZeroUsize; 21 use std::num::NonZeroUsize;
21 22
22 use alg_tools::parallelism::{ 23 use alg_tools::parallelism::{
23 set_num_threads, 24 set_num_threads,
28 pub mod measures; 29 pub mod measures;
29 pub mod fourier; 30 pub mod fourier;
30 pub mod kernels; 31 pub mod kernels;
31 pub mod seminorms; 32 pub mod seminorms;
32 pub mod forward_model; 33 pub mod forward_model;
34 pub mod preadjoint_helper;
33 pub mod plot; 35 pub mod plot;
34 pub mod subproblem; 36 pub mod subproblem;
35 pub mod tolerance; 37 pub mod tolerance;
36 pub mod regularisation; 38 pub mod regularisation;
39 pub mod dataterm;
40 pub mod prox_penalty;
37 pub mod fb; 41 pub mod fb;
42 pub mod sliding_fb;
43 pub mod sliding_pdps;
44 pub mod forward_pdps;
38 pub mod frank_wolfe; 45 pub mod frank_wolfe;
39 pub mod pdps; 46 pub mod pdps;
40 pub mod run; 47 pub mod run;
41 pub mod rand_distr; 48 pub mod rand_distr;
42 pub mod experiments; 49 pub mod experiments;
47 PlotLevel, 54 PlotLevel,
48 Named, 55 Named,
49 AlgorithmConfig, 56 AlgorithmConfig,
50 }; 57 };
51 use experiments::DefaultExperiment; 58 use experiments::DefaultExperiment;
52 use measures::merging::SpikeMergingMethod;
53 use DefaultExperiment::*; 59 use DefaultExperiment::*;
54 use DefaultAlgorithm::*; 60 use DefaultAlgorithm::*;
55 61
56 /// Command line parameters 62 /// Command line parameters
57 #[derive(Parser, Debug, Serialize)] 63 #[skip_serializing_none]
64 #[derive(Parser, Debug, Serialize, Default, Clone)]
58 #[clap( 65 #[clap(
59 about = env!("CARGO_PKG_DESCRIPTION"), 66 about = env!("CARGO_PKG_DESCRIPTION"),
60 author = env!("CARGO_PKG_AUTHORS"), 67 author = env!("CARGO_PKG_AUTHORS"),
61 version = env!("CARGO_PKG_VERSION"), 68 version = env!("CARGO_PKG_VERSION"),
62 after_help = "Pass --help for longer descriptions.", 69 after_help = "Pass --help for longer descriptions.",
87 /// Default algorithm configration(s) to use on the experiments. 94 /// Default algorithm configration(s) to use on the experiments.
88 /// 95 ///
89 /// Not all algorithms are available for all the experiments. 96 /// Not all algorithms are available for all the experiments.
90 /// In particular, only PDPS is available for the experiments with L¹ data term. 97 /// In particular, only PDPS is available for the experiments with L¹ data term.
91 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', 98 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
92 default_values_t = [FB, FISTA, PDPS, FW, FWRelax])] 99 default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])]
93 algorithm : Vec<DefaultAlgorithm>, 100 algorithm : Vec<DefaultAlgorithm>,
94 101
95 /// Saved algorithm configration(s) to use on the experiments 102 /// Saved algorithm configration(s) to use on the experiments
96 #[arg(value_name = "JSON_FILE", long)] 103 #[arg(value_name = "JSON_FILE", long)]
97 saved_algorithm : Vec<String>, 104 saved_algorithm : Vec<String>,
110 117
111 #[arg(long, help_heading = "Multi-threading")] 118 #[arg(long, help_heading = "Multi-threading")]
112 /// Number of threads. Overrides the maximum number. 119 /// Number of threads. Overrides the maximum number.
113 num_threads : Option<usize>, 120 num_threads : Option<usize>,
114 121
122 #[arg(long, default_value_t = false)]
123 /// Load saved value ranges (if exists) to do partial update.
124 load_valuerange : bool,
125
115 #[clap(flatten, next_help_heading = "Experiment overrides")] 126 #[clap(flatten, next_help_heading = "Experiment overrides")]
116 /// Experiment setup overrides 127 /// Experiment setup overrides
117 experiment_overrides : ExperimentOverrides<float>, 128 experiment_overrides : ExperimentOverrides<float>,
118 129
119 #[clap(flatten, next_help_heading = "Algorithm overrides")] 130 #[clap(flatten, next_help_heading = "Algorithm overrides")]
120 /// Algorithm parametrisation overrides 131 /// Algorithm parametrisation overrides
121 algoritm_overrides : AlgorithmOverrides<float>, 132 algoritm_overrides : AlgorithmOverrides<float>,
122 } 133 }
123 134
124 /// Command line experiment setup overrides 135 /// Command line experiment setup overrides
125 #[derive(Parser, Debug, Serialize, Deserialize)] 136 #[skip_serializing_none]
137 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
126 pub struct ExperimentOverrides<F : ClapFloat> { 138 pub struct ExperimentOverrides<F : ClapFloat> {
127 #[arg(long)] 139 #[arg(long)]
128 /// Regularisation parameter override. 140 /// Regularisation parameter override.
129 /// 141 ///
130 /// Only use if running just a single experiment, as different experiments have different 142 /// Only use if running just a single experiment, as different experiments have different
143 /// Noise seed 155 /// Noise seed
144 noise_seed : Option<u64>, 156 noise_seed : Option<u64>,
145 } 157 }
146 158
147 /// Command line algorithm parametrisation overrides 159 /// Command line algorithm parametrisation overrides
148 #[derive(Parser, Debug, Serialize, Deserialize)] 160 #[skip_serializing_none]
161 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
149 pub struct AlgorithmOverrides<F : ClapFloat> { 162 pub struct AlgorithmOverrides<F : ClapFloat> {
150 #[arg(long, value_names = &["COUNT", "EACH"])] 163 #[arg(long, value_names = &["COUNT", "EACH"])]
151 /// Override bootstrap insertion iterations for --algorithm. 164 /// Override bootstrap insertion iterations for --algorithm.
152 /// 165 ///
153 /// The first parameter is the number of bootstrap insertion iterations, and the second 166 /// The first parameter is the number of bootstrap insertion iterations, and the second
160 /// Only use if running just a single algorithm, as different algorithms have different 173 /// Only use if running just a single algorithm, as different algorithms have different
161 /// regularisation parameters. Does not affect the algorithms fw and fwrelax. 174 /// regularisation parameters. Does not affect the algorithms fw and fwrelax.
162 tau0 : Option<F>, 175 tau0 : Option<F>,
163 176
164 #[arg(long, requires = "algorithm")] 177 #[arg(long, requires = "algorithm")]
178 /// Second primal step length parameter override for SlidingPDPS.
179 ///
180 /// Only use if running just a single algorithm, as different algorithms have different
181 /// regularisation parameters.
182 sigmap0 : Option<F>,
183
184 #[arg(long, requires = "algorithm")]
165 /// Dual step length parameter override for --algorithm. 185 /// Dual step length parameter override for --algorithm.
166 /// 186 ///
167 /// Only use if running just a single algorithm, as different algorithms have different 187 /// Only use if running just a single algorithm, as different algorithms have different
168 /// regularisation parameters. Only affects PDPS. 188 /// regularisation parameters. Only affects PDPS.
169 sigma0 : Option<F>, 189 sigma0 : Option<F>,
170 190
191 #[arg(long)]
192 /// Normalised transport step length for sliding methods.
193 theta0 : Option<F>,
194
195 #[arg(long)]
196 /// A posteriori transport tolerance multiplier (C_pos)
197 transport_tolerance_pos : Option<F>,
198
199 #[arg(long)]
200 /// Transport adaptation factor. Must be in (0, 1).
201 transport_adaptation : Option<F>,
202
203 #[arg(long)]
204 /// Minimal step length parameter for sliding methods.
205 tau0_min : Option<F>,
206
171 #[arg(value_enum, long)] 207 #[arg(value_enum, long)]
172 /// PDPS acceleration, when available. 208 /// PDPS acceleration, when available.
173 acceleration : Option<pdps::Acceleration>, 209 acceleration : Option<pdps::Acceleration>,
174 210
175 #[arg(long)] 211 // #[arg(long)]
176 /// Perform postprocess weight optimisation for saved iterations 212 // /// Perform postprocess weight optimisation for saved iterations
177 /// 213 // ///
178 /// Only affects FB, FISTA, and PDPS. 214 // /// Only affects FB, FISTA, and PDPS.
179 postprocessing : Option<bool>, 215 // postprocessing : Option<bool>,
180 216
181 #[arg(value_name = "n", long)] 217 #[arg(value_name = "n", long)]
182 /// Merging frequency, if merging enabled (every n iterations) 218 /// Merging frequency, if merging enabled (every n iterations)
183 /// 219 ///
184 /// Only affects FB, FISTA, and PDPS. 220 /// Only affects FB, FISTA, and PDPS.
185 merge_every : Option<usize>, 221 merge_every : Option<usize>,
186 222
187 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] 223 #[arg(long)]
188 /// Merging strategy 224 /// Enable merging (default: determined by algorithm)
189 /// 225 merge : Option<bool>,
190 /// Either the string "none", or a radius value for heuristic merging. 226
191 merging : Option<SpikeMergingMethod<F>>, 227 #[arg(long)]
192 228 /// Merging radius (default: determined by experiment)
193 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] 229 merge_radius : Option<F>,
194 /// Final merging strategy 230
195 /// 231 #[arg(long)]
196 /// Either the string "none", or a radius value for heuristic merging. 232 /// Interpolate when merging (default : determined by algorithm)
197 /// Only affects FB, FISTA, and PDPS. 233 merge_interp : Option<bool>,
198 final_merging : Option<SpikeMergingMethod<F>>, 234
235 #[arg(long)]
236 /// Enable final merging (default: determined by algorithm)
237 final_merging : Option<bool>,
238
239 #[arg(long)]
240 /// Enable fitness-based merging for relevant FB-type methods.
241 /// This has worse convergence guarantees that merging based on optimality conditions.
242 fitness_merging : Option<bool>,
199 243
200 #[arg(long, value_names = &["ε", "θ", "p"])] 244 #[arg(long, value_names = &["ε", "θ", "p"])]
201 /// Set the tolerance to ε_k = ε/(1+θk)^p 245 /// Set the tolerance to ε_k = ε/(1+θk)^p
202 tolerance : Option<Vec<F>>, 246 tolerance : Option<Vec<F>>,
203 247
228 } 272 }
229 273
230 for experiment_shorthand in cli.experiments.iter().unique() { 274 for experiment_shorthand in cli.experiments.iter().unique() {
231 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); 275 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap();
232 let mut algs : Vec<Named<AlgorithmConfig<float>>> 276 let mut algs : Vec<Named<AlgorithmConfig<float>>>
233 = cli.algorithm.iter() 277 = cli.algorithm
234 .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides)) 278 .iter()
235 .collect(); 279 .map(|alg| {
280 let cfg = alg.default_config()
281 .cli_override(&experiment.algorithm_overrides(*alg))
282 .cli_override(&cli.algoritm_overrides);
283 alg.to_named(cfg)
284 })
285 .collect();
236 for filename in cli.saved_algorithm.iter() { 286 for filename in cli.saved_algorithm.iter() {
237 let f = std::fs::File::open(filename).unwrap(); 287 let f = std::fs::File::open(filename).unwrap();
238 let alg = serde_json::from_reader(f).unwrap(); 288 let alg = serde_json::from_reader(f).unwrap();
239 algs.push(alg); 289 algs.push(alg);
240 } 290 }

mercurial