Tue, 29 Nov 2022 15:36:12 +0200
fubar
0 | 1 | // The main documentation is in the README. |
2 | #![doc = include_str!("../README.md")] | |
3 | ||
4 | // We use unicode. We would like to use much more of it than Rust allows. | |
5 | // Live with it. Embrace it. | |
6 | #![allow(uncommon_codepoints)] | |
7 | #![allow(mixed_script_confusables)] | |
8 | #![allow(confusable_idents)] | |
9 | // Linear operators may be writtten e.g. as `opA` for a resemblance | |
10 | // to mathematical convention. | |
11 | #![allow(non_snake_case)] | |
12 | // We need the drain filter for inertial prune | |
13 | #![feature(drain_filter)] | |
14 | ||
15 | use clap::Parser; | |
16 | use itertools::Itertools; | |
17 | use serde_json; | |
18 | use alg_tools::iterate::Verbose; | |
19 | use alg_tools::parallelism::{ | |
20 | set_num_threads, | |
21 | set_max_threads, | |
22 | }; | |
23 | use std::num::NonZeroUsize; | |
24 | ||
25 | pub mod types; | |
26 | pub mod measures; | |
27 | pub mod fourier; | |
28 | pub mod kernels; | |
29 | pub mod seminorms; | |
30 | pub mod forward_model; | |
31 | pub mod plot; | |
32 | pub mod subproblem; | |
2 | 33 | pub mod weight_optim; |
0 | 34 | pub mod tolerance; |
35 | pub mod fb; | |
36 | pub mod frank_wolfe; | |
37 | pub mod pdps; | |
38 | pub mod run; | |
39 | pub mod rand_distr; | |
40 | pub mod experiments; | |
41 | ||
42 | use types::{float, ClapFloat}; | |
43 | use run::{ | |
44 | DefaultAlgorithm, | |
45 | Configuration, | |
46 | PlotLevel, | |
47 | Named, | |
48 | AlgorithmConfig, | |
49 | }; | |
50 | use experiments::DefaultExperiment; | |
51 | use measures::merging::SpikeMergingMethod; | |
52 | use DefaultExperiment::*; | |
53 | use DefaultAlgorithm::*; | |
54 | ||
55 | /// Command line parameters | |
56 | #[derive(Parser, Debug)] | |
57 | #[clap( | |
58 | about = env!("CARGO_PKG_DESCRIPTION"), | |
59 | author = env!("CARGO_PKG_AUTHORS"), | |
60 | version = env!("CARGO_PKG_VERSION"), | |
61 | after_help = "Pass --help for longer descriptions.", | |
62 | after_long_help = "", | |
63 | )] | |
64 | pub struct CommandLineArgs { | |
65 | #[arg(long, short = 'm', value_name = "M")] | |
66 | /// Maximum iteration count | |
67 | max_iter : Option<usize>, | |
68 | ||
69 | #[arg(long, short = 'n', value_name = "N")] | |
70 | /// Output status every N iterations. Set to 0 to disable. | |
71 | verbose_iter : Option<usize>, | |
72 | ||
73 | #[arg(long, short = 'q')] | |
74 | /// Don't display iteration progress | |
75 | quiet : bool, | |
76 | ||
77 | /// List of experiments to perform. | |
78 | #[arg(value_enum, value_name = "EXPERIMENT", | |
79 | default_values_t = [Experiment1D, Experiment1DFast, | |
80 | Experiment2D, Experiment2DFast, | |
81 | Experiment1D_L1])] | |
82 | experiments : Vec<DefaultExperiment>, | |
83 | ||
84 | /// Default algorithm configration(s) to use on the experiments. | |
85 | /// | |
86 | /// Not all algorithms are available for all the experiments. | |
87 | /// In particular, only PDPS is available for the experiments with L¹ data term. | |
88 | #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', | |
89 | default_values_t = [FB, FISTA, PDPS, FW, FWRelax])] | |
90 | algorithm : Vec<DefaultAlgorithm>, | |
91 | ||
92 | /// Saved algorithm configration(s) to use on the experiments | |
93 | #[arg(value_name = "JSON_FILE", long)] | |
94 | saved_algorithm : Vec<String>, | |
95 | ||
96 | /// Write plots for every verbose iteration | |
97 | #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] | |
98 | plot : PlotLevel, | |
99 | ||
100 | /// Directory for saving results | |
101 | #[arg(long, short = 'o', default_value = "out")] | |
102 | outdir : String, | |
103 | ||
104 | #[arg(long, help_heading = "Multi-threading", default_value = "4")] | |
105 | /// Maximum number of threads | |
106 | max_threads : usize, | |
107 | ||
108 | #[arg(long, help_heading = "Multi-threading")] | |
109 | /// Number of threads. Overrides the maximum number. | |
110 | num_threads : Option<usize>, | |
111 | ||
112 | #[clap(flatten, next_help_heading = "Experiment overrides")] | |
113 | /// Experiment setup overrides | |
114 | experiment_overrides : ExperimentOverrides<float>, | |
115 | ||
116 | #[clap(flatten, next_help_heading = "Algorithm overrides")] | |
117 | /// Algorithm parametrisation overrides | |
118 | algoritm_overrides : AlgorithmOverrides<float>, | |
119 | } | |
120 | ||
121 | /// Command line experiment setup overrides | |
122 | #[derive(Parser, Debug)] | |
123 | pub struct ExperimentOverrides<F : ClapFloat> { | |
124 | #[arg(long)] | |
125 | /// Regularisation parameter override. | |
126 | /// | |
127 | /// Only use if running just a single experiment, as different experiments have different | |
128 | /// regularisation parameters. | |
129 | alpha : Option<F>, | |
130 | ||
131 | #[arg(long)] | |
132 | /// Gaussian noise variance override | |
133 | variance : Option<F>, | |
134 | ||
135 | #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])] | |
136 | /// Salt and pepper noise override. | |
137 | salt_and_pepper : Option<Vec<F>>, | |
138 | ||
139 | #[arg(long)] | |
140 | /// Noise seed | |
141 | noise_seed : Option<u64>, | |
142 | } | |
143 | ||
144 | /// Command line algorithm parametrisation overrides | |
145 | #[derive(Parser, Debug)] | |
146 | pub struct AlgorithmOverrides<F : ClapFloat> { | |
147 | #[arg(long, value_names = &["COUNT", "EACH"])] | |
148 | /// Override bootstrap insertion iterations for --algorithm. | |
149 | /// | |
150 | /// The first parameter is the number of bootstrap insertion iterations, and the second | |
151 | /// the maximum number of iterations on each of them. | |
152 | bootstrap_insertions : Option<Vec<usize>>, | |
153 | ||
154 | #[arg(long, requires = "algorithm")] | |
155 | /// Primal step length parameter override for --algorithm. | |
156 | /// | |
157 | /// Only use if running just a single algorithm, as different algorithms have different | |
158 | /// regularisation parameters. Does not affect the algorithms fw and fwrelax. | |
159 | tau0 : Option<F>, | |
160 | ||
161 | #[arg(long, requires = "algorithm")] | |
162 | /// Dual step length parameter override for --algorithm. | |
163 | /// | |
164 | /// Only use if running just a single algorithm, as different algorithms have different | |
165 | /// regularisation parameters. Only affects PDPS. | |
166 | sigma0 : Option<F>, | |
167 | ||
168 | #[arg(value_enum, long)] | |
169 | /// PDPS acceleration, when available. | |
170 | acceleration : Option<pdps::Acceleration>, | |
171 | ||
172 | #[arg(long)] | |
173 | /// Perform postprocess weight optimisation for saved iterations | |
174 | /// | |
175 | /// Only affects FB, FISTA, and PDPS. | |
176 | postprocessing : Option<bool>, | |
177 | ||
178 | #[arg(value_name = "n", long)] | |
179 | /// Merging frequency, if merging enabled (every n iterations) | |
180 | /// | |
181 | /// Only affects FB, FISTA, and PDPS. | |
182 | merge_every : Option<usize>, | |
183 | ||
184 | #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] | |
185 | /// Merging strategy | |
186 | /// | |
187 | /// Either the string "none", or a radius value for heuristic merging. | |
188 | merging : Option<SpikeMergingMethod<F>>, | |
189 | ||
190 | #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] | |
191 | /// Final merging strategy | |
192 | /// | |
193 | /// Either the string "none", or a radius value for heuristic merging. | |
194 | /// Only affects FB, FISTA, and PDPS. | |
195 | final_merging : Option<SpikeMergingMethod<F>>, | |
196 | } | |
197 | ||
198 | /// The entry point for the program. | |
199 | pub fn main() { | |
200 | let cli = CommandLineArgs::parse(); | |
201 | ||
202 | if let Some(n_threads) = cli.num_threads { | |
203 | let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); | |
204 | set_num_threads(n); | |
205 | } else { | |
206 | let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); | |
207 | set_max_threads(m); | |
208 | } | |
209 | ||
210 | for experiment_shorthand in cli.experiments.iter().unique() { | |
211 | let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); | |
212 | let mut config : Configuration<float> = experiment.default_config(); | |
213 | let mut algs : Vec<Named<AlgorithmConfig<float>>> | |
214 | = cli.algorithm.iter() | |
215 | .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides)) | |
216 | .collect(); | |
217 | for filename in cli.saved_algorithm.iter() { | |
218 | let f = std::fs::File::open(filename).unwrap(); | |
219 | let alg = serde_json::from_reader(f).unwrap(); | |
220 | algs.push(alg); | |
221 | } | |
222 | cli.max_iter.map(|m| config.iterator_options.max_iter = m); | |
223 | cli.verbose_iter.map(|n| config.iterator_options.verbose_iter = Verbose::Every(n)); | |
224 | config.plot = cli.plot; | |
225 | config.iterator_options.quiet = cli.quiet; | |
226 | config.outdir = cli.outdir.clone(); | |
227 | if !algs.is_empty() { | |
228 | config.algorithms = algs.clone(); | |
229 | } | |
230 | ||
231 | experiment.runall(config) | |
232 | .unwrap() | |
233 | } | |
234 | } |