| |
1 // The main documentation is in the README. |
| |
2 // We need to uglify it in build.rs because rustdoc is stuck in the past. |
| |
3 #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))] |
| |
4 // We use unicode. We would like to use much more of it than Rust allows. |
| |
5 // Live with it. Embrace it. |
| |
6 #![allow(uncommon_codepoints)] |
| |
7 #![allow(mixed_script_confusables)] |
| |
8 #![allow(confusable_idents)] |
| |
9 // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical |
| |
10 // convention while referring to the type (trait) of the operator as `A`. |
| |
11 #![allow(non_snake_case)] |
| |
12 |
| |
13 use alg_tools::error::DynResult; |
| |
14 use alg_tools::parallelism::{set_max_threads, set_num_threads}; |
| |
15 use clap::Parser; |
| |
16 use serde::{Deserialize, Serialize}; |
| |
17 use serde_json; |
| |
18 use serde_with::skip_serializing_none; |
| |
19 use std::num::NonZeroUsize; |
| |
20 |
| |
21 //#[cfg(feature = "pyo3")] |
| |
22 //use pyo3::pyclass; |
| |
23 |
| |
24 pub mod dataterm; |
| |
25 pub mod experiments; |
| |
26 pub mod fb; |
| |
27 pub mod forward_model; |
| |
28 pub mod forward_pdps; |
| |
29 pub mod fourier; |
| |
30 pub mod frank_wolfe; |
| |
31 pub mod kernels; |
| |
32 pub mod pdps; |
| |
33 pub mod plot; |
| |
34 pub mod preadjoint_helper; |
| |
35 pub mod prox_penalty; |
| |
36 pub mod rand_distr; |
| |
37 pub mod regularisation; |
| |
38 pub mod run; |
| |
39 pub mod seminorms; |
| |
40 pub mod sliding_fb; |
| |
41 pub mod sliding_pdps; |
| |
42 pub mod subproblem; |
| |
43 pub mod tolerance; |
| |
44 pub mod types; |
| |
45 |
| |
46 pub mod measures { |
| |
47 pub use measures::*; |
| |
48 } |
| |
49 |
| |
50 use run::{AlgorithmConfig, DefaultAlgorithm, Named, PlotLevel, RunnableExperiment}; |
| |
51 use subproblem::InnerMethod; |
| |
52 use types::{ClapFloat, Float}; |
| |
53 use DefaultAlgorithm::*; |
| |
54 |
| |
55 /// Trait for customising the experiments available from the command line |
| |
56 pub trait ExperimentSetup: |
| |
57 clap::Args + Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> |
| |
58 { |
| |
59 /// Type of floating point numbers to be used. |
| |
60 type FloatType: Float + ClapFloat + for<'b> Deserialize<'b>; |
| |
61 |
| |
62 fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>>; |
| |
63 } |
| |
64 |
| |
65 /// Command line parameters |
| |
66 #[skip_serializing_none] |
| |
67 #[derive(Parser, Debug, Serialize, Default, Clone)] |
| |
68 pub struct CommandLineArgs { |
| |
69 #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)] |
| |
70 /// Maximum iteration count |
| |
71 max_iter: usize, |
| |
72 |
| |
73 #[arg(long, short = 'n', value_name = "N")] |
| |
74 /// Output status every N iterations. Set to 0 to disable. |
| |
75 /// |
| |
76 /// The default is to output status based on logarithmic increments. |
| |
77 verbose_iter: Option<usize>, |
| |
78 |
| |
79 #[arg(long, short = 'q')] |
| |
80 /// Don't display iteration progress |
| |
81 quiet: bool, |
| |
82 |
| |
83 /// Default algorithm configration(s) to use on the experiments. |
| |
84 /// |
| |
85 /// Not all algorithms are available for all the experiments. |
| |
86 /// In particular, only PDPS is available for the experiments with L¹ data term. |
| |
87 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', |
| |
88 default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])] |
| |
89 algorithm: Vec<DefaultAlgorithm>, |
| |
90 |
| |
91 /// Saved algorithm configration(s) to use on the experiments |
| |
92 #[arg(value_name = "JSON_FILE", long)] |
| |
93 saved_algorithm: Vec<String>, |
| |
94 |
| |
95 /// Plot saving scheme |
| |
96 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] |
| |
97 plot: PlotLevel, |
| |
98 |
| |
99 /// Directory for saving results |
| |
100 #[arg(long, short = 'o', required = true, default_value = "out")] |
| |
101 outdir: String, |
| |
102 |
| |
103 #[arg(long, help_heading = "Multi-threading", default_value = "4")] |
| |
104 /// Maximum number of threads |
| |
105 max_threads: usize, |
| |
106 |
| |
107 #[arg(long, help_heading = "Multi-threading")] |
| |
108 /// Number of threads. Overrides the maximum number. |
| |
109 num_threads: Option<usize>, |
| |
110 |
| |
111 #[arg(long, default_value_t = false)] |
| |
112 /// Load saved value ranges (if exists) to do partial update. |
| |
113 load_valuerange: bool, |
| |
114 } |
| |
115 |
| |
116 #[derive(Parser, Debug, Serialize, Default, Clone)] |
| |
117 #[clap( |
| |
118 about = env!("CARGO_PKG_DESCRIPTION"), |
| |
119 author = env!("CARGO_PKG_AUTHORS"), |
| |
120 version = env!("CARGO_PKG_VERSION"), |
| |
121 after_help = "Pass --help for longer descriptions.", |
| |
122 after_long_help = "", |
| |
123 )] |
| |
124 struct FusedCommandLineArgs<E: ExperimentSetup> { |
| |
125 /// List of experiments to perform. |
| |
126 #[clap(flatten, next_help_heading = "Experiment setup")] |
| |
127 experiment_setup: E, |
| |
128 |
| |
129 #[clap(flatten, next_help_heading = "General parameters")] |
| |
130 general: CommandLineArgs, |
| |
131 |
| |
132 #[clap(flatten, next_help_heading = "Algorithm overrides")] |
| |
133 /// Algorithm parametrisation overrides |
| |
134 algorithm_overrides: AlgorithmOverrides<E::FloatType>, |
| |
135 } |
| |
136 |
| |
137 /// Command line algorithm parametrisation overrides |
| |
138 #[skip_serializing_none] |
| |
139 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] |
| |
140 //#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))] |
| |
141 pub struct AlgorithmOverrides<F: ClapFloat> { |
| |
142 #[arg(long, value_names = &["COUNT", "EACH"])] |
| |
143 /// Override bootstrap insertion iterations for --algorithm. |
| |
144 /// |
| |
145 /// The first parameter is the number of bootstrap insertion iterations, and the second |
| |
146 /// the maximum number of iterations on each of them. |
| |
147 bootstrap_insertions: Option<Vec<usize>>, |
| |
148 |
| |
149 #[arg(long, requires = "algorithm")] |
| |
150 /// Primal step length parameter override for --algorithm. |
| |
151 /// |
| |
152 /// Only use if running just a single algorithm, as different algorithms have different |
| |
153 /// regularisation parameters. Does not affect the algorithms fw and fwrelax. |
| |
154 tau0: Option<F>, |
| |
155 |
| |
156 #[arg(long, requires = "algorithm")] |
| |
157 /// Second primal step length parameter override for SlidingPDPS. |
| |
158 /// |
| |
159 /// Only use if running just a single algorithm, as different algorithms have different |
| |
160 /// regularisation parameters. |
| |
161 sigmap0: Option<F>, |
| |
162 |
| |
163 #[arg(long, requires = "algorithm")] |
| |
164 /// Dual step length parameter override for --algorithm. |
| |
165 /// |
| |
166 /// Only use if running just a single algorithm, as different algorithms have different |
| |
167 /// regularisation parameters. Only affects PDPS. |
| |
168 sigma0: Option<F>, |
| |
169 |
| |
170 #[arg(long)] |
| |
171 /// Normalised transport step length for sliding methods. |
| |
172 theta0: Option<F>, |
| |
173 |
| |
174 #[arg(long)] |
| |
175 /// A posteriori transport tolerance multiplier (C_pos) |
| |
176 transport_tolerance_pos: Option<F>, |
| |
177 |
| |
178 #[arg(long)] |
| |
179 /// Transport adaptation factor. Must be in (0, 1). |
| |
180 transport_adaptation: Option<F>, |
| |
181 |
| |
182 #[arg(long)] |
| |
183 /// Minimal step length parameter for sliding methods. |
| |
184 tau0_min: Option<F>, |
| |
185 |
| |
186 #[arg(value_enum, long)] |
| |
187 /// PDPS acceleration, when available. |
| |
188 acceleration: Option<pdps::Acceleration>, |
| |
189 |
| |
190 // #[arg(long)] |
| |
191 // /// Perform postprocess weight optimisation for saved iterations |
| |
192 // /// |
| |
193 // /// Only affects FB, FISTA, and PDPS. |
| |
194 // postprocessing : Option<bool>, |
| |
195 #[arg(value_name = "n", long)] |
| |
196 /// Merging frequency, if merging enabled (every n iterations) |
| |
197 /// |
| |
198 /// Only affects FB, FISTA, and PDPS. |
| |
199 merge_every: Option<usize>, |
| |
200 |
| |
201 #[arg(long)] |
| |
202 /// Enable merging (default: determined by algorithm) |
| |
203 merge: Option<bool>, |
| |
204 |
| |
205 #[arg(long)] |
| |
206 /// Merging radius (default: determined by experiment) |
| |
207 merge_radius: Option<F>, |
| |
208 |
| |
209 #[arg(long)] |
| |
210 /// Interpolate when merging (default : determined by algorithm) |
| |
211 merge_interp: Option<bool>, |
| |
212 |
| |
213 #[arg(long)] |
| |
214 /// Enable final merging (default: determined by algorithm) |
| |
215 final_merging: Option<bool>, |
| |
216 |
| |
217 #[arg(long)] |
| |
218 /// Enable fitness-based merging for relevant FB-type methods. |
| |
219 /// This has worse convergence guarantees that merging based on optimality conditions. |
| |
220 fitness_merging: Option<bool>, |
| |
221 |
| |
222 #[arg(long, value_names = &["ε", "θ", "p"])] |
| |
223 /// Set the tolerance to ε_k = ε/(1+θk)^p |
| |
224 tolerance: Option<Vec<F>>, |
| |
225 |
| |
226 #[arg(long)] |
| |
227 /// Method for solving inner optimisation problems |
| |
228 inner_method: Option<InnerMethod>, |
| |
229 |
| |
230 #[arg(long)] |
| |
231 /// Step length parameter for inner problem |
| |
232 inner_τ0: Option<F>, |
| |
233 |
| |
234 #[arg(long, value_names = &["τ0", "σ0"])] |
| |
235 /// Dual step length parameter for inner problem |
| |
236 inner_pdps_τσ0: Option<Vec<F>>, |
| |
237 |
| |
238 #[arg(long, value_names = &["τ", "growth"])] |
| |
239 /// Inner proximal point method step length and its growth |
| |
240 inner_pp_τ: Option<Vec<F>>, |
| |
241 |
| |
242 #[arg(long)] |
| |
243 /// Inner tolerance multiplier |
| |
244 inner_tol: Option<F>, |
| |
245 } |
| |
246 |
| |
247 /// A generic entry point for binaries based on this library |
| |
248 pub fn common_main<E: ExperimentSetup>() -> DynResult<()> { |
| |
249 let full_cli = FusedCommandLineArgs::<E>::parse(); |
| |
250 let cli = &full_cli.general; |
| |
251 |
| |
252 #[cfg(debug_assertions)] |
| |
253 { |
| |
254 use colored::Colorize; |
| |
255 println!( |
| |
256 "{}", |
| |
257 format!( |
| |
258 "\n\ |
| |
259 ********\n\ |
| |
260 WARNING: Compiled without optimisations; {}\n\ |
| |
261 Please recompile with `--release` flag.\n\ |
| |
262 ********\n\ |
| |
263 ", |
| |
264 "performance will be poor!".blink() |
| |
265 ) |
| |
266 .red() |
| |
267 ); |
| |
268 } |
| |
269 |
| |
270 if let Some(n_threads) = cli.num_threads { |
| |
271 let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); |
| |
272 set_num_threads(n); |
| |
273 } else { |
| |
274 let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); |
| |
275 set_max_threads(m); |
| |
276 } |
| |
277 |
| |
278 for experiment in full_cli.experiment_setup.runnables()? { |
| |
279 let mut algs: Vec<Named<AlgorithmConfig<E::FloatType>>> = cli |
| |
280 .algorithm |
| |
281 .iter() |
| |
282 .map(|alg| { |
| |
283 let cfg = alg |
| |
284 .default_config() |
| |
285 .cli_override(&experiment.algorithm_overrides(*alg)) |
| |
286 .cli_override(&full_cli.algorithm_overrides); |
| |
287 alg.to_named(cfg) |
| |
288 }) |
| |
289 .collect(); |
| |
290 for filename in cli.saved_algorithm.iter() { |
| |
291 let f = std::fs::File::open(filename)?; |
| |
292 let alg = serde_json::from_reader(f)?; |
| |
293 algs.push(alg); |
| |
294 } |
| |
295 experiment.runall(&cli, (!algs.is_empty()).then_some(algs))?; |
| |
296 } |
| |
297 |
| |
298 Ok(()) |
| |
299 } |