--- a/src/run.rs Fri Dec 02 18:14:03 2022 +0200 +++ b/src/run.rs Fri Dec 02 21:20:04 2022 +0200 @@ -63,7 +63,7 @@ use crate::subproblem::InnerSettings; use crate::seminorms::*; use crate::plot::*; -use crate::AlgorithmOverrides; +use crate::{AlgorithmOverrides, CommandLineArgs}; /// Available algorithms and their configurations #[derive(Copy, Clone, Debug, Serialize, Deserialize)] @@ -120,7 +120,7 @@ } /// Shorthand algorithm configurations, to be used with the command line parser -#[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] pub enum DefaultAlgorithm { /// The μFB forward-backward method #[clap(name = "fb")] @@ -192,23 +192,6 @@ Iter, } -/// Algorithm and iterator config for the experiments - -#[derive(Clone, Debug, Serialize)] -#[serde(default)] -pub struct Configuration<F : Float> { - /// Algorithms to run - pub algorithms : Vec<Named<AlgorithmConfig<F>>>, - /// Options for algorithm step iteration (verbosity, etc.) - pub iterator_options : AlgIteratorOptions, - /// Plotting level - pub plot : PlotLevel, - /// Directory where to save results - pub outdir : String, - /// Bisection tree depth - pub bt_depth : DynamicDepth, -} - type DefaultBT<F, const N : usize> = BT< DynamicDepth, F, @@ -322,11 +305,9 @@ /// Trait for runnable experiments pub trait RunnableExperiment<F : ClapFloat> { - /// Run all algorithms of the [`Configuration`] `config` on the experiment. - fn runall(&self, config : Configuration<F>) -> DynError; - - /// Returns the default configuration - fn default_config(&self) -> Configuration<F>; + /// Run all algorithms provided, or default algorithms if none provided, on the experiment. + fn runall(&self, cli : &CommandLineArgs, + algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; /// Return algorithm default config fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) @@ -361,26 +342,9 @@ ) } - fn default_config(&self) -> Configuration<F> { - let default_alg = match self.data.dataterm { - DataTerm::L2Squared => DefaultAlgorithm::FB.get_named(), - DataTerm::L1 => DefaultAlgorithm::PDPS.get_named(), - }; - - Configuration{ - algorithms : vec![default_alg], - iterator_options : AlgIteratorOptions{ - max_iter : 2000, - verbose_iter : Verbose::Logarithmic(10), - quiet : false, - }, - plot : PlotLevel::Data, - outdir : "out".to_string(), - bt_depth : DynamicDepth(8), - } - } - - fn runall(&self, config : Configuration<F>) -> DynError { + fn runall(&self, cli : &CommandLineArgs, + algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { + // Get experiment configuration let &Named { name : ref experiment_name, data : Experiment { @@ -390,11 +354,25 @@ } } = self; - // Set path - let prefix = format!("{}/{}/", config.outdir, experiment_name); + // Set up output directory + let prefix = format!("{}/{}/", cli.outdir, self.name); + + // Set up algorithms + let iterator_options = AlgIteratorOptions{ + max_iter : cli.max_iter, + verbose_iter : cli.verbose_iter + .map_or(Verbose::Logarithmic(10), + |n| Verbose::Every(n)), + quiet : cli.quiet, + }; + let algorithms = match (algs, self.data.dataterm) { + (Some(algs), _) => algs, + (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], + (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], + }; // Set up operators - let depth = config.bt_depth; + let depth = DynamicDepth(8); let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); @@ -413,20 +391,20 @@ let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); std::fs::create_dir_all(&prefix)?; write_json(mkname_e("experiment"), self)?; - write_json(mkname_e("config"), &config)?; + write_json(mkname_e("config"), cli)?; write_json(mkname_e("stats"), &stats)?; - plotall(&config, &prefix, &domain, &sensor, &kernel, &spread, + plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; // Run the algorithm(s) - for named @ Named { name : alg_name, data : alg } in config.algorithms.iter() { + for named @ Named { name : alg_name, data : alg } in algorithms.iter() { let this_prefix = format!("{}{}/", prefix, alg_name); - let running = || { + let running = || if !cli.quiet { println!("{}\n{}\n{}", format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), - format!("{:?}", config.iterator_options).bright_black(), + format!("{:?}", iterator_options).bright_black(), format!("{:?}", alg).bright_black()); }; @@ -473,15 +451,14 @@ this_iters } }; - let iterator = config.iterator_options - .instantiate() - .timed() - .mapped(logmap) - .into_log(&mut logger); + let iterator = iterator_options.instantiate() + .timed() + .mapped(logmap) + .into_log(&mut logger); let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); // Create plotter and directory if needed. - let plot_count = if config.plot >= PlotLevel::Iter { 2000 } else { 0 }; + let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid); // Run the algorithm @@ -505,8 +482,8 @@ pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1) }, _ => { - let msg = format!("Algorithm “{}” not implemented for dataterm {:?}. Skipping.", - alg_name, dataterm).red(); + let msg = format!("Algorithm “{alg_name}” not implemented for \ + dataterm {dataterm:?}. Skipping.").red(); eprintln!("{}", msg); continue } @@ -519,8 +496,7 @@ // Save results println!("{}", "Saving results…".green()); - let mkname = | - t| format!("{p}{n}_{t}", p = prefix, n = alg_name, t = t); + let mkname = |t| format!("{prefix}{alg_name}_{t}"); write_json(mkname("config.json"), &named)?; write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; @@ -535,7 +511,7 @@ /// Plot experiment setup #[replace_float_literals(F::cast_from(literal))] fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( - config : &Configuration<F>, + cli : &CommandLineArgs, prefix : &String, domain : &Cube<F, N>, sensor : &Sensor, @@ -560,7 +536,7 @@ PlotLookup : Plotting<N>, Cube<F, N> : SetOrd { - if config.plot < PlotLevel::Data { + if cli.plot < PlotLevel::Data { return Ok(()) }