Thu, 23 Jan 2025 23:34:05 +0100
Merging adjustments, parameter tuning, etc.
// The main documentation is in the README. // We need to uglify it in build.rs because rustdoc is stuck in the past. #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))] // We use unicode. We would like to use much more of it than Rust allows. // Live with it. Embrace it. #![allow(uncommon_codepoints)] #![allow(mixed_script_confusables)] #![allow(confusable_idents)] // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical // convention while referring to the type (trait) of the operator as `A`. #![allow(non_snake_case)] // Need to create parse errors #![feature(dec2flt)] use clap::Parser; use serde::{Serialize, Deserialize}; use serde_json; use serde_with::skip_serializing_none; use itertools::Itertools; use std::num::NonZeroUsize; use alg_tools::parallelism::{ set_num_threads, set_max_threads, }; pub mod types; pub mod measures; pub mod fourier; pub mod kernels; pub mod seminorms; pub mod transport; pub mod forward_model; pub mod preadjoint_helper; pub mod plot; pub mod subproblem; pub mod tolerance; pub mod regularisation; pub mod dataterm; pub mod prox_penalty; pub mod fb; pub mod sliding_fb; pub mod sliding_pdps; pub mod forward_pdps; pub mod frank_wolfe; pub mod pdps; pub mod run; pub mod rand_distr; pub mod experiments; use types::{float, ClapFloat}; use run::{ DefaultAlgorithm, PlotLevel, Named, AlgorithmConfig, }; use experiments::DefaultExperiment; use DefaultExperiment::*; use DefaultAlgorithm::*; /// Command line parameters #[skip_serializing_none] #[derive(Parser, Debug, Serialize, Default, Clone)] #[clap( about = env!("CARGO_PKG_DESCRIPTION"), author = env!("CARGO_PKG_AUTHORS"), version = env!("CARGO_PKG_VERSION"), after_help = "Pass --help for longer descriptions.", after_long_help = "", )] pub struct CommandLineArgs { #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)] /// Maximum iteration count max_iter : usize, #[arg(long, short = 'n', value_name = "N")] /// Output status every N iterations. Set to 0 to disable. /// /// The default is to output status based on logarithmic increments. verbose_iter : Option<usize>, #[arg(long, short = 'q')] /// Don't display iteration progress quiet : bool, /// List of experiments to perform. #[arg(value_enum, value_name = "EXPERIMENT", default_values_t = [Experiment1D, Experiment1DFast, Experiment2D, Experiment2DFast, Experiment1D_L1])] experiments : Vec<DefaultExperiment>, /// Default algorithm configration(s) to use on the experiments. /// /// Not all algorithms are available for all the experiments. /// In particular, only PDPS is available for the experiments with L¹ data term. #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])] algorithm : Vec<DefaultAlgorithm>, /// Saved algorithm configration(s) to use on the experiments #[arg(value_name = "JSON_FILE", long)] saved_algorithm : Vec<String>, /// Plot saving scheme #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] plot : PlotLevel, /// Directory for saving results #[arg(long, short = 'o', required = true, default_value = "out")] outdir : String, #[arg(long, help_heading = "Multi-threading", default_value = "4")] /// Maximum number of threads max_threads : usize, #[arg(long, help_heading = "Multi-threading")] /// Number of threads. Overrides the maximum number. num_threads : Option<usize>, #[arg(long, default_value_t = false)] /// Load saved value ranges (if exists) to do partial update. load_valuerange : bool, #[clap(flatten, next_help_heading = "Experiment overrides")] /// Experiment setup overrides experiment_overrides : ExperimentOverrides<float>, #[clap(flatten, next_help_heading = "Algorithm overrides")] /// Algorithm parametrisation overrides algoritm_overrides : AlgorithmOverrides<float>, } /// Command line experiment setup overrides #[skip_serializing_none] #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] pub struct ExperimentOverrides<F : ClapFloat> { #[arg(long)] /// Regularisation parameter override. /// /// Only use if running just a single experiment, as different experiments have different /// regularisation parameters. alpha : Option<F>, #[arg(long)] /// Gaussian noise variance override variance : Option<F>, #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])] /// Salt and pepper noise override. salt_and_pepper : Option<Vec<F>>, #[arg(long)] /// Noise seed noise_seed : Option<u64>, } /// Command line algorithm parametrisation overrides #[skip_serializing_none] #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] pub struct AlgorithmOverrides<F : ClapFloat> { #[arg(long, value_names = &["COUNT", "EACH"])] /// Override bootstrap insertion iterations for --algorithm. /// /// The first parameter is the number of bootstrap insertion iterations, and the second /// the maximum number of iterations on each of them. bootstrap_insertions : Option<Vec<usize>>, #[arg(long, requires = "algorithm")] /// Primal step length parameter override for --algorithm. /// /// Only use if running just a single algorithm, as different algorithms have different /// regularisation parameters. Does not affect the algorithms fw and fwrelax. tau0 : Option<F>, #[arg(long, requires = "algorithm")] /// Second primal step length parameter override for SlidingPDPS. /// /// Only use if running just a single algorithm, as different algorithms have different /// regularisation parameters. sigmap0 : Option<F>, #[arg(long, requires = "algorithm")] /// Dual step length parameter override for --algorithm. /// /// Only use if running just a single algorithm, as different algorithms have different /// regularisation parameters. Only affects PDPS. sigma0 : Option<F>, #[arg(long)] /// Normalised transport step length for sliding methods. theta0 : Option<F>, #[arg(long)] /// A priori transport tolerance multiplier (C_pri) transport_tolerance_pri : Option<F>, #[arg(long)] /// A posteriori transport tolerance multiplier (C_pos) transport_tolerance_pos : Option<F>, #[arg(long)] /// Transport adaptation factor. Must be in (0, 1). transport_adaptation : Option<F>, #[arg(long)] /// Minimal step length parameter for sliding methods. tau0_min : Option<F>, #[arg(value_enum, long)] /// PDPS acceleration, when available. acceleration : Option<pdps::Acceleration>, // #[arg(long)] // /// Perform postprocess weight optimisation for saved iterations // /// // /// Only affects FB, FISTA, and PDPS. // postprocessing : Option<bool>, #[arg(value_name = "n", long)] /// Merging frequency, if merging enabled (every n iterations) /// /// Only affects FB, FISTA, and PDPS. merge_every : Option<usize>, #[arg(long)] /// Enable merging (default: determined by algorithm) merge : Option<bool>, #[arg(long)] /// Merging radius (default: determined by experiment) merge_radius : Option<F>, #[arg(long)] /// Interpolate when merging (default : determined by algorithm) merge_interp : Option<bool>, #[arg(long)] /// Enable final merging (default: determined by algorithm) final_merging : Option<bool>, #[arg(long)] /// Enable fitness-based merging for relevant FB-type methods. /// This has worse convergence guarantees that merging based on optimality conditions. fitness_merging : Option<bool>, #[arg(long, value_names = &["ε", "θ", "p"])] /// Set the tolerance to ε_k = ε/(1+θk)^p tolerance : Option<Vec<F>>, } /// The entry point for the program. pub fn main() { let cli = CommandLineArgs::parse(); #[cfg(debug_assertions)] { use colored::Colorize; println!("{}", format!("\n\ ********\n\ WARNING: Compiled without optimisations; {}\n\ Please recompile with `--release` flag.\n\ ********\n\ ", "performance will be poor!".blink() ).red()); } if let Some(n_threads) = cli.num_threads { let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); set_num_threads(n); } else { let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); set_max_threads(m); } for experiment_shorthand in cli.experiments.iter().unique() { let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); let mut algs : Vec<Named<AlgorithmConfig<float>>> = cli.algorithm .iter() .map(|alg| { let cfg = alg.default_config() .cli_override(&experiment.algorithm_overrides(*alg)) .cli_override(&cli.algoritm_overrides); alg.to_named(cfg) }) .collect(); for filename in cli.saved_algorithm.iter() { let f = std::fs::File::open(filename).unwrap(); let alg = serde_json::from_reader(f).unwrap(); algs.push(alg); } experiment.runall(&cli, (!algs.is_empty()).then_some(algs)) .unwrap() } }