src/main.rs

branch
dev
changeset 61
4f468d35fa29
parent 44
03251c546744
equal deleted inserted replaced
60:9738b51d90d7 61:4f468d35fa29
1 // The main documentation is in the README. 1 // The main documentation is in the README.
2 // We need to uglify it in build.rs because rustdoc is stuck in the past. 2 // We need to uglify it in build.rs because rustdoc is stuck in the past.
3 #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))] 3 #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))]
4 4
5 // We use unicode. We would like to use much more of it than Rust allows. 5 use pointsource_algs::{common_main, experiments::DefaultExperimentSetup};
6 // Live with it. Embrace it.
7 #![allow(uncommon_codepoints)]
8 #![allow(mixed_script_confusables)]
9 #![allow(confusable_idents)]
10 // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical
11 // convention while referring to the type (trait) of the operator as `A`.
12 #![allow(non_snake_case)]
13 // Need to create parse errors
14 #![feature(dec2flt)]
15
16 use clap::Parser;
17 use serde::{Serialize, Deserialize};
18 use serde_json;
19 use serde_with::skip_serializing_none;
20 use itertools::Itertools;
21 use std::num::NonZeroUsize;
22
23 use alg_tools::parallelism::{
24 set_num_threads,
25 set_max_threads,
26 };
27
28 pub mod types;
29 pub mod measures;
30 pub mod fourier;
31 pub mod kernels;
32 pub mod seminorms;
33 pub mod forward_model;
34 pub mod preadjoint_helper;
35 pub mod plot;
36 pub mod subproblem;
37 pub mod tolerance;
38 pub mod regularisation;
39 pub mod dataterm;
40 pub mod prox_penalty;
41 pub mod fb;
42 pub mod sliding_fb;
43 pub mod sliding_pdps;
44 pub mod forward_pdps;
45 pub mod frank_wolfe;
46 pub mod pdps;
47 pub mod run;
48 pub mod rand_distr;
49 pub mod experiments;
50
51 use types::{float, ClapFloat};
52 use run::{
53 DefaultAlgorithm,
54 PlotLevel,
55 Named,
56 AlgorithmConfig,
57 };
58 use experiments::DefaultExperiment;
59 use DefaultExperiment::*;
60 use DefaultAlgorithm::*;
61
62 /// Command line parameters
63 #[skip_serializing_none]
64 #[derive(Parser, Debug, Serialize, Default, Clone)]
65 #[clap(
66 about = env!("CARGO_PKG_DESCRIPTION"),
67 author = env!("CARGO_PKG_AUTHORS"),
68 version = env!("CARGO_PKG_VERSION"),
69 after_help = "Pass --help for longer descriptions.",
70 after_long_help = "",
71 )]
72 pub struct CommandLineArgs {
73 #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)]
74 /// Maximum iteration count
75 max_iter : usize,
76
77 #[arg(long, short = 'n', value_name = "N")]
78 /// Output status every N iterations. Set to 0 to disable.
79 ///
80 /// The default is to output status based on logarithmic increments.
81 verbose_iter : Option<usize>,
82
83 #[arg(long, short = 'q')]
84 /// Don't display iteration progress
85 quiet : bool,
86
87 /// List of experiments to perform.
88 #[arg(value_enum, value_name = "EXPERIMENT",
89 default_values_t = [Experiment1D, Experiment1DFast,
90 Experiment2D, Experiment2DFast,
91 Experiment1D_L1])]
92 experiments : Vec<DefaultExperiment>,
93
94 /// Default algorithm configration(s) to use on the experiments.
95 ///
96 /// Not all algorithms are available for all the experiments.
97 /// In particular, only PDPS is available for the experiments with L¹ data term.
98 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
99 default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])]
100 algorithm : Vec<DefaultAlgorithm>,
101
102 /// Saved algorithm configration(s) to use on the experiments
103 #[arg(value_name = "JSON_FILE", long)]
104 saved_algorithm : Vec<String>,
105
106 /// Plot saving scheme
107 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)]
108 plot : PlotLevel,
109
110 /// Directory for saving results
111 #[arg(long, short = 'o', required = true, default_value = "out")]
112 outdir : String,
113
114 #[arg(long, help_heading = "Multi-threading", default_value = "4")]
115 /// Maximum number of threads
116 max_threads : usize,
117
118 #[arg(long, help_heading = "Multi-threading")]
119 /// Number of threads. Overrides the maximum number.
120 num_threads : Option<usize>,
121
122 #[arg(long, default_value_t = false)]
123 /// Load saved value ranges (if exists) to do partial update.
124 load_valuerange : bool,
125
126 #[clap(flatten, next_help_heading = "Experiment overrides")]
127 /// Experiment setup overrides
128 experiment_overrides : ExperimentOverrides<float>,
129
130 #[clap(flatten, next_help_heading = "Algorithm overrides")]
131 /// Algorithm parametrisation overrides
132 algoritm_overrides : AlgorithmOverrides<float>,
133 }
134
135 /// Command line experiment setup overrides
136 #[skip_serializing_none]
137 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
138 pub struct ExperimentOverrides<F : ClapFloat> {
139 #[arg(long)]
140 /// Regularisation parameter override.
141 ///
142 /// Only use if running just a single experiment, as different experiments have different
143 /// regularisation parameters.
144 alpha : Option<F>,
145
146 #[arg(long)]
147 /// Gaussian noise variance override
148 variance : Option<F>,
149
150 #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])]
151 /// Salt and pepper noise override.
152 salt_and_pepper : Option<Vec<F>>,
153
154 #[arg(long)]
155 /// Noise seed
156 noise_seed : Option<u64>,
157 }
158
159 /// Command line algorithm parametrisation overrides
160 #[skip_serializing_none]
161 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
162 pub struct AlgorithmOverrides<F : ClapFloat> {
163 #[arg(long, value_names = &["COUNT", "EACH"])]
164 /// Override bootstrap insertion iterations for --algorithm.
165 ///
166 /// The first parameter is the number of bootstrap insertion iterations, and the second
167 /// the maximum number of iterations on each of them.
168 bootstrap_insertions : Option<Vec<usize>>,
169
170 #[arg(long, requires = "algorithm")]
171 /// Primal step length parameter override for --algorithm.
172 ///
173 /// Only use if running just a single algorithm, as different algorithms have different
174 /// regularisation parameters. Does not affect the algorithms fw and fwrelax.
175 tau0 : Option<F>,
176
177 #[arg(long, requires = "algorithm")]
178 /// Second primal step length parameter override for SlidingPDPS.
179 ///
180 /// Only use if running just a single algorithm, as different algorithms have different
181 /// regularisation parameters.
182 sigmap0 : Option<F>,
183
184 #[arg(long, requires = "algorithm")]
185 /// Dual step length parameter override for --algorithm.
186 ///
187 /// Only use if running just a single algorithm, as different algorithms have different
188 /// regularisation parameters. Only affects PDPS.
189 sigma0 : Option<F>,
190
191 #[arg(long)]
192 /// Normalised transport step length for sliding methods.
193 theta0 : Option<F>,
194
195 #[arg(long)]
196 /// A posteriori transport tolerance multiplier (C_pos)
197 transport_tolerance_pos : Option<F>,
198
199 #[arg(long)]
200 /// Transport adaptation factor. Must be in (0, 1).
201 transport_adaptation : Option<F>,
202
203 #[arg(long)]
204 /// Minimal step length parameter for sliding methods.
205 tau0_min : Option<F>,
206
207 #[arg(value_enum, long)]
208 /// PDPS acceleration, when available.
209 acceleration : Option<pdps::Acceleration>,
210
211 // #[arg(long)]
212 // /// Perform postprocess weight optimisation for saved iterations
213 // ///
214 // /// Only affects FB, FISTA, and PDPS.
215 // postprocessing : Option<bool>,
216
217 #[arg(value_name = "n", long)]
218 /// Merging frequency, if merging enabled (every n iterations)
219 ///
220 /// Only affects FB, FISTA, and PDPS.
221 merge_every : Option<usize>,
222
223 #[arg(long)]
224 /// Enable merging (default: determined by algorithm)
225 merge : Option<bool>,
226
227 #[arg(long)]
228 /// Merging radius (default: determined by experiment)
229 merge_radius : Option<F>,
230
231 #[arg(long)]
232 /// Interpolate when merging (default : determined by algorithm)
233 merge_interp : Option<bool>,
234
235 #[arg(long)]
236 /// Enable final merging (default: determined by algorithm)
237 final_merging : Option<bool>,
238
239 #[arg(long)]
240 /// Enable fitness-based merging for relevant FB-type methods.
241 /// This has worse convergence guarantees that merging based on optimality conditions.
242 fitness_merging : Option<bool>,
243
244 #[arg(long, value_names = &["ε", "θ", "p"])]
245 /// Set the tolerance to ε_k = ε/(1+θk)^p
246 tolerance : Option<Vec<F>>,
247
248 }
249 6
250 /// The entry point for the program. 7 /// The entry point for the program.
251 pub fn main() { 8 pub fn main() {
252 let cli = CommandLineArgs::parse(); 9 common_main::<DefaultExperimentSetup<f64>>().unwrap();
253
254 #[cfg(debug_assertions)]
255 {
256 use colored::Colorize;
257 println!("{}", format!("\n\
258 ********\n\
259 WARNING: Compiled without optimisations; {}\n\
260 Please recompile with `--release` flag.\n\
261 ********\n\
262 ", "performance will be poor!".blink()
263 ).red());
264 }
265
266 if let Some(n_threads) = cli.num_threads {
267 let n = NonZeroUsize::new(n_threads).expect("Invalid thread count");
268 set_num_threads(n);
269 } else {
270 let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count");
271 set_max_threads(m);
272 }
273
274 for experiment_shorthand in cli.experiments.iter().unique() {
275 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap();
276 let mut algs : Vec<Named<AlgorithmConfig<float>>>
277 = cli.algorithm
278 .iter()
279 .map(|alg| {
280 let cfg = alg.default_config()
281 .cli_override(&experiment.algorithm_overrides(*alg))
282 .cli_override(&cli.algoritm_overrides);
283 alg.to_named(cfg)
284 })
285 .collect();
286 for filename in cli.saved_algorithm.iter() {
287 let f = std::fs::File::open(filename).unwrap();
288 let alg = serde_json::from_reader(f).unwrap();
289 algs.push(alg);
290 }
291 experiment.runall(&cli, (!algs.is_empty()).then_some(algs))
292 .unwrap()
293 }
294 } 10 }

mercurial