| |
1 // The main documentation is in the README. |
| |
2 // We need to uglify it in build.rs because rustdoc is stuck in the past. |
| |
3 #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))] |
| |
4 // We use unicode. We would like to use much more of it than Rust allows. |
| |
5 // Live with it. Embrace it. |
| |
6 #![allow(uncommon_codepoints)] |
| |
7 #![allow(mixed_script_confusables)] |
| |
8 #![allow(confusable_idents)] |
| |
9 // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical |
| |
10 // convention while referring to the type (trait) of the operator as `A`. |
| |
11 #![allow(non_snake_case)] |
| |
12 // Need to create parse errors |
| |
13 #![feature(dec2flt)] |
| |
14 |
| |
15 use alg_tools::error::DynResult; |
| |
16 use alg_tools::parallelism::{set_max_threads, set_num_threads}; |
| |
17 use clap::Parser; |
| |
18 use serde::{Deserialize, Serialize}; |
| |
19 use serde_json; |
| |
20 use serde_with::skip_serializing_none; |
| |
21 use std::num::NonZeroUsize; |
| |
22 |
| |
23 //#[cfg(feature = "pyo3")] |
| |
24 //use pyo3::pyclass; |
| |
25 |
| |
26 pub mod dataterm; |
| |
27 pub mod experiments; |
| |
28 pub mod fb; |
| |
29 pub mod forward_model; |
| |
30 pub mod forward_pdps; |
| |
31 pub mod fourier; |
| |
32 pub mod frank_wolfe; |
| |
33 pub mod kernels; |
| |
34 pub mod pdps; |
| |
35 pub mod plot; |
| |
36 pub mod preadjoint_helper; |
| |
37 pub mod prox_penalty; |
| |
38 pub mod rand_distr; |
| |
39 pub mod regularisation; |
| |
40 pub mod run; |
| |
41 pub mod seminorms; |
| |
42 pub mod sliding_fb; |
| |
43 pub mod sliding_pdps; |
| |
44 pub mod subproblem; |
| |
45 pub mod tolerance; |
| |
46 pub mod types; |
| |
47 |
| |
48 pub mod measures { |
| |
49 pub use measures::*; |
| |
50 } |
| |
51 |
| |
52 use run::{AlgorithmConfig, DefaultAlgorithm, Named, PlotLevel, RunnableExperiment}; |
| |
53 use types::{ClapFloat, Float}; |
| |
54 use DefaultAlgorithm::*; |
| |
55 |
| |
56 /// Trait for customising the experiments available from the command line |
| |
57 pub trait ExperimentSetup: |
| |
58 clap::Args + Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> |
| |
59 { |
| |
60 /// Type of floating point numbers to be used. |
| |
61 type FloatType: Float + ClapFloat + for<'b> Deserialize<'b>; |
| |
62 |
| |
63 fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>>; |
| |
64 } |
| |
65 |
| |
66 /// Command line parameters |
| |
67 #[skip_serializing_none] |
| |
68 #[derive(Parser, Debug, Serialize, Default, Clone)] |
| |
69 pub struct CommandLineArgs { |
| |
70 #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)] |
| |
71 /// Maximum iteration count |
| |
72 max_iter: usize, |
| |
73 |
| |
74 #[arg(long, short = 'n', value_name = "N")] |
| |
75 /// Output status every N iterations. Set to 0 to disable. |
| |
76 /// |
| |
77 /// The default is to output status based on logarithmic increments. |
| |
78 verbose_iter: Option<usize>, |
| |
79 |
| |
80 #[arg(long, short = 'q')] |
| |
81 /// Don't display iteration progress |
| |
82 quiet: bool, |
| |
83 |
| |
84 /// Default algorithm configration(s) to use on the experiments. |
| |
85 /// |
| |
86 /// Not all algorithms are available for all the experiments. |
| |
87 /// In particular, only PDPS is available for the experiments with L¹ data term. |
| |
88 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', |
| |
89 default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])] |
| |
90 algorithm: Vec<DefaultAlgorithm>, |
| |
91 |
| |
92 /// Saved algorithm configration(s) to use on the experiments |
| |
93 #[arg(value_name = "JSON_FILE", long)] |
| |
94 saved_algorithm: Vec<String>, |
| |
95 |
| |
96 /// Plot saving scheme |
| |
97 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] |
| |
98 plot: PlotLevel, |
| |
99 |
| |
100 /// Directory for saving results |
| |
101 #[arg(long, short = 'o', required = true, default_value = "out")] |
| |
102 outdir: String, |
| |
103 |
| |
104 #[arg(long, help_heading = "Multi-threading", default_value = "4")] |
| |
105 /// Maximum number of threads |
| |
106 max_threads: usize, |
| |
107 |
| |
108 #[arg(long, help_heading = "Multi-threading")] |
| |
109 /// Number of threads. Overrides the maximum number. |
| |
110 num_threads: Option<usize>, |
| |
111 |
| |
112 #[arg(long, default_value_t = false)] |
| |
113 /// Load saved value ranges (if exists) to do partial update. |
| |
114 load_valuerange: bool, |
| |
115 } |
| |
116 |
| |
117 #[derive(Parser, Debug, Serialize, Default, Clone)] |
| |
118 #[clap( |
| |
119 about = env!("CARGO_PKG_DESCRIPTION"), |
| |
120 author = env!("CARGO_PKG_AUTHORS"), |
| |
121 version = env!("CARGO_PKG_VERSION"), |
| |
122 after_help = "Pass --help for longer descriptions.", |
| |
123 after_long_help = "", |
| |
124 )] |
| |
125 struct FusedCommandLineArgs<E: ExperimentSetup> { |
| |
126 /// List of experiments to perform. |
| |
127 #[clap(flatten, next_help_heading = "Experiment setup")] |
| |
128 experiment_setup: E, |
| |
129 |
| |
130 #[clap(flatten, next_help_heading = "General parameters")] |
| |
131 general: CommandLineArgs, |
| |
132 |
| |
133 #[clap(flatten, next_help_heading = "Algorithm overrides")] |
| |
134 /// Algorithm parametrisation overrides |
| |
135 algorithm_overrides: AlgorithmOverrides<E::FloatType>, |
| |
136 } |
| |
137 |
| |
138 /// Command line algorithm parametrisation overrides |
| |
139 #[skip_serializing_none] |
| |
140 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] |
| |
141 //#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))] |
| |
142 pub struct AlgorithmOverrides<F: ClapFloat> { |
| |
143 #[arg(long, value_names = &["COUNT", "EACH"])] |
| |
144 /// Override bootstrap insertion iterations for --algorithm. |
| |
145 /// |
| |
146 /// The first parameter is the number of bootstrap insertion iterations, and the second |
| |
147 /// the maximum number of iterations on each of them. |
| |
148 bootstrap_insertions: Option<Vec<usize>>, |
| |
149 |
| |
150 #[arg(long, requires = "algorithm")] |
| |
151 /// Primal step length parameter override for --algorithm. |
| |
152 /// |
| |
153 /// Only use if running just a single algorithm, as different algorithms have different |
| |
154 /// regularisation parameters. Does not affect the algorithms fw and fwrelax. |
| |
155 tau0: Option<F>, |
| |
156 |
| |
157 #[arg(long, requires = "algorithm")] |
| |
158 /// Second primal step length parameter override for SlidingPDPS. |
| |
159 /// |
| |
160 /// Only use if running just a single algorithm, as different algorithms have different |
| |
161 /// regularisation parameters. |
| |
162 sigmap0: Option<F>, |
| |
163 |
| |
164 #[arg(long, requires = "algorithm")] |
| |
165 /// Dual step length parameter override for --algorithm. |
| |
166 /// |
| |
167 /// Only use if running just a single algorithm, as different algorithms have different |
| |
168 /// regularisation parameters. Only affects PDPS. |
| |
169 sigma0: Option<F>, |
| |
170 |
| |
171 #[arg(long)] |
| |
172 /// Normalised transport step length for sliding methods. |
| |
173 theta0: Option<F>, |
| |
174 |
| |
175 #[arg(long)] |
| |
176 /// A posteriori transport tolerance multiplier (C_pos) |
| |
177 transport_tolerance_pos: Option<F>, |
| |
178 |
| |
179 #[arg(long)] |
| |
180 /// Transport adaptation factor. Must be in (0, 1). |
| |
181 transport_adaptation: Option<F>, |
| |
182 |
| |
183 #[arg(long)] |
| |
184 /// Minimal step length parameter for sliding methods. |
| |
185 tau0_min: Option<F>, |
| |
186 |
| |
187 #[arg(value_enum, long)] |
| |
188 /// PDPS acceleration, when available. |
| |
189 acceleration: Option<pdps::Acceleration>, |
| |
190 |
| |
191 // #[arg(long)] |
| |
192 // /// Perform postprocess weight optimisation for saved iterations |
| |
193 // /// |
| |
194 // /// Only affects FB, FISTA, and PDPS. |
| |
195 // postprocessing : Option<bool>, |
| |
196 #[arg(value_name = "n", long)] |
| |
197 /// Merging frequency, if merging enabled (every n iterations) |
| |
198 /// |
| |
199 /// Only affects FB, FISTA, and PDPS. |
| |
200 merge_every: Option<usize>, |
| |
201 |
| |
202 #[arg(long)] |
| |
203 /// Enable merging (default: determined by algorithm) |
| |
204 merge: Option<bool>, |
| |
205 |
| |
206 #[arg(long)] |
| |
207 /// Merging radius (default: determined by experiment) |
| |
208 merge_radius: Option<F>, |
| |
209 |
| |
210 #[arg(long)] |
| |
211 /// Interpolate when merging (default : determined by algorithm) |
| |
212 merge_interp: Option<bool>, |
| |
213 |
| |
214 #[arg(long)] |
| |
215 /// Enable final merging (default: determined by algorithm) |
| |
216 final_merging: Option<bool>, |
| |
217 |
| |
218 #[arg(long)] |
| |
219 /// Enable fitness-based merging for relevant FB-type methods. |
| |
220 /// This has worse convergence guarantees that merging based on optimality conditions. |
| |
221 fitness_merging: Option<bool>, |
| |
222 |
| |
223 #[arg(long, value_names = &["ε", "θ", "p"])] |
| |
224 /// Set the tolerance to ε_k = ε/(1+θk)^p |
| |
225 tolerance: Option<Vec<F>>, |
| |
226 } |
| |
227 |
| |
228 /// A generic entry point for binaries based on this library |
| |
229 pub fn common_main<E: ExperimentSetup>() -> DynResult<()> { |
| |
230 let full_cli = FusedCommandLineArgs::<E>::parse(); |
| |
231 let cli = &full_cli.general; |
| |
232 |
| |
233 #[cfg(debug_assertions)] |
| |
234 { |
| |
235 use colored::Colorize; |
| |
236 println!( |
| |
237 "{}", |
| |
238 format!( |
| |
239 "\n\ |
| |
240 ********\n\ |
| |
241 WARNING: Compiled without optimisations; {}\n\ |
| |
242 Please recompile with `--release` flag.\n\ |
| |
243 ********\n\ |
| |
244 ", |
| |
245 "performance will be poor!".blink() |
| |
246 ) |
| |
247 .red() |
| |
248 ); |
| |
249 } |
| |
250 |
| |
251 if let Some(n_threads) = cli.num_threads { |
| |
252 let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); |
| |
253 set_num_threads(n); |
| |
254 } else { |
| |
255 let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); |
| |
256 set_max_threads(m); |
| |
257 } |
| |
258 |
| |
259 for experiment in full_cli.experiment_setup.runnables()? { |
| |
260 let mut algs: Vec<Named<AlgorithmConfig<E::FloatType>>> = cli |
| |
261 .algorithm |
| |
262 .iter() |
| |
263 .map(|alg| { |
| |
264 let cfg = alg |
| |
265 .default_config() |
| |
266 .cli_override(&experiment.algorithm_overrides(*alg)) |
| |
267 .cli_override(&full_cli.algorithm_overrides); |
| |
268 alg.to_named(cfg) |
| |
269 }) |
| |
270 .collect(); |
| |
271 for filename in cli.saved_algorithm.iter() { |
| |
272 let f = std::fs::File::open(filename)?; |
| |
273 let alg = serde_json::from_reader(f)?; |
| |
274 algs.push(alg); |
| |
275 } |
| |
276 experiment.runall(&cli, (!algs.is_empty()).then_some(algs))?; |
| |
277 } |
| |
278 |
| |
279 Ok(()) |
| |
280 } |