src/lib.rs

branch
dev
changeset 61
4f468d35fa29
child 63
7a8a55fd41c0
equal deleted inserted replaced
60:9738b51d90d7 61:4f468d35fa29
1 // The main documentation is in the README.
2 // We need to uglify it in build.rs because rustdoc is stuck in the past.
3 #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))]
4 // We use unicode. We would like to use much more of it than Rust allows.
5 // Live with it. Embrace it.
6 #![allow(uncommon_codepoints)]
7 #![allow(mixed_script_confusables)]
8 #![allow(confusable_idents)]
9 // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical
10 // convention while referring to the type (trait) of the operator as `A`.
11 #![allow(non_snake_case)]
12 // Need to create parse errors
13 #![feature(dec2flt)]
14
15 use alg_tools::error::DynResult;
16 use alg_tools::parallelism::{set_max_threads, set_num_threads};
17 use clap::Parser;
18 use serde::{Deserialize, Serialize};
19 use serde_json;
20 use serde_with::skip_serializing_none;
21 use std::num::NonZeroUsize;
22
23 //#[cfg(feature = "pyo3")]
24 //use pyo3::pyclass;
25
26 pub mod dataterm;
27 pub mod experiments;
28 pub mod fb;
29 pub mod forward_model;
30 pub mod forward_pdps;
31 pub mod fourier;
32 pub mod frank_wolfe;
33 pub mod kernels;
34 pub mod pdps;
35 pub mod plot;
36 pub mod preadjoint_helper;
37 pub mod prox_penalty;
38 pub mod rand_distr;
39 pub mod regularisation;
40 pub mod run;
41 pub mod seminorms;
42 pub mod sliding_fb;
43 pub mod sliding_pdps;
44 pub mod subproblem;
45 pub mod tolerance;
46 pub mod types;
47
48 pub mod measures {
49 pub use measures::*;
50 }
51
52 use run::{AlgorithmConfig, DefaultAlgorithm, Named, PlotLevel, RunnableExperiment};
53 use types::{ClapFloat, Float};
54 use DefaultAlgorithm::*;
55
56 /// Trait for customising the experiments available from the command line
57 pub trait ExperimentSetup:
58 clap::Args + Send + Sync + 'static + Serialize + for<'a> Deserialize<'a>
59 {
60 /// Type of floating point numbers to be used.
61 type FloatType: Float + ClapFloat + for<'b> Deserialize<'b>;
62
63 fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>>;
64 }
65
66 /// Command line parameters
67 #[skip_serializing_none]
68 #[derive(Parser, Debug, Serialize, Default, Clone)]
69 pub struct CommandLineArgs {
70 #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)]
71 /// Maximum iteration count
72 max_iter: usize,
73
74 #[arg(long, short = 'n', value_name = "N")]
75 /// Output status every N iterations. Set to 0 to disable.
76 ///
77 /// The default is to output status based on logarithmic increments.
78 verbose_iter: Option<usize>,
79
80 #[arg(long, short = 'q')]
81 /// Don't display iteration progress
82 quiet: bool,
83
84 /// Default algorithm configration(s) to use on the experiments.
85 ///
86 /// Not all algorithms are available for all the experiments.
87 /// In particular, only PDPS is available for the experiments with L¹ data term.
88 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
89 default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])]
90 algorithm: Vec<DefaultAlgorithm>,
91
92 /// Saved algorithm configration(s) to use on the experiments
93 #[arg(value_name = "JSON_FILE", long)]
94 saved_algorithm: Vec<String>,
95
96 /// Plot saving scheme
97 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)]
98 plot: PlotLevel,
99
100 /// Directory for saving results
101 #[arg(long, short = 'o', required = true, default_value = "out")]
102 outdir: String,
103
104 #[arg(long, help_heading = "Multi-threading", default_value = "4")]
105 /// Maximum number of threads
106 max_threads: usize,
107
108 #[arg(long, help_heading = "Multi-threading")]
109 /// Number of threads. Overrides the maximum number.
110 num_threads: Option<usize>,
111
112 #[arg(long, default_value_t = false)]
113 /// Load saved value ranges (if exists) to do partial update.
114 load_valuerange: bool,
115 }
116
117 #[derive(Parser, Debug, Serialize, Default, Clone)]
118 #[clap(
119 about = env!("CARGO_PKG_DESCRIPTION"),
120 author = env!("CARGO_PKG_AUTHORS"),
121 version = env!("CARGO_PKG_VERSION"),
122 after_help = "Pass --help for longer descriptions.",
123 after_long_help = "",
124 )]
125 struct FusedCommandLineArgs<E: ExperimentSetup> {
126 /// List of experiments to perform.
127 #[clap(flatten, next_help_heading = "Experiment setup")]
128 experiment_setup: E,
129
130 #[clap(flatten, next_help_heading = "General parameters")]
131 general: CommandLineArgs,
132
133 #[clap(flatten, next_help_heading = "Algorithm overrides")]
134 /// Algorithm parametrisation overrides
135 algorithm_overrides: AlgorithmOverrides<E::FloatType>,
136 }
137
138 /// Command line algorithm parametrisation overrides
139 #[skip_serializing_none]
140 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
141 //#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))]
142 pub struct AlgorithmOverrides<F: ClapFloat> {
143 #[arg(long, value_names = &["COUNT", "EACH"])]
144 /// Override bootstrap insertion iterations for --algorithm.
145 ///
146 /// The first parameter is the number of bootstrap insertion iterations, and the second
147 /// the maximum number of iterations on each of them.
148 bootstrap_insertions: Option<Vec<usize>>,
149
150 #[arg(long, requires = "algorithm")]
151 /// Primal step length parameter override for --algorithm.
152 ///
153 /// Only use if running just a single algorithm, as different algorithms have different
154 /// regularisation parameters. Does not affect the algorithms fw and fwrelax.
155 tau0: Option<F>,
156
157 #[arg(long, requires = "algorithm")]
158 /// Second primal step length parameter override for SlidingPDPS.
159 ///
160 /// Only use if running just a single algorithm, as different algorithms have different
161 /// regularisation parameters.
162 sigmap0: Option<F>,
163
164 #[arg(long, requires = "algorithm")]
165 /// Dual step length parameter override for --algorithm.
166 ///
167 /// Only use if running just a single algorithm, as different algorithms have different
168 /// regularisation parameters. Only affects PDPS.
169 sigma0: Option<F>,
170
171 #[arg(long)]
172 /// Normalised transport step length for sliding methods.
173 theta0: Option<F>,
174
175 #[arg(long)]
176 /// A posteriori transport tolerance multiplier (C_pos)
177 transport_tolerance_pos: Option<F>,
178
179 #[arg(long)]
180 /// Transport adaptation factor. Must be in (0, 1).
181 transport_adaptation: Option<F>,
182
183 #[arg(long)]
184 /// Minimal step length parameter for sliding methods.
185 tau0_min: Option<F>,
186
187 #[arg(value_enum, long)]
188 /// PDPS acceleration, when available.
189 acceleration: Option<pdps::Acceleration>,
190
191 // #[arg(long)]
192 // /// Perform postprocess weight optimisation for saved iterations
193 // ///
194 // /// Only affects FB, FISTA, and PDPS.
195 // postprocessing : Option<bool>,
196 #[arg(value_name = "n", long)]
197 /// Merging frequency, if merging enabled (every n iterations)
198 ///
199 /// Only affects FB, FISTA, and PDPS.
200 merge_every: Option<usize>,
201
202 #[arg(long)]
203 /// Enable merging (default: determined by algorithm)
204 merge: Option<bool>,
205
206 #[arg(long)]
207 /// Merging radius (default: determined by experiment)
208 merge_radius: Option<F>,
209
210 #[arg(long)]
211 /// Interpolate when merging (default : determined by algorithm)
212 merge_interp: Option<bool>,
213
214 #[arg(long)]
215 /// Enable final merging (default: determined by algorithm)
216 final_merging: Option<bool>,
217
218 #[arg(long)]
219 /// Enable fitness-based merging for relevant FB-type methods.
220 /// This has worse convergence guarantees that merging based on optimality conditions.
221 fitness_merging: Option<bool>,
222
223 #[arg(long, value_names = &["ε", "θ", "p"])]
224 /// Set the tolerance to ε_k = ε/(1+θk)^p
225 tolerance: Option<Vec<F>>,
226 }
227
228 /// A generic entry point for binaries based on this library
229 pub fn common_main<E: ExperimentSetup>() -> DynResult<()> {
230 let full_cli = FusedCommandLineArgs::<E>::parse();
231 let cli = &full_cli.general;
232
233 #[cfg(debug_assertions)]
234 {
235 use colored::Colorize;
236 println!(
237 "{}",
238 format!(
239 "\n\
240 ********\n\
241 WARNING: Compiled without optimisations; {}\n\
242 Please recompile with `--release` flag.\n\
243 ********\n\
244 ",
245 "performance will be poor!".blink()
246 )
247 .red()
248 );
249 }
250
251 if let Some(n_threads) = cli.num_threads {
252 let n = NonZeroUsize::new(n_threads).expect("Invalid thread count");
253 set_num_threads(n);
254 } else {
255 let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count");
256 set_max_threads(m);
257 }
258
259 for experiment in full_cli.experiment_setup.runnables()? {
260 let mut algs: Vec<Named<AlgorithmConfig<E::FloatType>>> = cli
261 .algorithm
262 .iter()
263 .map(|alg| {
264 let cfg = alg
265 .default_config()
266 .cli_override(&experiment.algorithm_overrides(*alg))
267 .cli_override(&full_cli.algorithm_overrides);
268 alg.to_named(cfg)
269 })
270 .collect();
271 for filename in cli.saved_algorithm.iter() {
272 let f = std::fs::File::open(filename)?;
273 let alg = serde_json::from_reader(f)?;
274 algs.push(alg);
275 }
276 experiment.runall(&cli, (!algs.is_empty()).then_some(algs))?;
277 }
278
279 Ok(())
280 }

mercurial