| |
1 // The main documentation is in the README. |
| |
2 #![doc = include_str!("../README.md")] |
| |
3 |
| |
4 // We use unicode. We would like to use much more of it than Rust allows. |
| |
5 // Live with it. Embrace it. |
| |
6 #![allow(uncommon_codepoints)] |
| |
7 #![allow(mixed_script_confusables)] |
| |
8 #![allow(confusable_idents)] |
| |
9 // Linear operators may be writtten e.g. as `opA` for a resemblance |
| |
10 // to mathematical convention. |
| |
11 #![allow(non_snake_case)] |
| |
12 // We need the drain filter for inertial prune |
| |
13 #![feature(drain_filter)] |
| |
14 |
| |
15 use clap::Parser; |
| |
16 use itertools::Itertools; |
| |
17 use serde_json; |
| |
18 use alg_tools::iterate::Verbose; |
| |
19 use alg_tools::parallelism::{ |
| |
20 set_num_threads, |
| |
21 set_max_threads, |
| |
22 }; |
| |
23 use std::num::NonZeroUsize; |
| |
24 |
| |
25 pub mod types; |
| |
26 pub mod measures; |
| |
27 pub mod fourier; |
| |
28 pub mod kernels; |
| |
29 pub mod seminorms; |
| |
30 pub mod forward_model; |
| |
31 pub mod plot; |
| |
32 pub mod subproblem; |
| |
33 pub mod tolerance; |
| |
34 pub mod fb; |
| |
35 pub mod frank_wolfe; |
| |
36 pub mod pdps; |
| |
37 pub mod run; |
| |
38 pub mod rand_distr; |
| |
39 pub mod experiments; |
| |
40 |
| |
41 use types::{float, ClapFloat}; |
| |
42 use run::{ |
| |
43 DefaultAlgorithm, |
| |
44 Configuration, |
| |
45 PlotLevel, |
| |
46 Named, |
| |
47 AlgorithmConfig, |
| |
48 }; |
| |
49 use experiments::DefaultExperiment; |
| |
50 use measures::merging::SpikeMergingMethod; |
| |
51 use DefaultExperiment::*; |
| |
52 use DefaultAlgorithm::*; |
| |
53 |
| |
54 /// Command line parameters |
| |
55 #[derive(Parser, Debug)] |
| |
56 #[clap( |
| |
57 about = env!("CARGO_PKG_DESCRIPTION"), |
| |
58 author = env!("CARGO_PKG_AUTHORS"), |
| |
59 version = env!("CARGO_PKG_VERSION"), |
| |
60 after_help = "Pass --help for longer descriptions.", |
| |
61 after_long_help = "", |
| |
62 )] |
| |
63 pub struct CommandLineArgs { |
| |
64 #[arg(long, short = 'm', value_name = "M")] |
| |
65 /// Maximum iteration count |
| |
66 max_iter : Option<usize>, |
| |
67 |
| |
68 #[arg(long, short = 'n', value_name = "N")] |
| |
69 /// Output status every N iterations. Set to 0 to disable. |
| |
70 verbose_iter : Option<usize>, |
| |
71 |
| |
72 #[arg(long, short = 'q')] |
| |
73 /// Don't display iteration progress |
| |
74 quiet : bool, |
| |
75 |
| |
76 /// List of experiments to perform. |
| |
77 #[arg(value_enum, value_name = "EXPERIMENT", |
| |
78 default_values_t = [Experiment1D, Experiment1DFast, |
| |
79 Experiment2D, Experiment2DFast, |
| |
80 Experiment1D_L1])] |
| |
81 experiments : Vec<DefaultExperiment>, |
| |
82 |
| |
83 /// Default algorithm configration(s) to use on the experiments. |
| |
84 /// |
| |
85 /// Not all algorithms are available for all the experiments. |
| |
86 /// In particular, only PDPS is available for the experiments with L¹ data term. |
| |
87 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', |
| |
88 default_values_t = [FB, FISTA, PDPS, FW, FWRelax])] |
| |
89 algorithm : Vec<DefaultAlgorithm>, |
| |
90 |
| |
91 /// Saved algorithm configration(s) to use on the experiments |
| |
92 #[arg(value_name = "JSON_FILE", long)] |
| |
93 saved_algorithm : Vec<String>, |
| |
94 |
| |
95 /// Write plots for every verbose iteration |
| |
96 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] |
| |
97 plot : PlotLevel, |
| |
98 |
| |
99 /// Directory for saving results |
| |
100 #[arg(long, short = 'o', default_value = "out")] |
| |
101 outdir : String, |
| |
102 |
| |
103 #[arg(long, help_heading = "Multi-threading", default_value = "4")] |
| |
104 /// Maximum number of threads |
| |
105 max_threads : usize, |
| |
106 |
| |
107 #[arg(long, help_heading = "Multi-threading")] |
| |
108 /// Number of threads. Overrides the maximum number. |
| |
109 num_threads : Option<usize>, |
| |
110 |
| |
111 #[clap(flatten, next_help_heading = "Experiment overrides")] |
| |
112 /// Experiment setup overrides |
| |
113 experiment_overrides : ExperimentOverrides<float>, |
| |
114 |
| |
115 #[clap(flatten, next_help_heading = "Algorithm overrides")] |
| |
116 /// Algorithm parametrisation overrides |
| |
117 algoritm_overrides : AlgorithmOverrides<float>, |
| |
118 } |
| |
119 |
| |
120 /// Command line experiment setup overrides |
| |
121 #[derive(Parser, Debug)] |
| |
122 pub struct ExperimentOverrides<F : ClapFloat> { |
| |
123 #[arg(long)] |
| |
124 /// Regularisation parameter override. |
| |
125 /// |
| |
126 /// Only use if running just a single experiment, as different experiments have different |
| |
127 /// regularisation parameters. |
| |
128 alpha : Option<F>, |
| |
129 |
| |
130 #[arg(long)] |
| |
131 /// Gaussian noise variance override |
| |
132 variance : Option<F>, |
| |
133 |
| |
134 #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])] |
| |
135 /// Salt and pepper noise override. |
| |
136 salt_and_pepper : Option<Vec<F>>, |
| |
137 |
| |
138 #[arg(long)] |
| |
139 /// Noise seed |
| |
140 noise_seed : Option<u64>, |
| |
141 } |
| |
142 |
| |
143 /// Command line algorithm parametrisation overrides |
| |
144 #[derive(Parser, Debug)] |
| |
145 pub struct AlgorithmOverrides<F : ClapFloat> { |
| |
146 #[arg(long, value_names = &["COUNT", "EACH"])] |
| |
147 /// Override bootstrap insertion iterations for --algorithm. |
| |
148 /// |
| |
149 /// The first parameter is the number of bootstrap insertion iterations, and the second |
| |
150 /// the maximum number of iterations on each of them. |
| |
151 bootstrap_insertions : Option<Vec<usize>>, |
| |
152 |
| |
153 #[arg(long, requires = "algorithm")] |
| |
154 /// Primal step length parameter override for --algorithm. |
| |
155 /// |
| |
156 /// Only use if running just a single algorithm, as different algorithms have different |
| |
157 /// regularisation parameters. Does not affect the algorithms fw and fwrelax. |
| |
158 tau0 : Option<F>, |
| |
159 |
| |
160 #[arg(long, requires = "algorithm")] |
| |
161 /// Dual step length parameter override for --algorithm. |
| |
162 /// |
| |
163 /// Only use if running just a single algorithm, as different algorithms have different |
| |
164 /// regularisation parameters. Only affects PDPS. |
| |
165 sigma0 : Option<F>, |
| |
166 |
| |
167 #[arg(value_enum, long)] |
| |
168 /// PDPS acceleration, when available. |
| |
169 acceleration : Option<pdps::Acceleration>, |
| |
170 |
| |
171 #[arg(long)] |
| |
172 /// Perform postprocess weight optimisation for saved iterations |
| |
173 /// |
| |
174 /// Only affects FB, FISTA, and PDPS. |
| |
175 postprocessing : Option<bool>, |
| |
176 |
| |
177 #[arg(value_name = "n", long)] |
| |
178 /// Merging frequency, if merging enabled (every n iterations) |
| |
179 /// |
| |
180 /// Only affects FB, FISTA, and PDPS. |
| |
181 merge_every : Option<usize>, |
| |
182 |
| |
183 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] |
| |
184 /// Merging strategy |
| |
185 /// |
| |
186 /// Either the string "none", or a radius value for heuristic merging. |
| |
187 merging : Option<SpikeMergingMethod<F>>, |
| |
188 |
| |
189 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] |
| |
190 /// Final merging strategy |
| |
191 /// |
| |
192 /// Either the string "none", or a radius value for heuristic merging. |
| |
193 /// Only affects FB, FISTA, and PDPS. |
| |
194 final_merging : Option<SpikeMergingMethod<F>>, |
| |
195 } |
| |
196 |
| |
197 /// The entry point for the program. |
| |
198 pub fn main() { |
| |
199 let cli = CommandLineArgs::parse(); |
| |
200 |
| |
201 if let Some(n_threads) = cli.num_threads { |
| |
202 let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); |
| |
203 set_num_threads(n); |
| |
204 } else { |
| |
205 let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); |
| |
206 set_max_threads(m); |
| |
207 } |
| |
208 |
| |
209 for experiment_shorthand in cli.experiments.iter().unique() { |
| |
210 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); |
| |
211 let mut config : Configuration<float> = experiment.default_config(); |
| |
212 let mut algs : Vec<Named<AlgorithmConfig<float>>> |
| |
213 = cli.algorithm.iter() |
| |
214 .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides)) |
| |
215 .collect(); |
| |
216 for filename in cli.saved_algorithm.iter() { |
| |
217 let f = std::fs::File::open(filename).unwrap(); |
| |
218 let alg = serde_json::from_reader(f).unwrap(); |
| |
219 algs.push(alg); |
| |
220 } |
| |
221 cli.max_iter.map(|m| config.iterator_options.max_iter = m); |
| |
222 cli.verbose_iter.map(|n| config.iterator_options.verbose_iter = Verbose::Every(n)); |
| |
223 config.plot = cli.plot; |
| |
224 config.iterator_options.quiet = cli.quiet; |
| |
225 config.outdir = cli.outdir.clone(); |
| |
226 if !algs.is_empty() { |
| |
227 config.algorithms = algs.clone(); |
| |
228 } |
| |
229 |
| |
230 experiment.runall(config) |
| |
231 .unwrap() |
| |
232 } |
| |
233 } |