| 1 // The main documentation is in the README. |
1 // The main documentation is in the README. |
| 2 // We need to uglify it in build.rs because rustdoc is stuck in the past. |
2 // We need to uglify it in build.rs because rustdoc is stuck in the past. |
| 3 #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))] |
3 #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))] |
| 4 |
4 |
| 5 // We use unicode. We would like to use much more of it than Rust allows. |
5 use pointsource_algs::{common_main, experiments::DefaultExperimentSetup}; |
| 6 // Live with it. Embrace it. |
|
| 7 #![allow(uncommon_codepoints)] |
|
| 8 #![allow(mixed_script_confusables)] |
|
| 9 #![allow(confusable_idents)] |
|
| 10 // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical |
|
| 11 // convention while referring to the type (trait) of the operator as `A`. |
|
| 12 #![allow(non_snake_case)] |
|
| 13 // Need to create parse errors |
|
| 14 #![feature(dec2flt)] |
|
| 15 |
|
| 16 use clap::Parser; |
|
| 17 use serde::{Serialize, Deserialize}; |
|
| 18 use serde_json; |
|
| 19 use serde_with::skip_serializing_none; |
|
| 20 use itertools::Itertools; |
|
| 21 use std::num::NonZeroUsize; |
|
| 22 |
|
| 23 use alg_tools::parallelism::{ |
|
| 24 set_num_threads, |
|
| 25 set_max_threads, |
|
| 26 }; |
|
| 27 |
|
| 28 pub mod types; |
|
| 29 pub mod measures; |
|
| 30 pub mod fourier; |
|
| 31 pub mod kernels; |
|
| 32 pub mod seminorms; |
|
| 33 pub mod forward_model; |
|
| 34 pub mod preadjoint_helper; |
|
| 35 pub mod plot; |
|
| 36 pub mod subproblem; |
|
| 37 pub mod tolerance; |
|
| 38 pub mod regularisation; |
|
| 39 pub mod dataterm; |
|
| 40 pub mod prox_penalty; |
|
| 41 pub mod fb; |
|
| 42 pub mod sliding_fb; |
|
| 43 pub mod sliding_pdps; |
|
| 44 pub mod forward_pdps; |
|
| 45 pub mod frank_wolfe; |
|
| 46 pub mod pdps; |
|
| 47 pub mod run; |
|
| 48 pub mod rand_distr; |
|
| 49 pub mod experiments; |
|
| 50 |
|
| 51 use types::{float, ClapFloat}; |
|
| 52 use run::{ |
|
| 53 DefaultAlgorithm, |
|
| 54 PlotLevel, |
|
| 55 Named, |
|
| 56 AlgorithmConfig, |
|
| 57 }; |
|
| 58 use experiments::DefaultExperiment; |
|
| 59 use DefaultExperiment::*; |
|
| 60 use DefaultAlgorithm::*; |
|
| 61 |
|
| 62 /// Command line parameters |
|
| 63 #[skip_serializing_none] |
|
| 64 #[derive(Parser, Debug, Serialize, Default, Clone)] |
|
| 65 #[clap( |
|
| 66 about = env!("CARGO_PKG_DESCRIPTION"), |
|
| 67 author = env!("CARGO_PKG_AUTHORS"), |
|
| 68 version = env!("CARGO_PKG_VERSION"), |
|
| 69 after_help = "Pass --help for longer descriptions.", |
|
| 70 after_long_help = "", |
|
| 71 )] |
|
| 72 pub struct CommandLineArgs { |
|
| 73 #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)] |
|
| 74 /// Maximum iteration count |
|
| 75 max_iter : usize, |
|
| 76 |
|
| 77 #[arg(long, short = 'n', value_name = "N")] |
|
| 78 /// Output status every N iterations. Set to 0 to disable. |
|
| 79 /// |
|
| 80 /// The default is to output status based on logarithmic increments. |
|
| 81 verbose_iter : Option<usize>, |
|
| 82 |
|
| 83 #[arg(long, short = 'q')] |
|
| 84 /// Don't display iteration progress |
|
| 85 quiet : bool, |
|
| 86 |
|
| 87 /// List of experiments to perform. |
|
| 88 #[arg(value_enum, value_name = "EXPERIMENT", |
|
| 89 default_values_t = [Experiment1D, Experiment1DFast, |
|
| 90 Experiment2D, Experiment2DFast, |
|
| 91 Experiment1D_L1])] |
|
| 92 experiments : Vec<DefaultExperiment>, |
|
| 93 |
|
| 94 /// Default algorithm configration(s) to use on the experiments. |
|
| 95 /// |
|
| 96 /// Not all algorithms are available for all the experiments. |
|
| 97 /// In particular, only PDPS is available for the experiments with L¹ data term. |
|
| 98 #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', |
|
| 99 default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])] |
|
| 100 algorithm : Vec<DefaultAlgorithm>, |
|
| 101 |
|
| 102 /// Saved algorithm configration(s) to use on the experiments |
|
| 103 #[arg(value_name = "JSON_FILE", long)] |
|
| 104 saved_algorithm : Vec<String>, |
|
| 105 |
|
| 106 /// Plot saving scheme |
|
| 107 #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] |
|
| 108 plot : PlotLevel, |
|
| 109 |
|
| 110 /// Directory for saving results |
|
| 111 #[arg(long, short = 'o', required = true, default_value = "out")] |
|
| 112 outdir : String, |
|
| 113 |
|
| 114 #[arg(long, help_heading = "Multi-threading", default_value = "4")] |
|
| 115 /// Maximum number of threads |
|
| 116 max_threads : usize, |
|
| 117 |
|
| 118 #[arg(long, help_heading = "Multi-threading")] |
|
| 119 /// Number of threads. Overrides the maximum number. |
|
| 120 num_threads : Option<usize>, |
|
| 121 |
|
| 122 #[arg(long, default_value_t = false)] |
|
| 123 /// Load saved value ranges (if exists) to do partial update. |
|
| 124 load_valuerange : bool, |
|
| 125 |
|
| 126 #[clap(flatten, next_help_heading = "Experiment overrides")] |
|
| 127 /// Experiment setup overrides |
|
| 128 experiment_overrides : ExperimentOverrides<float>, |
|
| 129 |
|
| 130 #[clap(flatten, next_help_heading = "Algorithm overrides")] |
|
| 131 /// Algorithm parametrisation overrides |
|
| 132 algoritm_overrides : AlgorithmOverrides<float>, |
|
| 133 } |
|
| 134 |
|
| 135 /// Command line experiment setup overrides |
|
| 136 #[skip_serializing_none] |
|
| 137 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] |
|
| 138 pub struct ExperimentOverrides<F : ClapFloat> { |
|
| 139 #[arg(long)] |
|
| 140 /// Regularisation parameter override. |
|
| 141 /// |
|
| 142 /// Only use if running just a single experiment, as different experiments have different |
|
| 143 /// regularisation parameters. |
|
| 144 alpha : Option<F>, |
|
| 145 |
|
| 146 #[arg(long)] |
|
| 147 /// Gaussian noise variance override |
|
| 148 variance : Option<F>, |
|
| 149 |
|
| 150 #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])] |
|
| 151 /// Salt and pepper noise override. |
|
| 152 salt_and_pepper : Option<Vec<F>>, |
|
| 153 |
|
| 154 #[arg(long)] |
|
| 155 /// Noise seed |
|
| 156 noise_seed : Option<u64>, |
|
| 157 } |
|
| 158 |
|
| 159 /// Command line algorithm parametrisation overrides |
|
| 160 #[skip_serializing_none] |
|
| 161 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] |
|
| 162 pub struct AlgorithmOverrides<F : ClapFloat> { |
|
| 163 #[arg(long, value_names = &["COUNT", "EACH"])] |
|
| 164 /// Override bootstrap insertion iterations for --algorithm. |
|
| 165 /// |
|
| 166 /// The first parameter is the number of bootstrap insertion iterations, and the second |
|
| 167 /// the maximum number of iterations on each of them. |
|
| 168 bootstrap_insertions : Option<Vec<usize>>, |
|
| 169 |
|
| 170 #[arg(long, requires = "algorithm")] |
|
| 171 /// Primal step length parameter override for --algorithm. |
|
| 172 /// |
|
| 173 /// Only use if running just a single algorithm, as different algorithms have different |
|
| 174 /// regularisation parameters. Does not affect the algorithms fw and fwrelax. |
|
| 175 tau0 : Option<F>, |
|
| 176 |
|
| 177 #[arg(long, requires = "algorithm")] |
|
| 178 /// Second primal step length parameter override for SlidingPDPS. |
|
| 179 /// |
|
| 180 /// Only use if running just a single algorithm, as different algorithms have different |
|
| 181 /// regularisation parameters. |
|
| 182 sigmap0 : Option<F>, |
|
| 183 |
|
| 184 #[arg(long, requires = "algorithm")] |
|
| 185 /// Dual step length parameter override for --algorithm. |
|
| 186 /// |
|
| 187 /// Only use if running just a single algorithm, as different algorithms have different |
|
| 188 /// regularisation parameters. Only affects PDPS. |
|
| 189 sigma0 : Option<F>, |
|
| 190 |
|
| 191 #[arg(long)] |
|
| 192 /// Normalised transport step length for sliding methods. |
|
| 193 theta0 : Option<F>, |
|
| 194 |
|
| 195 #[arg(long)] |
|
| 196 /// A posteriori transport tolerance multiplier (C_pos) |
|
| 197 transport_tolerance_pos : Option<F>, |
|
| 198 |
|
| 199 #[arg(long)] |
|
| 200 /// Transport adaptation factor. Must be in (0, 1). |
|
| 201 transport_adaptation : Option<F>, |
|
| 202 |
|
| 203 #[arg(long)] |
|
| 204 /// Minimal step length parameter for sliding methods. |
|
| 205 tau0_min : Option<F>, |
|
| 206 |
|
| 207 #[arg(value_enum, long)] |
|
| 208 /// PDPS acceleration, when available. |
|
| 209 acceleration : Option<pdps::Acceleration>, |
|
| 210 |
|
| 211 // #[arg(long)] |
|
| 212 // /// Perform postprocess weight optimisation for saved iterations |
|
| 213 // /// |
|
| 214 // /// Only affects FB, FISTA, and PDPS. |
|
| 215 // postprocessing : Option<bool>, |
|
| 216 |
|
| 217 #[arg(value_name = "n", long)] |
|
| 218 /// Merging frequency, if merging enabled (every n iterations) |
|
| 219 /// |
|
| 220 /// Only affects FB, FISTA, and PDPS. |
|
| 221 merge_every : Option<usize>, |
|
| 222 |
|
| 223 #[arg(long)] |
|
| 224 /// Enable merging (default: determined by algorithm) |
|
| 225 merge : Option<bool>, |
|
| 226 |
|
| 227 #[arg(long)] |
|
| 228 /// Merging radius (default: determined by experiment) |
|
| 229 merge_radius : Option<F>, |
|
| 230 |
|
| 231 #[arg(long)] |
|
| 232 /// Interpolate when merging (default : determined by algorithm) |
|
| 233 merge_interp : Option<bool>, |
|
| 234 |
|
| 235 #[arg(long)] |
|
| 236 /// Enable final merging (default: determined by algorithm) |
|
| 237 final_merging : Option<bool>, |
|
| 238 |
|
| 239 #[arg(long)] |
|
| 240 /// Enable fitness-based merging for relevant FB-type methods. |
|
| 241 /// This has worse convergence guarantees that merging based on optimality conditions. |
|
| 242 fitness_merging : Option<bool>, |
|
| 243 |
|
| 244 #[arg(long, value_names = &["ε", "θ", "p"])] |
|
| 245 /// Set the tolerance to ε_k = ε/(1+θk)^p |
|
| 246 tolerance : Option<Vec<F>>, |
|
| 247 |
|
| 248 } |
|
| 249 |
6 |
| 250 /// The entry point for the program. |
7 /// The entry point for the program. |
| 251 pub fn main() { |
8 pub fn main() { |
| 252 let cli = CommandLineArgs::parse(); |
9 common_main::<DefaultExperimentSetup<f64>>().unwrap(); |
| 253 |
|
| 254 #[cfg(debug_assertions)] |
|
| 255 { |
|
| 256 use colored::Colorize; |
|
| 257 println!("{}", format!("\n\ |
|
| 258 ********\n\ |
|
| 259 WARNING: Compiled without optimisations; {}\n\ |
|
| 260 Please recompile with `--release` flag.\n\ |
|
| 261 ********\n\ |
|
| 262 ", "performance will be poor!".blink() |
|
| 263 ).red()); |
|
| 264 } |
|
| 265 |
|
| 266 if let Some(n_threads) = cli.num_threads { |
|
| 267 let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); |
|
| 268 set_num_threads(n); |
|
| 269 } else { |
|
| 270 let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); |
|
| 271 set_max_threads(m); |
|
| 272 } |
|
| 273 |
|
| 274 for experiment_shorthand in cli.experiments.iter().unique() { |
|
| 275 let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); |
|
| 276 let mut algs : Vec<Named<AlgorithmConfig<float>>> |
|
| 277 = cli.algorithm |
|
| 278 .iter() |
|
| 279 .map(|alg| { |
|
| 280 let cfg = alg.default_config() |
|
| 281 .cli_override(&experiment.algorithm_overrides(*alg)) |
|
| 282 .cli_override(&cli.algoritm_overrides); |
|
| 283 alg.to_named(cfg) |
|
| 284 }) |
|
| 285 .collect(); |
|
| 286 for filename in cli.saved_algorithm.iter() { |
|
| 287 let f = std::fs::File::open(filename).unwrap(); |
|
| 288 let alg = serde_json::from_reader(f).unwrap(); |
|
| 289 algs.push(alg); |
|
| 290 } |
|
| 291 experiment.runall(&cli, (!algs.is_empty()).then_some(algs)) |
|
| 292 .unwrap() |
|
| 293 } |
|
| 294 } |
10 } |