Thu, 26 Feb 2026 11:38:43 -0500
General forward operators, separation of measures into own crate, and other architecture improvements to support the pointsource_pde crate.
// The main documentation is in the README. // We need to uglify it in build.rs because rustdoc is stuck in the past. #![doc = include_str!(concat!(env!("OUT_DIR"), "/README_uglified.md"))] // We use unicode. We would like to use much more of it than Rust allows. // Live with it. Embrace it. #![allow(uncommon_codepoints)] #![allow(mixed_script_confusables)] #![allow(confusable_idents)] // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical // convention while referring to the type (trait) of the operator as `A`. #![allow(non_snake_case)] // Need to create parse errors #![feature(dec2flt)] use alg_tools::error::DynResult; use alg_tools::parallelism::{set_max_threads, set_num_threads}; use clap::Parser; use serde::{Deserialize, Serialize}; use serde_json; use serde_with::skip_serializing_none; use std::num::NonZeroUsize; //#[cfg(feature = "pyo3")] //use pyo3::pyclass; pub mod dataterm; pub mod experiments; pub mod fb; pub mod forward_model; pub mod forward_pdps; pub mod fourier; pub mod frank_wolfe; pub mod kernels; pub mod pdps; pub mod plot; pub mod preadjoint_helper; pub mod prox_penalty; pub mod rand_distr; pub mod regularisation; pub mod run; pub mod seminorms; pub mod sliding_fb; pub mod sliding_pdps; pub mod subproblem; pub mod tolerance; pub mod types; pub mod measures { pub use measures::*; } use run::{AlgorithmConfig, DefaultAlgorithm, Named, PlotLevel, RunnableExperiment}; use types::{ClapFloat, Float}; use DefaultAlgorithm::*; /// Trait for customising the experiments available from the command line pub trait ExperimentSetup: clap::Args + Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> { /// Type of floating point numbers to be used. type FloatType: Float + ClapFloat + for<'b> Deserialize<'b>; fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>>; } /// Command line parameters #[skip_serializing_none] #[derive(Parser, Debug, Serialize, Default, Clone)] pub struct CommandLineArgs { #[arg(long, short = 'm', value_name = "M", default_value_t = 2000)] /// Maximum iteration count max_iter: usize, #[arg(long, short = 'n', value_name = "N")] /// Output status every N iterations. Set to 0 to disable. /// /// The default is to output status based on logarithmic increments. verbose_iter: Option<usize>, #[arg(long, short = 'q')] /// Don't display iteration progress quiet: bool, /// Default algorithm configration(s) to use on the experiments. /// /// Not all algorithms are available for all the experiments. /// In particular, only PDPS is available for the experiments with L¹ data term. #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])] algorithm: Vec<DefaultAlgorithm>, /// Saved algorithm configration(s) to use on the experiments #[arg(value_name = "JSON_FILE", long)] saved_algorithm: Vec<String>, /// Plot saving scheme #[arg(value_enum, long, short = 'p', default_value_t = PlotLevel::Data)] plot: PlotLevel, /// Directory for saving results #[arg(long, short = 'o', required = true, default_value = "out")] outdir: String, #[arg(long, help_heading = "Multi-threading", default_value = "4")] /// Maximum number of threads max_threads: usize, #[arg(long, help_heading = "Multi-threading")] /// Number of threads. Overrides the maximum number. num_threads: Option<usize>, #[arg(long, default_value_t = false)] /// Load saved value ranges (if exists) to do partial update. load_valuerange: bool, } #[derive(Parser, Debug, Serialize, Default, Clone)] #[clap( about = env!("CARGO_PKG_DESCRIPTION"), author = env!("CARGO_PKG_AUTHORS"), version = env!("CARGO_PKG_VERSION"), after_help = "Pass --help for longer descriptions.", after_long_help = "", )] struct FusedCommandLineArgs<E: ExperimentSetup> { /// List of experiments to perform. #[clap(flatten, next_help_heading = "Experiment setup")] experiment_setup: E, #[clap(flatten, next_help_heading = "General parameters")] general: CommandLineArgs, #[clap(flatten, next_help_heading = "Algorithm overrides")] /// Algorithm parametrisation overrides algorithm_overrides: AlgorithmOverrides<E::FloatType>, } /// Command line algorithm parametrisation overrides #[skip_serializing_none] #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] //#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))] pub struct AlgorithmOverrides<F: ClapFloat> { #[arg(long, value_names = &["COUNT", "EACH"])] /// Override bootstrap insertion iterations for --algorithm. /// /// The first parameter is the number of bootstrap insertion iterations, and the second /// the maximum number of iterations on each of them. bootstrap_insertions: Option<Vec<usize>>, #[arg(long, requires = "algorithm")] /// Primal step length parameter override for --algorithm. /// /// Only use if running just a single algorithm, as different algorithms have different /// regularisation parameters. Does not affect the algorithms fw and fwrelax. tau0: Option<F>, #[arg(long, requires = "algorithm")] /// Second primal step length parameter override for SlidingPDPS. /// /// Only use if running just a single algorithm, as different algorithms have different /// regularisation parameters. sigmap0: Option<F>, #[arg(long, requires = "algorithm")] /// Dual step length parameter override for --algorithm. /// /// Only use if running just a single algorithm, as different algorithms have different /// regularisation parameters. Only affects PDPS. sigma0: Option<F>, #[arg(long)] /// Normalised transport step length for sliding methods. theta0: Option<F>, #[arg(long)] /// A posteriori transport tolerance multiplier (C_pos) transport_tolerance_pos: Option<F>, #[arg(long)] /// Transport adaptation factor. Must be in (0, 1). transport_adaptation: Option<F>, #[arg(long)] /// Minimal step length parameter for sliding methods. tau0_min: Option<F>, #[arg(value_enum, long)] /// PDPS acceleration, when available. acceleration: Option<pdps::Acceleration>, // #[arg(long)] // /// Perform postprocess weight optimisation for saved iterations // /// // /// Only affects FB, FISTA, and PDPS. // postprocessing : Option<bool>, #[arg(value_name = "n", long)] /// Merging frequency, if merging enabled (every n iterations) /// /// Only affects FB, FISTA, and PDPS. merge_every: Option<usize>, #[arg(long)] /// Enable merging (default: determined by algorithm) merge: Option<bool>, #[arg(long)] /// Merging radius (default: determined by experiment) merge_radius: Option<F>, #[arg(long)] /// Interpolate when merging (default : determined by algorithm) merge_interp: Option<bool>, #[arg(long)] /// Enable final merging (default: determined by algorithm) final_merging: Option<bool>, #[arg(long)] /// Enable fitness-based merging for relevant FB-type methods. /// This has worse convergence guarantees that merging based on optimality conditions. fitness_merging: Option<bool>, #[arg(long, value_names = &["ε", "θ", "p"])] /// Set the tolerance to ε_k = ε/(1+θk)^p tolerance: Option<Vec<F>>, } /// A generic entry point for binaries based on this library pub fn common_main<E: ExperimentSetup>() -> DynResult<()> { let full_cli = FusedCommandLineArgs::<E>::parse(); let cli = &full_cli.general; #[cfg(debug_assertions)] { use colored::Colorize; println!( "{}", format!( "\n\ ********\n\ WARNING: Compiled without optimisations; {}\n\ Please recompile with `--release` flag.\n\ ********\n\ ", "performance will be poor!".blink() ) .red() ); } if let Some(n_threads) = cli.num_threads { let n = NonZeroUsize::new(n_threads).expect("Invalid thread count"); set_num_threads(n); } else { let m = NonZeroUsize::new(cli.max_threads).expect("Invalid maximum thread count"); set_max_threads(m); } for experiment in full_cli.experiment_setup.runnables()? { let mut algs: Vec<Named<AlgorithmConfig<E::FloatType>>> = cli .algorithm .iter() .map(|alg| { let cfg = alg .default_config() .cli_override(&experiment.algorithm_overrides(*alg)) .cli_override(&full_cli.algorithm_overrides); alg.to_named(cfg) }) .collect(); for filename in cli.saved_algorithm.iter() { let f = std::fs::File::open(filename)?; let alg = serde_json::from_reader(f)?; algs.push(alg); } experiment.runall(&cli, (!algs.is_empty()).then_some(algs))?; } Ok(()) }