src/main.rs

Sun, 11 Dec 2022 23:19:17 +0200

author
Tuomo Valkonen <tuomov@iki.fi>
date
Sun, 11 Dec 2022 23:19:17 +0200
changeset 23
9869fa1e0ccd
parent 20
90f77ad9a98d
child 24
d29d1fcf5423
permissions
-rw-r--r--

Print out experiment information when running it

// The main documentation is in the README.
// We need to uglify it in build.rs because rustdoc is stuck in the past.
#![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))]

// We use unicode. We would like to use much more of it than Rust allows.
// Live with it. Embrace it.
#![allow(uncommon_codepoints)]
#![allow(mixed_script_confusables)]
#![allow(confusable_idents)]
// Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical
// convention while referring to the type (trait) of the operator as `A`.
#![allow(non_snake_case)]
// We need the drain filter for inertial prune.
#![feature(drain_filter)]

use clap::Parser;
use serde::{Serialize, Deserialize};
use serde_json;
use itertools::Itertools;
use std::num::NonZeroUsize;

use alg_tools::parallelism::{
    set_num_threads,
    set_max_threads,
};

pub mod types;
pub mod measures;
pub mod fourier;
pub mod kernels;
pub mod seminorms;
pub mod forward_model;
pub mod plot;
pub mod subproblem;
pub mod tolerance;
pub mod fb;
pub mod frank_wolfe;
pub mod pdps;
pub mod run;
pub mod rand_distr;
pub mod experiments;

use types::{float, ClapFloat};
use run::{
    DefaultAlgorithm,
    PlotLevel,
    Named,
    AlgorithmConfig,
};
use experiments::DefaultExperiment;
use measures::merging::SpikeMergingMethod;
use DefaultExperiment::*;
use DefaultAlgorithm::*;

/// Command line parameters
#[derive(Parser, Debug, Serialize)]
#[clap(
    about = env!("CARGO_PKG_DESCRIPTION"),
    author = env!("CARGO_PKG_AUTHORS"),
    version = env!("CARGO_PKG_VERSION"),
    after_help = "Pass --help for longer descriptions.",
    after_long_help = "",
)]
pub struct CommandLineArgs {
    #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)]
    /// Maximum iteration count
    max_iter : usize,

    #[arg(long, short = 'n', value_name = "N")]
    /// Output status every N iterations. Set to 0 to disable.
    ///
    /// The default is to output status based on logarithmic increments.
    verbose_iter : Option<usize>,

    #[arg(long, short = 'q')]
    /// Don't display iteration progress
    quiet : bool,

    /// List of experiments to perform.
    #[arg(value_enum, value_name = "EXPERIMENT",
           default_values_t = [Experiment1D, Experiment1DFast,
                               Experiment2D, Experiment2DFast,
                               Experiment1D_L1])]
    experiments : Vec<DefaultExperiment>,

    /// Default algorithm configration(s) to use on the experiments.
    ///
    /// Not all algorithms are available for all the experiments.
    /// In particular, only PDPS is available for the experiments with L¹ data term.
    #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
           default_values_t = [FB, FISTA, PDPS, FW, FWRelax])]
    algorithm : Vec<DefaultAlgorithm>,

    /// Saved algorithm configration(s) to use on the experiments
    #[arg(value_name = "JSON_FILE", long)]
    saved_algorithm : Vec<String>,

    /// Plot saving scheme
    #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)]
    plot : PlotLevel,

    /// Directory for saving results
    #[arg(long, short = 'o', required = true, default_value = "out")]
    outdir : String,

    #[arg(long, help_heading = "Multi-threading", default_value = "4")]
    /// Maximum number of threads
    max_threads : usize,

    #[arg(long, help_heading = "Multi-threading")]
    /// Number of threads. Overrides the maximum number.
    num_threads : Option<usize>,

    #[clap(flatten, next_help_heading = "Experiment overrides")]
    /// Experiment setup overrides
    experiment_overrides : ExperimentOverrides<float>,

    #[clap(flatten, next_help_heading = "Algorithm overrides")]
    /// Algorithm parametrisation overrides
    algoritm_overrides : AlgorithmOverrides<float>,
}

/// Command line experiment setup overrides
#[derive(Parser, Debug, Serialize, Deserialize)]
pub struct ExperimentOverrides<F : ClapFloat> {
    #[arg(long)]
    /// Regularisation parameter override.
    ///
    /// Only use if running just a single experiment, as different experiments have different
    /// regularisation parameters.
    alpha : Option<F>,

    #[arg(long)]
    /// Gaussian noise variance override
    variance : Option<F>,

    #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])]
    /// Salt and pepper noise override.
    salt_and_pepper : Option<Vec<F>>,

    #[arg(long)]
    /// Noise seed
    noise_seed : Option<u64>,
}

/// Command line algorithm parametrisation overrides
#[derive(Parser, Debug, Serialize, Deserialize)]
pub struct AlgorithmOverrides<F : ClapFloat> {
    #[arg(long, value_names = &["COUNT", "EACH"])]
    /// Override bootstrap insertion iterations for --algorithm.
    ///
    /// The first parameter is the number of bootstrap insertion iterations, and the second
    /// the maximum number of iterations on each of them.
    bootstrap_insertions : Option<Vec<usize>>,

    #[arg(long, requires = "algorithm")]
    /// Primal step length parameter override for --algorithm.
    ///
    /// Only use if running just a single algorithm, as different algorithms have different
    /// regularisation parameters. Does not affect the algorithms fw and fwrelax.
    tau0 : Option<F>,

    #[arg(long, requires = "algorithm")]
    /// Dual step length parameter override for --algorithm.
    ///
    /// Only use if running just a single algorithm, as different algorithms have different
    /// regularisation parameters. Only affects PDPS.
    sigma0 : Option<F>,

    #[arg(value_enum, long)]
    /// PDPS acceleration, when available.
    acceleration : Option<pdps::Acceleration>,

    #[arg(long)]
    /// Perform postprocess weight optimisation for saved iterations
    ///
    /// Only affects FB, FISTA, and PDPS.
    postprocessing : Option<bool>,

    #[arg(value_name = "n", long)]
    /// Merging frequency, if merging enabled (every n iterations)
    ///
    /// Only affects FB, FISTA, and PDPS.
    merge_every : Option<usize>,

    #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())]
    /// Merging strategy
    ///
    /// Either the string "none", or a radius value for heuristic merging.
    merging : Option<SpikeMergingMethod<F>>,

    #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())]
    /// Final merging strategy
    ///
    /// Either the string "none", or a radius value for heuristic merging.
    /// Only affects FB, FISTA, and PDPS.
    final_merging : Option<SpikeMergingMethod<F>>,

    #[arg(long, value_names = &["ε", "θ", "p"])]
    /// Set the tolerance to ε_k = ε/(1+θk)^p
    tolerance : Option<Vec<F>>,

}

/// The entry point for the program.
pub fn main() {
    let cli = CommandLineArgs::parse();

    #[cfg(debug_assertions)]
    {
        use colored::Colorize;
        println!("{}", format!("\n\
            ********\n\
            WARNING: Compiled without optimisations; {}\n\
            Please recompile with `--release` flag.\n\
            ********\n\
            ", "performance will be poor!".blink()
        ).red());
    }

    if let Some(n_threads) = cli.num_threads {
        let n = NonZeroUsize::new(n_threads).expect("Invalid thread count");
        set_num_threads(n);
    } else {
        let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count");
        set_max_threads(m);
    }

    for experiment_shorthand in cli.experiments.iter().unique() {
        let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap();
        let mut algs : Vec<Named<AlgorithmConfig<float>>>
            = cli.algorithm.iter()
                            .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides))
                            .collect();
        for filename in cli.saved_algorithm.iter() {
            let f = std::fs::File::open(filename).unwrap();
            let alg = serde_json::from_reader(f).unwrap();
            algs.push(alg);
        }
        experiment.runall(&cli, (!algs.is_empty()).then_some(algs))
                  .unwrap()
    }
}

mercurial