src/lib.rs

branch
dev
changeset 61
4f468d35fa29
child 63
7a8a55fd41c0
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/lib.rs	Thu Feb 26 11:38:43 2026 -0500
@@ -0,0 +1,280 @@
+// The main documentation is in the README.
+// We need to uglify it in build.rs because rustdoc is stuck in the past.
+#![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))]
+// We use unicode. We would like to use much more of it than Rust allows.
+// Live with it. Embrace it.
+#![allow(uncommon_codepoints)]
+#![allow(mixed_script_confusables)]
+#![allow(confusable_idents)]
+// Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical
+// convention while referring to the type (trait) of the operator as `A`.
+#![allow(non_snake_case)]
+// Need to create parse errors
+#![feature(dec2flt)]
+
+use alg_tools::error::DynResult;
+use alg_tools::parallelism::{set_max_threads, set_num_threads};
+use clap::Parser;
+use serde::{Deserialize, Serialize};
+use serde_json;
+use serde_with::skip_serializing_none;
+use std::num::NonZeroUsize;
+
+//#[cfg(feature = "pyo3")]
+//use pyo3::pyclass;
+
+pub mod dataterm;
+pub mod experiments;
+pub mod fb;
+pub mod forward_model;
+pub mod forward_pdps;
+pub mod fourier;
+pub mod frank_wolfe;
+pub mod kernels;
+pub mod pdps;
+pub mod plot;
+pub mod preadjoint_helper;
+pub mod prox_penalty;
+pub mod rand_distr;
+pub mod regularisation;
+pub mod run;
+pub mod seminorms;
+pub mod sliding_fb;
+pub mod sliding_pdps;
+pub mod subproblem;
+pub mod tolerance;
+pub mod types;
+
+pub mod measures {
+    pub use measures::*;
+}
+
+use run::{AlgorithmConfig, DefaultAlgorithm, Named, PlotLevel, RunnableExperiment};
+use types::{ClapFloat, Float};
+use DefaultAlgorithm::*;
+
+/// Trait for customising the experiments available from the command line
+pub trait ExperimentSetup:
+    clap::Args + Send + Sync + 'static + Serialize + for<'a> Deserialize<'a>
+{
+    /// Type of floating point numbers to be used.
+    type FloatType: Float + ClapFloat + for<'b> Deserialize<'b>;
+
+    fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>>;
+}
+
+/// Command line parameters
+#[skip_serializing_none]
+#[derive(Parser, Debug, Serialize, Default, Clone)]
+pub struct CommandLineArgs {
+    #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)]
+    /// Maximum iteration count
+    max_iter: usize,
+
+    #[arg(long, short = 'n', value_name = "N")]
+    /// Output status every N iterations. Set to 0 to disable.
+    ///
+    /// The default is to output status based on logarithmic increments.
+    verbose_iter: Option<usize>,
+
+    #[arg(long, short = 'q')]
+    /// Don't display iteration progress
+    quiet: bool,
+
+    /// Default algorithm configration(s) to use on the experiments.
+    ///
+    /// Not all algorithms are available for all the experiments.
+    /// In particular, only PDPS is available for the experiments with L¹ data term.
+    #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a',
+           default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])]
+    algorithm: Vec<DefaultAlgorithm>,
+
+    /// Saved algorithm configration(s) to use on the experiments
+    #[arg(value_name = "JSON_FILE", long)]
+    saved_algorithm: Vec<String>,
+
+    /// Plot saving scheme
+    #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)]
+    plot: PlotLevel,
+
+    /// Directory for saving results
+    #[arg(long, short = 'o', required = true, default_value = "out")]
+    outdir: String,
+
+    #[arg(long, help_heading = "Multi-threading", default_value = "4")]
+    /// Maximum number of threads
+    max_threads: usize,
+
+    #[arg(long, help_heading = "Multi-threading")]
+    /// Number of threads. Overrides the maximum number.
+    num_threads: Option<usize>,
+
+    #[arg(long, default_value_t = false)]
+    /// Load saved value ranges (if exists) to do partial update.
+    load_valuerange: bool,
+}
+
+#[derive(Parser, Debug, Serialize, Default, Clone)]
+#[clap(
+    about = env!("CARGO_PKG_DESCRIPTION"),
+    author = env!("CARGO_PKG_AUTHORS"),
+    version = env!("CARGO_PKG_VERSION"),
+    after_help = "Pass --help for longer descriptions.",
+    after_long_help = "",
+)]
+struct FusedCommandLineArgs<E: ExperimentSetup> {
+    /// List of experiments to perform.
+    #[clap(flatten, next_help_heading = "Experiment setup")]
+    experiment_setup: E,
+
+    #[clap(flatten, next_help_heading = "General parameters")]
+    general: CommandLineArgs,
+
+    #[clap(flatten, next_help_heading = "Algorithm overrides")]
+    /// Algorithm parametrisation overrides
+    algorithm_overrides: AlgorithmOverrides<E::FloatType>,
+}
+
+/// Command line algorithm parametrisation overrides
+#[skip_serializing_none]
+#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
+//#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))]
+pub struct AlgorithmOverrides<F: ClapFloat> {
+    #[arg(long, value_names = &["COUNT", "EACH"])]
+    /// Override bootstrap insertion iterations for --algorithm.
+    ///
+    /// The first parameter is the number of bootstrap insertion iterations, and the second
+    /// the maximum number of iterations on each of them.
+    bootstrap_insertions: Option<Vec<usize>>,
+
+    #[arg(long, requires = "algorithm")]
+    /// Primal step length parameter override for --algorithm.
+    ///
+    /// Only use if running just a single algorithm, as different algorithms have different
+    /// regularisation parameters. Does not affect the algorithms fw and fwrelax.
+    tau0: Option<F>,
+
+    #[arg(long, requires = "algorithm")]
+    /// Second primal step length parameter override for SlidingPDPS.
+    ///
+    /// Only use if running just a single algorithm, as different algorithms have different
+    /// regularisation parameters.
+    sigmap0: Option<F>,
+
+    #[arg(long, requires = "algorithm")]
+    /// Dual step length parameter override for --algorithm.
+    ///
+    /// Only use if running just a single algorithm, as different algorithms have different
+    /// regularisation parameters. Only affects PDPS.
+    sigma0: Option<F>,
+
+    #[arg(long)]
+    /// Normalised transport step length for sliding methods.
+    theta0: Option<F>,
+
+    #[arg(long)]
+    /// A posteriori transport tolerance multiplier (C_pos)
+    transport_tolerance_pos: Option<F>,
+
+    #[arg(long)]
+    /// Transport adaptation factor. Must be in (0, 1).
+    transport_adaptation: Option<F>,
+
+    #[arg(long)]
+    /// Minimal step length parameter for sliding methods.
+    tau0_min: Option<F>,
+
+    #[arg(value_enum, long)]
+    /// PDPS acceleration, when available.
+    acceleration: Option<pdps::Acceleration>,
+
+    // #[arg(long)]
+    // /// Perform postprocess weight optimisation for saved iterations
+    // ///
+    // /// Only affects FB, FISTA, and PDPS.
+    // postprocessing : Option<bool>,
+    #[arg(value_name = "n", long)]
+    /// Merging frequency, if merging enabled (every n iterations)
+    ///
+    /// Only affects FB, FISTA, and PDPS.
+    merge_every: Option<usize>,
+
+    #[arg(long)]
+    /// Enable merging (default: determined by algorithm)
+    merge: Option<bool>,
+
+    #[arg(long)]
+    /// Merging radius (default: determined by experiment)
+    merge_radius: Option<F>,
+
+    #[arg(long)]
+    /// Interpolate when merging (default : determined by algorithm)
+    merge_interp: Option<bool>,
+
+    #[arg(long)]
+    /// Enable final merging (default: determined by algorithm)
+    final_merging: Option<bool>,
+
+    #[arg(long)]
+    /// Enable fitness-based merging for relevant FB-type methods.
+    /// This has worse convergence guarantees that merging based on optimality conditions.
+    fitness_merging: Option<bool>,
+
+    #[arg(long, value_names = &["ε", "θ", "p"])]
+    /// Set the tolerance to ε_k = ε/(1+θk)^p
+    tolerance: Option<Vec<F>>,
+}
+
+/// A generic entry point for binaries based on this library
+pub fn common_main<E: ExperimentSetup>() -> DynResult<()> {
+    let full_cli = FusedCommandLineArgs::<E>::parse();
+    let cli = &full_cli.general;
+
+    #[cfg(debug_assertions)]
+    {
+        use colored::Colorize;
+        println!(
+            "{}",
+            format!(
+                "\n\
+            ********\n\
+            WARNING: Compiled without optimisations; {}\n\
+            Please recompile with `--release` flag.\n\
+            ********\n\
+            ",
+                "performance will be poor!".blink()
+            )
+            .red()
+        );
+    }
+
+    if let Some(n_threads) = cli.num_threads {
+        let n = NonZeroUsize::new(n_threads).expect("Invalid thread count");
+        set_num_threads(n);
+    } else {
+        let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count");
+        set_max_threads(m);
+    }
+
+    for experiment in full_cli.experiment_setup.runnables()? {
+        let mut algs: Vec<Named<AlgorithmConfig<E::FloatType>>> = cli
+            .algorithm
+            .iter()
+            .map(|alg| {
+                let cfg = alg
+                    .default_config()
+                    .cli_override(&experiment.algorithm_overrides(*alg))
+                    .cli_override(&full_cli.algorithm_overrides);
+                alg.to_named(cfg)
+            })
+            .collect();
+        for filename in cli.saved_algorithm.iter() {
+            let f = std::fs::File::open(filename)?;
+            let alg = serde_json::from_reader(f)?;
+            algs.push(alg);
+        }
+        experiment.runall(&cli, (!algs.is_empty()).then_some(algs))?;
+    }
+
+    Ok(())
+}

mercurial