src/main.rs

Thu, 23 Jan 2025 23:34:05 +0100

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 23 Jan 2025 23:34:05 +0100
branch
dev
changeset 39
6316d68b58af
parent 37
c5d8bd1a7728
child 41
b6bdb6cb4d44
permissions
-rw-r--r--

Merging adjustments, parameter tuning, etc.

// The main documentation is in the README.
// We need to uglify it in build.rs because rustdoc is stuck in the past.
#![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))]

// We use unicode. We would like to use much more of it than Rust allows.
// Live with it. Embrace it.
#![allow(uncommon_codepoints)]
#![allow(mixed_script_confusables)]
#![allow(confusable_idents)]
// Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical
// convention while referring to the type (trait) of the operator as `A`.
#![allow(non_snake_case)]
// Need to create parse errors
#![feature(dec2flt)]

use clap::Parser;
use serde::{Serialize, Deserialize};
use serde_json;
use serde_with::skip_serializing_none;
use itertools::Itertools;
use std::num::NonZeroUsize;

use alg_tools::parallelism::{
    set_num_threads,
    set_max_threads,
};

pub mod types;
pub mod measures;
pub mod fourier;
pub mod kernels;
pub mod seminorms;
pub mod transport;
pub mod forward_model;
pub mod preadjoint_helper;
pub mod plot;
pub mod subproblem;
pub mod tolerance;
pub mod regularisation;
pub mod dataterm;
pub mod prox_penalty;
pub mod fb;
pub mod sliding_fb;
pub mod sliding_pdps;
pub mod forward_pdps;
pub mod frank_wolfe;
pub mod pdps;
pub mod run;
pub mod rand_distr;
pub mod experiments;

use types::{float, ClapFloat};
use run::{
    DefaultAlgorithm,
    PlotLevel,
    Named,
    AlgorithmConfig,
};
use experiments::DefaultExperiment;
use DefaultExperiment::*;
use DefaultAlgorithm::*;

/// Command line parameters
#[skip_serializing_none]
#[derive(Parser, Debug, Serialize, Default, Clone)]
#[clap(
    about = env!("CARGO_PKG_DESCRIPTION"),
    author = env!("CARGO_PKG_AUTHORS"),
    version = env!("CARGO_PKG_VERSION"),
    after_help = "Pass --help for longer descriptions.",
    after_long_help = "",
)]
pub struct CommandLineArgs {
    #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)]
    /// Maximum iteration count
    max_iter : usize,

    #[arg(long, short = 'n', value_name = "N")]
    /// Output status every N iterations. Set to 0 to disable.
    ///
    /// The default is to output status based on logarithmic increments.
    verbose_iter : Option<usize>,

    #[arg(long, short = 'q')]
    /// Don't display iteration progress
    quiet : bool,

    /// List of experiments to perform.
    #[arg(value_enum, value_name = "EXPERIMENT",
           default_values_t = [Experiment1D, Experiment1DFast,
                               Experiment2D, Experiment2DFast,
                               Experiment1D_L1])]
    experiments : Vec<DefaultExperiment>,

    /// Default algorithm configration(s) to use on the experiments.
    ///
    /// Not all algorithms are available for all the experiments.
    /// In particular, only PDPS is available for the experiments with L¹ data term.
    #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
           default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])]
    algorithm : Vec<DefaultAlgorithm>,

    /// Saved algorithm configration(s) to use on the experiments
    #[arg(value_name = "JSON_FILE", long)]
    saved_algorithm : Vec<String>,

    /// Plot saving scheme
    #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)]
    plot : PlotLevel,

    /// Directory for saving results
    #[arg(long, short = 'o', required = true, default_value = "out")]
    outdir : String,

    #[arg(long, help_heading = "Multi-threading", default_value = "4")]
    /// Maximum number of threads
    max_threads : usize,

    #[arg(long, help_heading = "Multi-threading")]
    /// Number of threads. Overrides the maximum number.
    num_threads : Option<usize>,

    #[arg(long, default_value_t = false)]
    /// Load saved value ranges (if exists) to do partial update.
    load_valuerange : bool,

    #[clap(flatten, next_help_heading = "Experiment overrides")]
    /// Experiment setup overrides
    experiment_overrides : ExperimentOverrides<float>,

    #[clap(flatten, next_help_heading = "Algorithm overrides")]
    /// Algorithm parametrisation overrides
    algoritm_overrides : AlgorithmOverrides<float>,
}

/// Command line experiment setup overrides
#[skip_serializing_none]
#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
pub struct ExperimentOverrides<F : ClapFloat> {
    #[arg(long)]
    /// Regularisation parameter override.
    ///
    /// Only use if running just a single experiment, as different experiments have different
    /// regularisation parameters.
    alpha : Option<F>,

    #[arg(long)]
    /// Gaussian noise variance override
    variance : Option<F>,

    #[arg(long, value_names = &["MAGNITUDE", "PROBABILITY"])]
    /// Salt and pepper noise override.
    salt_and_pepper : Option<Vec<F>>,

    #[arg(long)]
    /// Noise seed
    noise_seed : Option<u64>,
}

/// Command line algorithm parametrisation overrides
#[skip_serializing_none]
#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
pub struct AlgorithmOverrides<F : ClapFloat> {
    #[arg(long, value_names = &["COUNT", "EACH"])]
    /// Override bootstrap insertion iterations for --algorithm.
    ///
    /// The first parameter is the number of bootstrap insertion iterations, and the second
    /// the maximum number of iterations on each of them.
    bootstrap_insertions : Option<Vec<usize>>,

    #[arg(long, requires = "algorithm")]
    /// Primal step length parameter override for --algorithm.
    ///
    /// Only use if running just a single algorithm, as different algorithms have different
    /// regularisation parameters. Does not affect the algorithms fw and fwrelax.
    tau0 : Option<F>,

    #[arg(long, requires = "algorithm")]
    /// Second primal step length parameter override for SlidingPDPS.
    ///
    /// Only use if running just a single algorithm, as different algorithms have different
    /// regularisation parameters.
    sigmap0 : Option<F>,

    #[arg(long, requires = "algorithm")]
    /// Dual step length parameter override for --algorithm.
    ///
    /// Only use if running just a single algorithm, as different algorithms have different
    /// regularisation parameters. Only affects PDPS.
    sigma0 : Option<F>,

    #[arg(long)]
    /// Normalised transport step length for sliding methods.
    theta0 : Option<F>,

    #[arg(long)]
    /// A priori transport tolerance multiplier (C_pri)
    transport_tolerance_pri : Option<F>,

    #[arg(long)]
    /// A posteriori transport tolerance multiplier (C_pos)
    transport_tolerance_pos : Option<F>,

    #[arg(long)]
    /// Transport adaptation factor. Must be in (0, 1).
    transport_adaptation : Option<F>,

    #[arg(long)]
    /// Minimal step length parameter for sliding methods.
    tau0_min : Option<F>,

    #[arg(value_enum, long)]
    /// PDPS acceleration, when available.
    acceleration : Option<pdps::Acceleration>,

    // #[arg(long)]
    // /// Perform postprocess weight optimisation for saved iterations
    // ///
    // /// Only affects FB, FISTA, and PDPS.
    // postprocessing : Option<bool>,

    #[arg(value_name = "n", long)]
    /// Merging frequency, if merging enabled (every n iterations)
    ///
    /// Only affects FB, FISTA, and PDPS.
    merge_every : Option<usize>,

    #[arg(long)]
    /// Enable merging (default: determined by algorithm)
    merge : Option<bool>,

    #[arg(long)]
    /// Merging radius (default: determined by experiment)
    merge_radius : Option<F>,

    #[arg(long)]
    /// Interpolate when merging (default : determined by algorithm)
    merge_interp : Option<bool>,

    #[arg(long)]
    /// Enable final merging (default: determined by algorithm)
    final_merging : Option<bool>,

    #[arg(long)]
    /// Enable fitness-based merging for relevant FB-type methods.
    /// This has worse convergence guarantees that merging based on optimality conditions.
    fitness_merging : Option<bool>,

    #[arg(long, value_names = &["ε", "θ", "p"])]
    /// Set the tolerance to ε_k = ε/(1+θk)^p
    tolerance : Option<Vec<F>>,

}

/// The entry point for the program.
pub fn main() {
    let cli = CommandLineArgs::parse();

    #[cfg(debug_assertions)]
    {
        use colored::Colorize;
        println!("{}", format!("\n\
            ********\n\
            WARNING: Compiled without optimisations; {}\n\
            Please recompile with `--release` flag.\n\
            ********\n\
            ", "performance will be poor!".blink()
        ).red());
    }

    if let Some(n_threads) = cli.num_threads {
        let n = NonZeroUsize::new(n_threads).expect("Invalid thread count");
        set_num_threads(n);
    } else {
        let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count");
        set_max_threads(m);
    }

    for experiment_shorthand in cli.experiments.iter().unique() {
        let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap();
        let mut algs : Vec<Named<AlgorithmConfig<float>>>
            = cli.algorithm
                 .iter()
                 .map(|alg| {
                    let cfg = alg.default_config()
                                 .cli_override(&experiment.algorithm_overrides(*alg))
                                 .cli_override(&cli.algoritm_overrides);
                    alg.to_named(cfg)
                 })
                 .collect();
        for filename in cli.saved_algorithm.iter() {
            let f = std::fs::File::open(filename).unwrap();
            let alg = serde_json::from_reader(f).unwrap();
            algs.push(alg);
        }
        experiment.runall(&cli, (!algs.is_empty()).then_some(algs))
                  .unwrap()
    }
}

mercurial