src/main.rs

changeset 0
eb3c7813b67a
child 2
7a953a87b6c1
child 5
df971c81282e
equal deleted inserted replaced
-1:000000000000 0:eb3c7813b67a
1 // The main documentation is in the README.
2 #![doc = include_str!("../README.md")]
3
4 // We use unicode. We would like to use much more of it than Rust allows.
5 // Live with it. Embrace it.
6 #![allow(uncommon_codepoints)]
7 #![allow(mixed_script_confusables)]
8 #![allow(confusable_idents)]
9 // Linear operators may be writtten e.g. as `opA` for a resemblance
10 // to mathematical convention.
11 #![allow(non_snake_case)]
12 // We need the drain filter for inertial prune
13 #![feature(drain_filter)]
14
15 use clap::Parser;
16 use itertools::Itertools;
17 use serde_json;
18 use alg_tools::iterate::Verbose;
19 use alg_tools::parallelism::{
20 set_num_threads,
21 set_max_threads,
22 };
23 use std::num::NonZeroUsize;
24
25 pub mod types;
26 pub mod measures;
27 pub mod fourier;
28 pub mod kernels;
29 pub mod seminorms;
30 pub mod forward_model;
31 pub mod plot;
32 pub mod subproblem;
33 pub mod tolerance;
34 pub mod fb;
35 pub mod frank_wolfe;
36 pub mod pdps;
37 pub mod run;
38 pub mod rand_distr;
39 pub mod experiments;
40
41 use types::{float, ClapFloat};
42 use run::{
43 DefaultAlgorithm,
44 Configuration,
45 PlotLevel,
46 Named,
47 AlgorithmConfig,
48 };
49 use experiments::DefaultExperiment;
50 use measures::merging::SpikeMergingMethod;
51 use DefaultExperiment::*;
52 use DefaultAlgorithm::*;
53
54 /// Command line parameters
55 #[derive(Parser, Debug)]
56 #[clap(
57 about = env!("CARGO_PKG_DESCRIPTION"),
58 author = env!("CARGO_PKG_AUTHORS"),
59 version = env!("CARGO_PKG_VERSION"),
60 after_help = "Pass --help for longer descriptions.",
61 after_long_help = "",
62 )]
63 pub struct CommandLineArgs {
64 #[arg(long, short = 'm', value_name = "M")]
65 /// Maximum iteration count
66 max_iter : Option<usize>,
67
68 #[arg(long, short = 'n', value_name = "N")]
69 /// Output status every N iterations. Set to 0 to disable.
70 verbose_iter : Option<usize>,
71
72 #[arg(long, short = 'q')]
73 /// Don't display iteration progress
74 quiet : bool,
75
76 /// List of experiments to perform.
77 #[arg(value_enum, value_name = "EXPERIMENT",
78 default_values_t = [Experiment1D, Experiment1DFast,
79 Experiment2D, Experiment2DFast,
80 Experiment1D_L1])]
81 experiments : Vec<DefaultExperiment>,
82
83 /// Default algorithm configration(s) to use on the experiments.
84 ///
85 /// Not all algorithms are available for all the experiments.
86 /// In particular, only PDPS is available for the experiments with L¹ data term.
87 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
88 default_values_t = [FB, FISTA, PDPS, FW, FWRelax])]
89 algorithm : Vec<DefaultAlgorithm>,
90
91 /// Saved algorithm configration(s) to use on the experiments
92 #[arg(value_name = "JSON_FILE", long)]
93 saved_algorithm : Vec<String>,
94
95 /// Write plots for every verbose iteration
96 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)]
97 plot : PlotLevel,
98
99 /// Directory for saving results
100 #[arg(long, short = 'o', default_value = "out")]
101 outdir : String,
102
103 #[arg(long, help_heading = "Multi-threading", default_value = "4")]
104 /// Maximum number of threads
105 max_threads : usize,
106
107 #[arg(long, help_heading = "Multi-threading")]
108 /// Number of threads. Overrides the maximum number.
109 num_threads : Option<usize>,
110
111 #[clap(flatten, next_help_heading = "Experiment overrides")]
112 /// Experiment setup overrides
113 experiment_overrides : ExperimentOverrides<float>,
114
115 #[clap(flatten, next_help_heading = "Algorithm overrides")]
116 /// Algorithm parametrisation overrides
117 algoritm_overrides : AlgorithmOverrides<float>,
118 }
119
120 /// Command line experiment setup overrides
121 #[derive(Parser, Debug)]
122 pub struct ExperimentOverrides<F : ClapFloat> {
123 #[arg(long)]
124 /// Regularisation parameter override.
125 ///
126 /// Only use if running just a single experiment, as different experiments have different
127 /// regularisation parameters.
128 alpha : Option<F>,
129
130 #[arg(long)]
131 /// Gaussian noise variance override
132 variance : Option<F>,
133
134 #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])]
135 /// Salt and pepper noise override.
136 salt_and_pepper : Option<Vec<F>>,
137
138 #[arg(long)]
139 /// Noise seed
140 noise_seed : Option<u64>,
141 }
142
143 /// Command line algorithm parametrisation overrides
144 #[derive(Parser, Debug)]
145 pub struct AlgorithmOverrides<F : ClapFloat> {
146 #[arg(long, value_names = &["COUNT", "EACH"])]
147 /// Override bootstrap insertion iterations for --algorithm.
148 ///
149 /// The first parameter is the number of bootstrap insertion iterations, and the second
150 /// the maximum number of iterations on each of them.
151 bootstrap_insertions : Option<Vec<usize>>,
152
153 #[arg(long, requires = "algorithm")]
154 /// Primal step length parameter override for --algorithm.
155 ///
156 /// Only use if running just a single algorithm, as different algorithms have different
157 /// regularisation parameters. Does not affect the algorithms fw and fwrelax.
158 tau0 : Option<F>,
159
160 #[arg(long, requires = "algorithm")]
161 /// Dual step length parameter override for --algorithm.
162 ///
163 /// Only use if running just a single algorithm, as different algorithms have different
164 /// regularisation parameters. Only affects PDPS.
165 sigma0 : Option<F>,
166
167 #[arg(value_enum, long)]
168 /// PDPS acceleration, when available.
169 acceleration : Option<pdps::Acceleration>,
170
171 #[arg(long)]
172 /// Perform postprocess weight optimisation for saved iterations
173 ///
174 /// Only affects FB, FISTA, and PDPS.
175 postprocessing : Option<bool>,
176
177 #[arg(value_name = "n", long)]
178 /// Merging frequency, if merging enabled (every n iterations)
179 ///
180 /// Only affects FB, FISTA, and PDPS.
181 merge_every : Option<usize>,
182
183 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())]
184 /// Merging strategy
185 ///
186 /// Either the string "none", or a radius value for heuristic merging.
187 merging : Option<SpikeMergingMethod<F>>,
188
189 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())]
190 /// Final merging strategy
191 ///
192 /// Either the string "none", or a radius value for heuristic merging.
193 /// Only affects FB, FISTA, and PDPS.
194 final_merging : Option<SpikeMergingMethod<F>>,
195 }
196
197 /// The entry point for the program.
198 pub fn main() {
199 let cli = CommandLineArgs::parse();
200
201 if let Some(n_threads) = cli.num_threads {
202 let n = NonZeroUsize::new(n_threads).expect("Invalid thread count");
203 set_num_threads(n);
204 } else {
205 let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count");
206 set_max_threads(m);
207 }
208
209 for experiment_shorthand in cli.experiments.iter().unique() {
210 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap();
211 let mut config : Configuration<float> = experiment.default_config();
212 let mut algs : Vec<Named<AlgorithmConfig<float>>>
213 = cli.algorithm.iter()
214 .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides))
215 .collect();
216 for filename in cli.saved_algorithm.iter() {
217 let f = std::fs::File::open(filename).unwrap();
218 let alg = serde_json::from_reader(f).unwrap();
219 algs.push(alg);
220 }
221 cli.max_iter.map(|m| config.iterator_options.max_iter = m);
222 cli.verbose_iter.map(|n| config.iterator_options.verbose_iter = Verbose::Every(n));
223 config.plot = cli.plot;
224 config.iterator_options.quiet = cli.quiet;
225 config.outdir = cli.outdir.clone();
226 if !algs.is_empty() {
227 config.algorithms = algs.clone();
228 }
229
230 experiment.runall(config)
231 .unwrap()
232 }
233 }

mercurial