src/lib.rs

Thu, 26 Feb 2026 13:05:07 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 26 Feb 2026 13:05:07 -0500
branch
dev
changeset 66
fe47ad484deb
parent 63
7a8a55fd41c0
child 67
95bb12bdb6ac
permissions
-rw-r--r--

Allow fitness merge when forward_pdps and sliding_pdps are used as forward-backward with aux variable.

// The main documentation is in the README.
// We need to uglify it in build.rs because rustdoc is stuck in the past.
#![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))]
// We use unicode. We would like to use much more of it than Rust allows.
// Live with it. Embrace it.
#![allow(uncommon_codepoints)]
#![allow(mixed_script_confusables)]
#![allow(confusable_idents)]
// Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical
// convention while referring to the type (trait) of the operator as `A`.
#![allow(non_snake_case)]
// Need to create parse errors
#![feature(dec2flt)]

use alg_tools::error::DynResult;
use alg_tools::parallelism::{set_max_threads, set_num_threads};
use clap::Parser;
use serde::{Deserialize, Serialize};
use serde_json;
use serde_with::skip_serializing_none;
use std::num::NonZeroUsize;

//#[cfg(feature = "pyo3")]
//use pyo3::pyclass;

pub mod dataterm;
pub mod experiments;
pub mod fb;
pub mod forward_model;
pub mod forward_pdps;
pub mod fourier;
pub mod frank_wolfe;
pub mod kernels;
pub mod pdps;
pub mod plot;
pub mod preadjoint_helper;
pub mod prox_penalty;
pub mod rand_distr;
pub mod regularisation;
pub mod run;
pub mod seminorms;
pub mod sliding_fb;
pub mod sliding_pdps;
pub mod subproblem;
pub mod tolerance;
pub mod types;

pub mod measures {
    pub use measures::*;
}

use run::{AlgorithmConfig, DefaultAlgorithm, Named, PlotLevel, RunnableExperiment};
use subproblem::InnerMethod;
use types::{ClapFloat, Float};
use DefaultAlgorithm::*;

/// Trait for customising the experiments available from the command line
pub trait ExperimentSetup:
    clap::Args + Send + Sync + 'static + Serialize + for<'a> Deserialize<'a>
{
    /// Type of floating point numbers to be used.
    type FloatType: Float + ClapFloat + for<'b> Deserialize<'b>;

    fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>>;
}

/// Command line parameters
#[skip_serializing_none]
#[derive(Parser, Debug, Serialize, Default, Clone)]
pub struct CommandLineArgs {
    #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)]
    /// Maximum iteration count
    max_iter: usize,

    #[arg(long, short = 'n', value_name = "N")]
    /// Output status every N iterations. Set to 0 to disable.
    ///
    /// The default is to output status based on logarithmic increments.
    verbose_iter: Option<usize>,

    #[arg(long, short = 'q')]
    /// Don't display iteration progress
    quiet: bool,

    /// Default algorithm configration(s) to use on the experiments.
    ///
    /// Not all algorithms are available for all the experiments.
    /// In particular, only PDPS is available for the experiments with L¹ data term.
    #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
           default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])]
    algorithm: Vec<DefaultAlgorithm>,

    /// Saved algorithm configration(s) to use on the experiments
    #[arg(value_name = "JSON_FILE", long)]
    saved_algorithm: Vec<String>,

    /// Plot saving scheme
    #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)]
    plot: PlotLevel,

    /// Directory for saving results
    #[arg(long, short = 'o', required = true, default_value = "out")]
    outdir: String,

    #[arg(long, help_heading = "Multi-threading", default_value = "4")]
    /// Maximum number of threads
    max_threads: usize,

    #[arg(long, help_heading = "Multi-threading")]
    /// Number of threads. Overrides the maximum number.
    num_threads: Option<usize>,

    #[arg(long, default_value_t = false)]
    /// Load saved value ranges (if exists) to do partial update.
    load_valuerange: bool,
}

#[derive(Parser, Debug, Serialize, Default, Clone)]
#[clap(
    about = env!("CARGO_PKG_DESCRIPTION"),
    author = env!("CARGO_PKG_AUTHORS"),
    version = env!("CARGO_PKG_VERSION"),
    after_help = "Pass --help for longer descriptions.",
    after_long_help = "",
)]
struct FusedCommandLineArgs<E: ExperimentSetup> {
    /// List of experiments to perform.
    #[clap(flatten, next_help_heading = "Experiment setup")]
    experiment_setup: E,

    #[clap(flatten, next_help_heading = "General parameters")]
    general: CommandLineArgs,

    #[clap(flatten, next_help_heading = "Algorithm overrides")]
    /// Algorithm parametrisation overrides
    algorithm_overrides: AlgorithmOverrides<E::FloatType>,
}

/// Command line algorithm parametrisation overrides
#[skip_serializing_none]
#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
//#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))]
pub struct AlgorithmOverrides<F: ClapFloat> {
    #[arg(long, value_names = &["COUNT", "EACH"])]
    /// Override bootstrap insertion iterations for --algorithm.
    ///
    /// The first parameter is the number of bootstrap insertion iterations, and the second
    /// the maximum number of iterations on each of them.
    bootstrap_insertions: Option<Vec<usize>>,

    #[arg(long, requires = "algorithm")]
    /// Primal step length parameter override for --algorithm.
    ///
    /// Only use if running just a single algorithm, as different algorithms have different
    /// regularisation parameters. Does not affect the algorithms fw and fwrelax.
    tau0: Option<F>,

    #[arg(long, requires = "algorithm")]
    /// Second primal step length parameter override for SlidingPDPS.
    ///
    /// Only use if running just a single algorithm, as different algorithms have different
    /// regularisation parameters.
    sigmap0: Option<F>,

    #[arg(long, requires = "algorithm")]
    /// Dual step length parameter override for --algorithm.
    ///
    /// Only use if running just a single algorithm, as different algorithms have different
    /// regularisation parameters. Only affects PDPS.
    sigma0: Option<F>,

    #[arg(long)]
    /// Normalised transport step length for sliding methods.
    theta0: Option<F>,

    #[arg(long)]
    /// A posteriori transport tolerance multiplier (C_pos)
    transport_tolerance_pos: Option<F>,

    #[arg(long)]
    /// Transport adaptation factor. Must be in (0, 1).
    transport_adaptation: Option<F>,

    #[arg(long)]
    /// Minimal step length parameter for sliding methods.
    tau0_min: Option<F>,

    #[arg(value_enum, long)]
    /// PDPS acceleration, when available.
    acceleration: Option<pdps::Acceleration>,

    // #[arg(long)]
    // /// Perform postprocess weight optimisation for saved iterations
    // ///
    // /// Only affects FB, FISTA, and PDPS.
    // postprocessing : Option<bool>,
    #[arg(value_name = "n", long)]
    /// Merging frequency, if merging enabled (every n iterations)
    ///
    /// Only affects FB, FISTA, and PDPS.
    merge_every: Option<usize>,

    #[arg(long)]
    /// Enable merging (default: determined by algorithm)
    merge: Option<bool>,

    #[arg(long)]
    /// Merging radius (default: determined by experiment)
    merge_radius: Option<F>,

    #[arg(long)]
    /// Interpolate when merging (default : determined by algorithm)
    merge_interp: Option<bool>,

    #[arg(long)]
    /// Enable final merging (default: determined by algorithm)
    final_merging: Option<bool>,

    #[arg(long)]
    /// Enable fitness-based merging for relevant FB-type methods.
    /// This has worse convergence guarantees that merging based on optimality conditions.
    fitness_merging: Option<bool>,

    #[arg(long, value_names = &["ε", "θ", "p"])]
    /// Set the tolerance to ε_k = ε/(1+θk)^p
    tolerance: Option<Vec<F>>,

    #[arg(long)]
    /// Method for solving inner optimisation problems
    inner_method: Option<InnerMethod>,

    #[arg(long)]
    /// Step length parameter for inner problem
    inner_τ0: Option<F>,

    #[arg(long, value_names = &["τ0", "σ0"])]
    /// Dual step length parameter for inner problem
    inner_pdps_τσ0: Option<Vec<F>>,

    #[arg(long, value_names = &["τ", "growth"])]
    /// Inner proximal point method step length and its growth
    inner_pp_τ: Option<Vec<F>>,

    #[arg(long)]
    /// Inner tolerance multiplier
    inner_tol: Option<F>,
}

/// A generic entry point for binaries based on this library
pub fn common_main<E: ExperimentSetup>() -> DynResult<()> {
    let full_cli = FusedCommandLineArgs::<E>::parse();
    let cli = &full_cli.general;

    #[cfg(debug_assertions)]
    {
        use colored::Colorize;
        println!(
            "{}",
            format!(
                "\n\
            ********\n\
            WARNING: Compiled without optimisations; {}\n\
            Please recompile with `--release` flag.\n\
            ********\n\
            ",
                "performance will be poor!".blink()
            )
            .red()
        );
    }

    if let Some(n_threads) = cli.num_threads {
        let n = NonZeroUsize::new(n_threads).expect("Invalid thread count");
        set_num_threads(n);
    } else {
        let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count");
        set_max_threads(m);
    }

    for experiment in full_cli.experiment_setup.runnables()? {
        let mut algs: Vec<Named<AlgorithmConfig<E::FloatType>>> = cli
            .algorithm
            .iter()
            .map(|alg| {
                let cfg = alg
                    .default_config()
                    .cli_override(&experiment.algorithm_overrides(*alg))
                    .cli_override(&full_cli.algorithm_overrides);
                alg.to_named(cfg)
            })
            .collect();
        for filename in cli.saved_algorithm.iter() {
            let f = std::fs::File::open(filename)?;
            let alg = serde_json::from_reader(f)?;
            algs.push(alg);
        }
        experiment.runall(&cli, (!algs.is_empty()).then_some(algs))?;
    }

    Ok(())
}

mercurial