|
1 // The main documentation is in the README. |
|
2 #![doc = include_str!("../README.md")] |
|
3 |
|
4 // We use unicode. We would like to use much more of it than Rust allows. |
|
5 // Live with it. Embrace it. |
|
6 #![allow(uncommon_codepoints)] |
|
7 #![allow(mixed_script_confusables)] |
|
8 #![allow(confusable_idents)] |
|
9 // Linear operators may be writtten e.g. as `opA` for a resemblance |
|
10 // to mathematical convention. |
|
11 #![allow(non_snake_case)] |
|
12 // We need the drain filter for inertial prune |
|
13 #![feature(drain_filter)] |
|
14 |
|
15 use clap::Parser; |
|
16 use itertools::Itertools; |
|
17 use serde_json; |
|
18 use alg_tools::iterate::Verbose; |
|
19 use alg_tools::parallelism::{ |
|
20 set_num_threads, |
|
21 set_max_threads, |
|
22 }; |
|
23 use std::num::NonZeroUsize; |
|
24 |
|
25 pub mod types; |
|
26 pub mod measures; |
|
27 pub mod fourier; |
|
28 pub mod kernels; |
|
29 pub mod seminorms; |
|
30 pub mod forward_model; |
|
31 pub mod plot; |
|
32 pub mod subproblem; |
|
33 pub mod tolerance; |
|
34 pub mod fb; |
|
35 pub mod frank_wolfe; |
|
36 pub mod pdps; |
|
37 pub mod run; |
|
38 pub mod rand_distr; |
|
39 pub mod experiments; |
|
40 |
|
41 use types::{float, ClapFloat}; |
|
42 use run::{ |
|
43 DefaultAlgorithm, |
|
44 Configuration, |
|
45 PlotLevel, |
|
46 Named, |
|
47 AlgorithmConfig, |
|
48 }; |
|
49 use experiments::DefaultExperiment; |
|
50 use measures::merging::SpikeMergingMethod; |
|
51 use DefaultExperiment::*; |
|
52 use DefaultAlgorithm::*; |
|
53 |
|
54 /// Command line parameters |
|
55 #[derive(Parser, Debug)] |
|
56 #[clap( |
|
57 about = env!("CARGO_PKG_DESCRIPTION"), |
|
58 author = env!("CARGO_PKG_AUTHORS"), |
|
59 version = env!("CARGO_PKG_VERSION"), |
|
60 after_help = "Pass --help for longer descriptions.", |
|
61 after_long_help = "", |
|
62 )] |
|
63 pub struct CommandLineArgs { |
|
64 #[arg(long, short = 'm', value_name = "M")] |
|
65 /// Maximum iteration count |
|
66 max_iter : Option<usize>, |
|
67 |
|
68 #[arg(long, short = 'n', value_name = "N")] |
|
69 /// Output status every N iterations. Set to 0 to disable. |
|
70 verbose_iter : Option<usize>, |
|
71 |
|
72 #[arg(long, short = 'q')] |
|
73 /// Don't display iteration progress |
|
74 quiet : bool, |
|
75 |
|
76 /// List of experiments to perform. |
|
77 #[arg(value_enum, value_name = "EXPERIMENT", |
|
78 default_values_t = [Experiment1D, Experiment1DFast, |
|
79 Experiment2D, Experiment2DFast, |
|
80 Experiment1D_L1])] |
|
81 experiments : Vec<DefaultExperiment>, |
|
82 |
|
83 /// Default algorithm configration(s) to use on the experiments. |
|
84 /// |
|
85 /// Not all algorithms are available for all the experiments. |
|
86 /// In particular, only PDPS is available for the experiments with L¹ data term. |
|
87 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', |
|
88 default_values_t = [FB, FISTA, PDPS, FW, FWRelax])] |
|
89 algorithm : Vec<DefaultAlgorithm>, |
|
90 |
|
91 /// Saved algorithm configration(s) to use on the experiments |
|
92 #[arg(value_name = "JSON_FILE", long)] |
|
93 saved_algorithm : Vec<String>, |
|
94 |
|
95 /// Write plots for every verbose iteration |
|
96 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] |
|
97 plot : PlotLevel, |
|
98 |
|
99 /// Directory for saving results |
|
100 #[arg(long, short = 'o', default_value = "out")] |
|
101 outdir : String, |
|
102 |
|
103 #[arg(long, help_heading = "Multi-threading", default_value = "4")] |
|
104 /// Maximum number of threads |
|
105 max_threads : usize, |
|
106 |
|
107 #[arg(long, help_heading = "Multi-threading")] |
|
108 /// Number of threads. Overrides the maximum number. |
|
109 num_threads : Option<usize>, |
|
110 |
|
111 #[clap(flatten, next_help_heading = "Experiment overrides")] |
|
112 /// Experiment setup overrides |
|
113 experiment_overrides : ExperimentOverrides<float>, |
|
114 |
|
115 #[clap(flatten, next_help_heading = "Algorithm overrides")] |
|
116 /// Algorithm parametrisation overrides |
|
117 algoritm_overrides : AlgorithmOverrides<float>, |
|
118 } |
|
119 |
|
120 /// Command line experiment setup overrides |
|
121 #[derive(Parser, Debug)] |
|
122 pub struct ExperimentOverrides<F : ClapFloat> { |
|
123 #[arg(long)] |
|
124 /// Regularisation parameter override. |
|
125 /// |
|
126 /// Only use if running just a single experiment, as different experiments have different |
|
127 /// regularisation parameters. |
|
128 alpha : Option<F>, |
|
129 |
|
130 #[arg(long)] |
|
131 /// Gaussian noise variance override |
|
132 variance : Option<F>, |
|
133 |
|
134 #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])] |
|
135 /// Salt and pepper noise override. |
|
136 salt_and_pepper : Option<Vec<F>>, |
|
137 |
|
138 #[arg(long)] |
|
139 /// Noise seed |
|
140 noise_seed : Option<u64>, |
|
141 } |
|
142 |
|
143 /// Command line algorithm parametrisation overrides |
|
144 #[derive(Parser, Debug)] |
|
145 pub struct AlgorithmOverrides<F : ClapFloat> { |
|
146 #[arg(long, value_names = &["COUNT", "EACH"])] |
|
147 /// Override bootstrap insertion iterations for --algorithm. |
|
148 /// |
|
149 /// The first parameter is the number of bootstrap insertion iterations, and the second |
|
150 /// the maximum number of iterations on each of them. |
|
151 bootstrap_insertions : Option<Vec<usize>>, |
|
152 |
|
153 #[arg(long, requires = "algorithm")] |
|
154 /// Primal step length parameter override for --algorithm. |
|
155 /// |
|
156 /// Only use if running just a single algorithm, as different algorithms have different |
|
157 /// regularisation parameters. Does not affect the algorithms fw and fwrelax. |
|
158 tau0 : Option<F>, |
|
159 |
|
160 #[arg(long, requires = "algorithm")] |
|
161 /// Dual step length parameter override for --algorithm. |
|
162 /// |
|
163 /// Only use if running just a single algorithm, as different algorithms have different |
|
164 /// regularisation parameters. Only affects PDPS. |
|
165 sigma0 : Option<F>, |
|
166 |
|
167 #[arg(value_enum, long)] |
|
168 /// PDPS acceleration, when available. |
|
169 acceleration : Option<pdps::Acceleration>, |
|
170 |
|
171 #[arg(long)] |
|
172 /// Perform postprocess weight optimisation for saved iterations |
|
173 /// |
|
174 /// Only affects FB, FISTA, and PDPS. |
|
175 postprocessing : Option<bool>, |
|
176 |
|
177 #[arg(value_name = "n", long)] |
|
178 /// Merging frequency, if merging enabled (every n iterations) |
|
179 /// |
|
180 /// Only affects FB, FISTA, and PDPS. |
|
181 merge_every : Option<usize>, |
|
182 |
|
183 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] |
|
184 /// Merging strategy |
|
185 /// |
|
186 /// Either the string "none", or a radius value for heuristic merging. |
|
187 merging : Option<SpikeMergingMethod<F>>, |
|
188 |
|
189 #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] |
|
190 /// Final merging strategy |
|
191 /// |
|
192 /// Either the string "none", or a radius value for heuristic merging. |
|
193 /// Only affects FB, FISTA, and PDPS. |
|
194 final_merging : Option<SpikeMergingMethod<F>>, |
|
195 } |
|
196 |
|
197 /// The entry point for the program. |
|
198 pub fn main() { |
|
199 let cli = CommandLineArgs::parse(); |
|
200 |
|
201 if let Some(n_threads) = cli.num_threads { |
|
202 let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); |
|
203 set_num_threads(n); |
|
204 } else { |
|
205 let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); |
|
206 set_max_threads(m); |
|
207 } |
|
208 |
|
209 for experiment_shorthand in cli.experiments.iter().unique() { |
|
210 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); |
|
211 let mut config : Configuration<float> = experiment.default_config(); |
|
212 let mut algs : Vec<Named<AlgorithmConfig<float>>> |
|
213 = cli.algorithm.iter() |
|
214 .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides)) |
|
215 .collect(); |
|
216 for filename in cli.saved_algorithm.iter() { |
|
217 let f = std::fs::File::open(filename).unwrap(); |
|
218 let alg = serde_json::from_reader(f).unwrap(); |
|
219 algs.push(alg); |
|
220 } |
|
221 cli.max_iter.map(|m| config.iterator_options.max_iter = m); |
|
222 cli.verbose_iter.map(|n| config.iterator_options.verbose_iter = Verbose::Every(n)); |
|
223 config.plot = cli.plot; |
|
224 config.iterator_options.quiet = cli.quiet; |
|
225 config.outdir = cli.outdir.clone(); |
|
226 if !algs.is_empty() { |
|
227 config.algorithms = algs.clone(); |
|
228 } |
|
229 |
|
230 experiment.runall(config) |
|
231 .unwrap() |
|
232 } |
|
233 } |