diff -r 9738b51d90d7 -r 4f468d35fa29 src/lib.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/lib.rs Thu Feb 26 11:38:43 2026 -0500 @@ -0,0 +1,280 @@ +// The main documentation is in the README. +// We need to uglify it in build.rs because rustdoc is stuck in the past. +#![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))] +// We use unicode. We would like to use much more of it than Rust allows. +// Live with it. Embrace it. +#![allow(uncommon_codepoints)] +#![allow(mixed_script_confusables)] +#![allow(confusable_idents)] +// Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical +// convention while referring to the type (trait) of the operator as `A`. +#![allow(non_snake_case)] +// Need to create parse errors +#![feature(dec2flt)] + +use alg_tools::error::DynResult; +use alg_tools::parallelism::{set_max_threads, set_num_threads}; +use clap::Parser; +use serde::{Deserialize, Serialize}; +use serde_json; +use serde_with::skip_serializing_none; +use std::num::NonZeroUsize; + +//#[cfg(feature = "pyo3")] +//use pyo3::pyclass; + +pub mod dataterm; +pub mod experiments; +pub mod fb; +pub mod forward_model; +pub mod forward_pdps; +pub mod fourier; +pub mod frank_wolfe; +pub mod kernels; +pub mod pdps; +pub mod plot; +pub mod preadjoint_helper; +pub mod prox_penalty; +pub mod rand_distr; +pub mod regularisation; +pub mod run; +pub mod seminorms; +pub mod sliding_fb; +pub mod sliding_pdps; +pub mod subproblem; +pub mod tolerance; +pub mod types; + +pub mod measures { + pub use measures::*; +} + +use run::{AlgorithmConfig, DefaultAlgorithm, Named, PlotLevel, RunnableExperiment}; +use types::{ClapFloat, Float}; +use DefaultAlgorithm::*; + +/// Trait for customising the experiments available from the command line +pub trait ExperimentSetup: + clap::Args + Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> +{ + /// Type of floating point numbers to be used. + type FloatType: Float + ClapFloat + for<'b> Deserialize<'b>; + + fn runnables(&self) -> DynResult>>>; +} + +/// Command line parameters +#[skip_serializing_none] +#[derive(Parser, Debug, Serialize, Default, Clone)] +pub struct CommandLineArgs { + #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)] + /// Maximum iteration count + max_iter: usize, + + #[arg(long, short = 'n', value_name = "N")] + /// Output status every N iterations. Set to 0 to disable. + /// + /// The default is to output status based on logarithmic increments. + verbose_iter: Option, + + #[arg(long, short = 'q')] + /// Don't display iteration progress + quiet: bool, + + /// Default algorithm configration(s) to use on the experiments. + /// + /// Not all algorithms are available for all the experiments. + /// In particular, only PDPS is available for the experiments with L¹ data term. + #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', + default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])] + algorithm: Vec, + + /// Saved algorithm configration(s) to use on the experiments + #[arg(value_name = "JSON_FILE", long)] + saved_algorithm: Vec, + + /// Plot saving scheme + #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] + plot: PlotLevel, + + /// Directory for saving results + #[arg(long, short = 'o', required = true, default_value = "out")] + outdir: String, + + #[arg(long, help_heading = "Multi-threading", default_value = "4")] + /// Maximum number of threads + max_threads: usize, + + #[arg(long, help_heading = "Multi-threading")] + /// Number of threads. Overrides the maximum number. + num_threads: Option, + + #[arg(long, default_value_t = false)] + /// Load saved value ranges (if exists) to do partial update. + load_valuerange: bool, +} + +#[derive(Parser, Debug, Serialize, Default, Clone)] +#[clap( + about = env!("CARGO_PKG_DESCRIPTION"), + author = env!("CARGO_PKG_AUTHORS"), + version = env!("CARGO_PKG_VERSION"), + after_help = "Pass --help for longer descriptions.", + after_long_help = "", +)] +struct FusedCommandLineArgs { + /// List of experiments to perform. + #[clap(flatten, next_help_heading = "Experiment setup")] + experiment_setup: E, + + #[clap(flatten, next_help_heading = "General parameters")] + general: CommandLineArgs, + + #[clap(flatten, next_help_heading = "Algorithm overrides")] + /// Algorithm parametrisation overrides + algorithm_overrides: AlgorithmOverrides, +} + +/// Command line algorithm parametrisation overrides +#[skip_serializing_none] +#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] +//#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))] +pub struct AlgorithmOverrides { + #[arg(long, value_names = &["COUNT", "EACH"])] + /// Override bootstrap insertion iterations for --algorithm. + /// + /// The first parameter is the number of bootstrap insertion iterations, and the second + /// the maximum number of iterations on each of them. + bootstrap_insertions: Option>, + + #[arg(long, requires = "algorithm")] + /// Primal step length parameter override for --algorithm. + /// + /// Only use if running just a single algorithm, as different algorithms have different + /// regularisation parameters. Does not affect the algorithms fw and fwrelax. + tau0: Option, + + #[arg(long, requires = "algorithm")] + /// Second primal step length parameter override for SlidingPDPS. + /// + /// Only use if running just a single algorithm, as different algorithms have different + /// regularisation parameters. + sigmap0: Option, + + #[arg(long, requires = "algorithm")] + /// Dual step length parameter override for --algorithm. + /// + /// Only use if running just a single algorithm, as different algorithms have different + /// regularisation parameters. Only affects PDPS. + sigma0: Option, + + #[arg(long)] + /// Normalised transport step length for sliding methods. + theta0: Option, + + #[arg(long)] + /// A posteriori transport tolerance multiplier (C_pos) + transport_tolerance_pos: Option, + + #[arg(long)] + /// Transport adaptation factor. Must be in (0, 1). + transport_adaptation: Option, + + #[arg(long)] + /// Minimal step length parameter for sliding methods. + tau0_min: Option, + + #[arg(value_enum, long)] + /// PDPS acceleration, when available. + acceleration: Option, + + // #[arg(long)] + // /// Perform postprocess weight optimisation for saved iterations + // /// + // /// Only affects FB, FISTA, and PDPS. + // postprocessing : Option, + #[arg(value_name = "n", long)] + /// Merging frequency, if merging enabled (every n iterations) + /// + /// Only affects FB, FISTA, and PDPS. + merge_every: Option, + + #[arg(long)] + /// Enable merging (default: determined by algorithm) + merge: Option, + + #[arg(long)] + /// Merging radius (default: determined by experiment) + merge_radius: Option, + + #[arg(long)] + /// Interpolate when merging (default : determined by algorithm) + merge_interp: Option, + + #[arg(long)] + /// Enable final merging (default: determined by algorithm) + final_merging: Option, + + #[arg(long)] + /// Enable fitness-based merging for relevant FB-type methods. + /// This has worse convergence guarantees that merging based on optimality conditions. + fitness_merging: Option, + + #[arg(long, value_names = &["ε", "θ", "p"])] + /// Set the tolerance to ε_k = ε/(1+θk)^p + tolerance: Option>, +} + +/// A generic entry point for binaries based on this library +pub fn common_main() -> DynResult<()> { + let full_cli = FusedCommandLineArgs::::parse(); + let cli = &full_cli.general; + + #[cfg(debug_assertions)] + { + use colored::Colorize; + println!( + "{}", + format!( + "\n\ + ********\n\ + WARNING: Compiled without optimisations; {}\n\ + Please recompile with `--release` flag.\n\ + ********\n\ + ", + "performance will be poor!".blink() + ) + .red() + ); + } + + if let Some(n_threads) = cli.num_threads { + let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); + set_num_threads(n); + } else { + let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); + set_max_threads(m); + } + + for experiment in full_cli.experiment_setup.runnables()? { + let mut algs: Vec>> = cli + .algorithm + .iter() + .map(|alg| { + let cfg = alg + .default_config() + .cli_override(&experiment.algorithm_overrides(*alg)) + .cli_override(&full_cli.algorithm_overrides); + alg.to_named(cfg) + }) + .collect(); + for filename in cli.saved_algorithm.iter() { + let f = std::fs::File::open(filename)?; + let alg = serde_json::from_reader(f)?; + algs.push(alg); + } + experiment.runall(&cli, (!algs.is_empty()).then_some(algs))?; + } + + Ok(()) +}