src/main.rs

Tue, 29 Nov 2022 15:36:12 +0200

author
Tuomo Valkonen <tuomov@iki.fi>
date
Tue, 29 Nov 2022 15:36:12 +0200
changeset 2
7a953a87b6c1
parent 0
eb3c7813b67a
permissions
-rw-r--r--

fubar

// The main documentation is in the README.
#![doc = include_str!("../README.md")]

// We use unicode. We would like to use much more of it than Rust allows.
// Live with it. Embrace it.
#![allow(uncommon_codepoints)]
#![allow(mixed_script_confusables)]
#![allow(confusable_idents)]
// Linear operators may be writtten e.g. as `opA` for a resemblance
// to mathematical convention.
#![allow(non_snake_case)]
// We need the drain filter for inertial prune
#![feature(drain_filter)]

use clap::Parser;
use itertools::Itertools;
use serde_json;
use alg_tools::iterate::Verbose;
use alg_tools::parallelism::{
    set_num_threads,
    set_max_threads,
};
use std::num::NonZeroUsize;

pub mod types;
pub mod measures;
pub mod fourier;
pub mod kernels;
pub mod seminorms;
pub mod forward_model;
pub mod plot;
pub mod subproblem;
pub mod weight_optim;
pub mod tolerance;
pub mod fb;
pub mod frank_wolfe;
pub mod pdps;
pub mod run;
pub mod rand_distr;
pub mod experiments;

use types::{float, ClapFloat};
use run::{
    DefaultAlgorithm,
    Configuration,
    PlotLevel,
    Named,
    AlgorithmConfig,
};
use experiments::DefaultExperiment;
use measures::merging::SpikeMergingMethod;
use DefaultExperiment::*;
use DefaultAlgorithm::*;

/// Command line parameters
#[derive(Parser, Debug)]
#[clap(
    about = env!("CARGO_PKG_DESCRIPTION"),
    author = env!("CARGO_PKG_AUTHORS"),
    version = env!("CARGO_PKG_VERSION"),
    after_help = "Pass --help for longer descriptions.",
    after_long_help = "",
)]
pub struct CommandLineArgs {
    #[arg(long, short = 'm', value_name = "M")]
    /// Maximum iteration count
    max_iter : Option<usize>,

    #[arg(long, short = 'n', value_name = "N")]
    /// Output status every N iterations. Set to 0 to disable.
    verbose_iter : Option<usize>,

    #[arg(long, short = 'q')]
    /// Don't display iteration progress
    quiet : bool,

    /// List of experiments to perform.
    #[arg(value_enum, value_name = "EXPERIMENT",
           default_values_t = [Experiment1D, Experiment1DFast,
                               Experiment2D, Experiment2DFast,
                               Experiment1D_L1])]
    experiments : Vec<DefaultExperiment>,

    /// Default algorithm configration(s) to use on the experiments.
    ///
    /// Not all algorithms are available for all the experiments.
    /// In particular, only PDPS is available for the experiments with L¹ data term.
    #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
           default_values_t = [FB, FISTA, PDPS, FW, FWRelax])]
    algorithm : Vec<DefaultAlgorithm>,

    /// Saved algorithm configration(s) to use on the experiments
    #[arg(value_name = "JSON_FILE", long)]
    saved_algorithm : Vec<String>,

    /// Write plots for every verbose iteration
    #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)]
    plot : PlotLevel,

    /// Directory for saving results
    #[arg(long, short = 'o', default_value = "out")]
    outdir : String,

    #[arg(long, help_heading = "Multi-threading", default_value = "4")]
    /// Maximum number of threads
    max_threads : usize,

    #[arg(long, help_heading = "Multi-threading")]
    /// Number of threads. Overrides the maximum number.
    num_threads : Option<usize>,

    #[clap(flatten, next_help_heading = "Experiment overrides")]
    /// Experiment setup overrides
    experiment_overrides : ExperimentOverrides<float>,

    #[clap(flatten, next_help_heading = "Algorithm overrides")]
    /// Algorithm parametrisation overrides
    algoritm_overrides : AlgorithmOverrides<float>,
}

/// Command line experiment setup overrides
#[derive(Parser, Debug)]
pub struct ExperimentOverrides<F : ClapFloat> {
    #[arg(long)]
    /// Regularisation parameter override.
    ///
    /// Only use if running just a single experiment, as different experiments have different
    /// regularisation parameters.
    alpha : Option<F>,

    #[arg(long)]
    /// Gaussian noise variance override
    variance : Option<F>,

    #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])]
    /// Salt and pepper noise override.
    salt_and_pepper : Option<Vec<F>>,

    #[arg(long)]
    /// Noise seed
    noise_seed : Option<u64>,
}

/// Command line algorithm parametrisation overrides
#[derive(Parser, Debug)]
pub struct AlgorithmOverrides<F : ClapFloat> {
    #[arg(long, value_names = &["COUNT", "EACH"])]
    /// Override bootstrap insertion iterations for --algorithm.
    ///
    /// The first parameter is the number of bootstrap insertion iterations, and the second
    /// the maximum number of iterations on each of them.
    bootstrap_insertions : Option<Vec<usize>>,

    #[arg(long, requires = "algorithm")]
    /// Primal step length parameter override for --algorithm.
    ///
    /// Only use if running just a single algorithm, as different algorithms have different
    /// regularisation parameters. Does not affect the algorithms fw and fwrelax.
    tau0 : Option<F>,

    #[arg(long, requires = "algorithm")]
    /// Dual step length parameter override for --algorithm.
    ///
    /// Only use if running just a single algorithm, as different algorithms have different
    /// regularisation parameters. Only affects PDPS.
    sigma0 : Option<F>,

    #[arg(value_enum, long)]
    /// PDPS acceleration, when available.
    acceleration : Option<pdps::Acceleration>,

    #[arg(long)]
    /// Perform postprocess weight optimisation for saved iterations
    ///
    /// Only affects FB, FISTA, and PDPS.
    postprocessing : Option<bool>,

    #[arg(value_name = "n", long)]
    /// Merging frequency, if merging enabled (every n iterations)
    ///
    /// Only affects FB, FISTA, and PDPS.
    merge_every : Option<usize>,

    #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())]
    /// Merging strategy
    ///
    /// Either the string "none", or a radius value for heuristic merging.
    merging : Option<SpikeMergingMethod<F>>,

    #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())]
    /// Final merging strategy
    ///
    /// Either the string "none", or a radius value for heuristic merging.
    /// Only affects FB, FISTA, and PDPS.
    final_merging : Option<SpikeMergingMethod<F>>,
}

/// The entry point for the program.
pub fn main() {
    let cli = CommandLineArgs::parse();

    if let Some(n_threads) = cli.num_threads {
        let n = NonZeroUsize::new(n_threads).expect("Invalid thread count");
        set_num_threads(n);
    } else {
        let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count");
        set_max_threads(m);
    }

    for experiment_shorthand in cli.experiments.iter().unique() {
        let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap();
        let mut config : Configuration<float> = experiment.default_config();
        let mut algs : Vec<Named<AlgorithmConfig<float>>>
            = cli.algorithm.iter()
                            .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides))
                            .collect();
        for filename in cli.saved_algorithm.iter() {
            let f = std::fs::File::open(filename).unwrap();
            let alg = serde_json::from_reader(f).unwrap();
            algs.push(alg);
        }
        cli.max_iter.map(|m| config.iterator_options.max_iter = m);
        cli.verbose_iter.map(|n| config.iterator_options.verbose_iter = Verbose::Every(n));
        config.plot = cli.plot;
        config.iterator_options.quiet = cli.quiet;
        config.outdir = cli.outdir.clone();
        if !algs.is_empty() {
            config.algorithms = algs.clone();
        }

        experiment.runall(config)
                  .unwrap()
    }
}

mercurial