src/lib.rs

changeset 70
ed16d0f10d08
parent 67
95bb12bdb6ac
equal deleted inserted replaced
58:6099ba025aac 70:ed16d0f10d08
1 // The main documentation is in the README.
2 // We need to uglify it in build.rs because rustdoc is stuck in the past.
3 #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))]
4 // We use unicode. We would like to use much more of it than Rust allows.
5 // Live with it. Embrace it.
6 #![allow(uncommon_codepoints)]
7 #![allow(mixed_script_confusables)]
8 #![allow(confusable_idents)]
9 // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical
10 // convention while referring to the type (trait) of the operator as `A`.
11 #![allow(non_snake_case)]
12
13 use alg_tools::error::DynResult;
14 use alg_tools::parallelism::{set_max_threads, set_num_threads};
15 use clap::Parser;
16 use serde::{Deserialize, Serialize};
17 use serde_json;
18 use serde_with::skip_serializing_none;
19 use std::num::NonZeroUsize;
20
21 //#[cfg(feature = "pyo3")]
22 //use pyo3::pyclass;
23
24 pub mod dataterm;
25 pub mod experiments;
26 pub mod fb;
27 pub mod forward_model;
28 pub mod forward_pdps;
29 pub mod fourier;
30 pub mod frank_wolfe;
31 pub mod kernels;
32 pub mod pdps;
33 pub mod plot;
34 pub mod preadjoint_helper;
35 pub mod prox_penalty;
36 pub mod rand_distr;
37 pub mod regularisation;
38 pub mod run;
39 pub mod seminorms;
40 pub mod sliding_fb;
41 pub mod sliding_pdps;
42 pub mod subproblem;
43 pub mod tolerance;
44 pub mod types;
45
46 pub mod measures {
47 pub use measures::*;
48 }
49
50 use run::{AlgorithmConfig, DefaultAlgorithm, Named, PlotLevel, RunnableExperiment};
51 use subproblem::InnerMethod;
52 use types::{ClapFloat, Float};
53 use DefaultAlgorithm::*;
54
55 /// Trait for customising the experiments available from the command line
56 pub trait ExperimentSetup:
57 clap::Args + Send + Sync + 'static + Serialize + for<'a> Deserialize<'a>
58 {
59 /// Type of floating point numbers to be used.
60 type FloatType: Float + ClapFloat + for<'b> Deserialize<'b>;
61
62 fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>>;
63 }
64
65 /// Command line parameters
66 #[skip_serializing_none]
67 #[derive(Parser, Debug, Serialize, Default, Clone)]
68 pub struct CommandLineArgs {
69 #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)]
70 /// Maximum iteration count
71 max_iter: usize,
72
73 #[arg(long, short = 'n', value_name = "N")]
74 /// Output status every N iterations. Set to 0 to disable.
75 ///
76 /// The default is to output status based on logarithmic increments.
77 verbose_iter: Option<usize>,
78
79 #[arg(long, short = 'q')]
80 /// Don't display iteration progress
81 quiet: bool,
82
83 /// Default algorithm configration(s) to use on the experiments.
84 ///
85 /// Not all algorithms are available for all the experiments.
86 /// In particular, only PDPS is available for the experiments with L¹ data term.
87 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
88 default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])]
89 algorithm: Vec<DefaultAlgorithm>,
90
91 /// Saved algorithm configration(s) to use on the experiments
92 #[arg(value_name = "JSON_FILE", long)]
93 saved_algorithm: Vec<String>,
94
95 /// Plot saving scheme
96 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)]
97 plot: PlotLevel,
98
99 /// Directory for saving results
100 #[arg(long, short = 'o', required = true, default_value = "out")]
101 outdir: String,
102
103 #[arg(long, help_heading = "Multi-threading", default_value = "4")]
104 /// Maximum number of threads
105 max_threads: usize,
106
107 #[arg(long, help_heading = "Multi-threading")]
108 /// Number of threads. Overrides the maximum number.
109 num_threads: Option<usize>,
110
111 #[arg(long, default_value_t = false)]
112 /// Load saved value ranges (if exists) to do partial update.
113 load_valuerange: bool,
114 }
115
116 #[derive(Parser, Debug, Serialize, Default, Clone)]
117 #[clap(
118 about = env!("CARGO_PKG_DESCRIPTION"),
119 author = env!("CARGO_PKG_AUTHORS"),
120 version = env!("CARGO_PKG_VERSION"),
121 after_help = "Pass --help for longer descriptions.",
122 after_long_help = "",
123 )]
124 struct FusedCommandLineArgs<E: ExperimentSetup> {
125 /// List of experiments to perform.
126 #[clap(flatten, next_help_heading = "Experiment setup")]
127 experiment_setup: E,
128
129 #[clap(flatten, next_help_heading = "General parameters")]
130 general: CommandLineArgs,
131
132 #[clap(flatten, next_help_heading = "Algorithm overrides")]
133 /// Algorithm parametrisation overrides
134 algorithm_overrides: AlgorithmOverrides<E::FloatType>,
135 }
136
137 /// Command line algorithm parametrisation overrides
138 #[skip_serializing_none]
139 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
140 //#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))]
141 pub struct AlgorithmOverrides<F: ClapFloat> {
142 #[arg(long, value_names = &["COUNT", "EACH"])]
143 /// Override bootstrap insertion iterations for --algorithm.
144 ///
145 /// The first parameter is the number of bootstrap insertion iterations, and the second
146 /// the maximum number of iterations on each of them.
147 bootstrap_insertions: Option<Vec<usize>>,
148
149 #[arg(long, requires = "algorithm")]
150 /// Primal step length parameter override for --algorithm.
151 ///
152 /// Only use if running just a single algorithm, as different algorithms have different
153 /// regularisation parameters. Does not affect the algorithms fw and fwrelax.
154 tau0: Option<F>,
155
156 #[arg(long, requires = "algorithm")]
157 /// Second primal step length parameter override for SlidingPDPS.
158 ///
159 /// Only use if running just a single algorithm, as different algorithms have different
160 /// regularisation parameters.
161 sigmap0: Option<F>,
162
163 #[arg(long, requires = "algorithm")]
164 /// Dual step length parameter override for --algorithm.
165 ///
166 /// Only use if running just a single algorithm, as different algorithms have different
167 /// regularisation parameters. Only affects PDPS.
168 sigma0: Option<F>,
169
170 #[arg(long)]
171 /// Normalised transport step length for sliding methods.
172 theta0: Option<F>,
173
174 #[arg(long)]
175 /// A posteriori transport tolerance multiplier (C_pos)
176 transport_tolerance_pos: Option<F>,
177
178 #[arg(long)]
179 /// Transport adaptation factor. Must be in (0, 1).
180 transport_adaptation: Option<F>,
181
182 #[arg(long)]
183 /// Minimal step length parameter for sliding methods.
184 tau0_min: Option<F>,
185
186 #[arg(value_enum, long)]
187 /// PDPS acceleration, when available.
188 acceleration: Option<pdps::Acceleration>,
189
190 // #[arg(long)]
191 // /// Perform postprocess weight optimisation for saved iterations
192 // ///
193 // /// Only affects FB, FISTA, and PDPS.
194 // postprocessing : Option<bool>,
195 #[arg(value_name = "n", long)]
196 /// Merging frequency, if merging enabled (every n iterations)
197 ///
198 /// Only affects FB, FISTA, and PDPS.
199 merge_every: Option<usize>,
200
201 #[arg(long)]
202 /// Enable merging (default: determined by algorithm)
203 merge: Option<bool>,
204
205 #[arg(long)]
206 /// Merging radius (default: determined by experiment)
207 merge_radius: Option<F>,
208
209 #[arg(long)]
210 /// Interpolate when merging (default : determined by algorithm)
211 merge_interp: Option<bool>,
212
213 #[arg(long)]
214 /// Enable final merging (default: determined by algorithm)
215 final_merging: Option<bool>,
216
217 #[arg(long)]
218 /// Enable fitness-based merging for relevant FB-type methods.
219 /// This has worse convergence guarantees that merging based on optimality conditions.
220 fitness_merging: Option<bool>,
221
222 #[arg(long, value_names = &["ε", "θ", "p"])]
223 /// Set the tolerance to ε_k = ε/(1+θk)^p
224 tolerance: Option<Vec<F>>,
225
226 #[arg(long)]
227 /// Method for solving inner optimisation problems
228 inner_method: Option<InnerMethod>,
229
230 #[arg(long)]
231 /// Step length parameter for inner problem
232 inner_τ0: Option<F>,
233
234 #[arg(long, value_names = &["τ0", "σ0"])]
235 /// Dual step length parameter for inner problem
236 inner_pdps_τσ0: Option<Vec<F>>,
237
238 #[arg(long, value_names = &["τ", "growth"])]
239 /// Inner proximal point method step length and its growth
240 inner_pp_τ: Option<Vec<F>>,
241
242 #[arg(long)]
243 /// Inner tolerance multiplier
244 inner_tol: Option<F>,
245 }
246
247 /// A generic entry point for binaries based on this library
248 pub fn common_main<E: ExperimentSetup>() -> DynResult<()> {
249 let full_cli = FusedCommandLineArgs::<E>::parse();
250 let cli = &full_cli.general;
251
252 #[cfg(debug_assertions)]
253 {
254 use colored::Colorize;
255 println!(
256 "{}",
257 format!(
258 "\n\
259 ********\n\
260 WARNING: Compiled without optimisations; {}\n\
261 Please recompile with `--release` flag.\n\
262 ********\n\
263 ",
264 "performance will be poor!".blink()
265 )
266 .red()
267 );
268 }
269
270 if let Some(n_threads) = cli.num_threads {
271 let n = NonZeroUsize::new(n_threads).expect("Invalid thread count");
272 set_num_threads(n);
273 } else {
274 let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count");
275 set_max_threads(m);
276 }
277
278 for experiment in full_cli.experiment_setup.runnables()? {
279 let mut algs: Vec<Named<AlgorithmConfig<E::FloatType>>> = cli
280 .algorithm
281 .iter()
282 .map(|alg| {
283 let cfg = alg
284 .default_config()
285 .cli_override(&experiment.algorithm_overrides(*alg))
286 .cli_override(&full_cli.algorithm_overrides);
287 alg.to_named(cfg)
288 })
289 .collect();
290 for filename in cli.saved_algorithm.iter() {
291 let f = std::fs::File::open(filename)?;
292 let alg = serde_json::from_reader(f)?;
293 algs.push(alg);
294 }
295 experiment.runall(&cli, (!algs.is_empty()).then_some(algs))?;
296 }
297
298 Ok(())
299 }

mercurial