--- a/src/main.rs Tue Aug 01 10:25:09 2023 +0300 +++ b/src/main.rs Mon Feb 17 13:54:53 2025 -0500 @@ -10,12 +10,13 @@ // Linear operators may be written e.g. as `opA`, to keep the capital letters of mathematical // convention while referring to the type (trait) of the operator as `A`. #![allow(non_snake_case)] -// We need the drain filter for inertial prune. -#![feature(drain_filter)] +// Need to create parse errors +#![feature(dec2flt)] use clap::Parser; use serde::{Serialize, Deserialize}; use serde_json; +use serde_with::skip_serializing_none; use itertools::Itertools; use std::num::NonZeroUsize; @@ -30,11 +31,17 @@ pub mod kernels; pub mod seminorms; pub mod forward_model; +pub mod preadjoint_helper; pub mod plot; pub mod subproblem; pub mod tolerance; pub mod regularisation; +pub mod dataterm; +pub mod prox_penalty; pub mod fb; +pub mod sliding_fb; +pub mod sliding_pdps; +pub mod forward_pdps; pub mod frank_wolfe; pub mod pdps; pub mod run; @@ -49,12 +56,12 @@ AlgorithmConfig, }; use experiments::DefaultExperiment; -use measures::merging::SpikeMergingMethod; use DefaultExperiment::*; use DefaultAlgorithm::*; /// Command line parameters -#[derive(Parser, Debug, Serialize)] +#[skip_serializing_none] +#[derive(Parser, Debug, Serialize, Default, Clone)] #[clap( about = env!("CARGO_PKG_DESCRIPTION"), author = env!("CARGO_PKG_AUTHORS"), @@ -89,7 +96,7 @@ /// Not all algorithms are available for all the experiments. /// In particular, only PDPS is available for the experiments with L¹ data term. #[arg(value_enum, value_name = "ALGORITHM", long, short = 'a', - default_values_t = [FB, FISTA, PDPS, FW, FWRelax])] + default_values_t = [FB, PDPS, SlidingFB, FW, RadonFB])] algorithm : Vec<DefaultAlgorithm>, /// Saved algorithm configration(s) to use on the experiments @@ -112,6 +119,10 @@ /// Number of threads. Overrides the maximum number. num_threads : Option<usize>, + #[arg(long, default_value_t = false)] + /// Load saved value ranges (if exists) to do partial update. + load_valuerange : bool, + #[clap(flatten, next_help_heading = "Experiment overrides")] /// Experiment setup overrides experiment_overrides : ExperimentOverrides<float>, @@ -122,7 +133,8 @@ } /// Command line experiment setup overrides -#[derive(Parser, Debug, Serialize, Deserialize)] +#[skip_serializing_none] +#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] pub struct ExperimentOverrides<F : ClapFloat> { #[arg(long)] /// Regularisation parameter override. @@ -145,7 +157,8 @@ } /// Command line algorithm parametrisation overrides -#[derive(Parser, Debug, Serialize, Deserialize)] +#[skip_serializing_none] +#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] pub struct AlgorithmOverrides<F : ClapFloat> { #[arg(long, value_names = &["COUNT", "EACH"])] /// Override bootstrap insertion iterations for --algorithm. @@ -162,21 +175,44 @@ tau0 : Option<F>, #[arg(long, requires = "algorithm")] + /// Second primal step length parameter override for SlidingPDPS. + /// + /// Only use if running just a single algorithm, as different algorithms have different + /// regularisation parameters. + sigmap0 : Option<F>, + + #[arg(long, requires = "algorithm")] /// Dual step length parameter override for --algorithm. /// /// Only use if running just a single algorithm, as different algorithms have different /// regularisation parameters. Only affects PDPS. sigma0 : Option<F>, + #[arg(long)] + /// Normalised transport step length for sliding methods. + theta0 : Option<F>, + + #[arg(long)] + /// A posteriori transport tolerance multiplier (C_pos) + transport_tolerance_pos : Option<F>, + + #[arg(long)] + /// Transport adaptation factor. Must be in (0, 1). + transport_adaptation : Option<F>, + + #[arg(long)] + /// Minimal step length parameter for sliding methods. + tau0_min : Option<F>, + #[arg(value_enum, long)] /// PDPS acceleration, when available. acceleration : Option<pdps::Acceleration>, - #[arg(long)] - /// Perform postprocess weight optimisation for saved iterations - /// - /// Only affects FB, FISTA, and PDPS. - postprocessing : Option<bool>, + // #[arg(long)] + // /// Perform postprocess weight optimisation for saved iterations + // /// + // /// Only affects FB, FISTA, and PDPS. + // postprocessing : Option<bool>, #[arg(value_name = "n", long)] /// Merging frequency, if merging enabled (every n iterations) @@ -184,18 +220,26 @@ /// Only affects FB, FISTA, and PDPS. merge_every : Option<usize>, - #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] - /// Merging strategy - /// - /// Either the string "none", or a radius value for heuristic merging. - merging : Option<SpikeMergingMethod<F>>, + #[arg(long)] + /// Enable merging (default: determined by algorithm) + merge : Option<bool>, + + #[arg(long)] + /// Merging radius (default: determined by experiment) + merge_radius : Option<F>, - #[arg(value_enum, long)]//, value_parser = SpikeMergingMethod::<float>::value_parser())] - /// Final merging strategy - /// - /// Either the string "none", or a radius value for heuristic merging. - /// Only affects FB, FISTA, and PDPS. - final_merging : Option<SpikeMergingMethod<F>>, + #[arg(long)] + /// Interpolate when merging (default : determined by algorithm) + merge_interp : Option<bool>, + + #[arg(long)] + /// Enable final merging (default: determined by algorithm) + final_merging : Option<bool>, + + #[arg(long)] + /// Enable fitness-based merging for relevant FB-type methods. + /// This has worse convergence guarantees that merging based on optimality conditions. + fitness_merging : Option<bool>, #[arg(long, value_names = &["ε", "θ", "p"])] /// Set the tolerance to ε_k = ε/(1+θk)^p @@ -230,9 +274,15 @@ for experiment_shorthand in cli.experiments.iter().unique() { let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap(); let mut algs : Vec<Named<AlgorithmConfig<float>>> - = cli.algorithm.iter() - .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides)) - .collect(); + = cli.algorithm + .iter() + .map(|alg| { + let cfg = alg.default_config() + .cli_override(&experiment.algorithm_overrides(*alg)) + .cli_override(&cli.algoritm_overrides); + alg.to_named(cfg) + }) + .collect(); for filename in cli.saved_algorithm.iter() { let f = std::fs::File::open(filename).unwrap(); let alg = serde_json::from_reader(f).unwrap();