--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/run.rs Thu Dec 01 23:07:35 2022 +0200 @@ -0,0 +1,602 @@ +/*! +This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. +*/ + +use numeric_literals::replace_float_literals; +use colored::Colorize; +use serde::{Serialize, Deserialize}; +use serde_json; +use nalgebra::base::DVector; +use std::hash::Hash; +use chrono::{DateTime, Utc}; +use cpu_time::ProcessTime; +use clap::ValueEnum; +use std::collections::HashMap; +use std::time::Instant; + +use rand::prelude::{ + StdRng, + SeedableRng +}; +use rand_distr::Distribution; + +use alg_tools::bisection_tree::*; +use alg_tools::iterate::{ + Timed, + AlgIteratorOptions, + Verbose, + AlgIteratorFactory, +}; +use alg_tools::logger::Logger; +use alg_tools::error::DynError; +use alg_tools::tabledump::TableDump; +use alg_tools::sets::Cube; +use alg_tools::mapping::RealMapping; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::euclidean::Euclidean; +use alg_tools::norms::{Norm, L1}; +use alg_tools::lingrid::lingrid; +use alg_tools::sets::SetOrd; + +use crate::kernels::*; +use crate::types::*; +use crate::measures::*; +use crate::measures::merging::SpikeMerging; +use crate::forward_model::*; +use crate::fb::{ + FBConfig, + pointsource_fb, + FBMetaAlgorithm, FBGenericConfig, +}; +use crate::pdps::{ + PDPSConfig, + L2Squared, + pointsource_pdps, +}; +use crate::frank_wolfe::{ + FWConfig, + FWVariant, + pointsource_fw, + prepare_optimise_weights, + optimise_weights, +}; +use crate::subproblem::InnerSettings; +use crate::seminorms::*; +use crate::plot::*; +use crate::AlgorithmOverrides; + +/// Available algorithms and their configurations +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +pub enum AlgorithmConfig<F : Float> { + FB(FBConfig<F>), + FW(FWConfig<F>), + PDPS(PDPSConfig<F>), +} + +impl<F : ClapFloat> AlgorithmConfig<F> { + /// Override supported parameters based on the command line. + pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { + let override_fb_generic = |g : FBGenericConfig<F>| { + FBGenericConfig { + bootstrap_insertions : cli.bootstrap_insertions + .as_ref() + .map_or(g.bootstrap_insertions, + |n| Some((n[0], n[1]))), + merge_every : cli.merge_every.unwrap_or(g.merge_every), + merging : cli.merging.clone().unwrap_or(g.merging), + final_merging : cli.final_merging.clone().unwrap_or(g.final_merging), + .. g + } + }; + + use AlgorithmConfig::*; + match self { + FB(fb) => FB(FBConfig { + τ0 : cli.tau0.unwrap_or(fb.τ0), + insertion : override_fb_generic(fb.insertion), + .. fb + }), + PDPS(pdps) => PDPS(PDPSConfig { + τ0 : cli.tau0.unwrap_or(pdps.τ0), + σ0 : cli.sigma0.unwrap_or(pdps.σ0), + acceleration : cli.acceleration.unwrap_or(pdps.acceleration), + insertion : override_fb_generic(pdps.insertion), + .. pdps + }), + FW(fw) => FW(FWConfig { + merging : cli.merging.clone().unwrap_or(fw.merging), + .. fw + }) + } + } +} + +/// Helper struct for tagging and [`AlgorithmConfig`] or [`Experiment`] with a name. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Named<Data> { + pub name : String, + #[serde(flatten)] + pub data : Data, +} + +/// Shorthand algorithm configurations, to be used with the command line parser +#[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub enum DefaultAlgorithm { + /// The μFB forward-backward method + #[clap(name = "fb")] + FB, + /// The μFISTA inertial forward-backward method + #[clap(name = "fista")] + FISTA, + /// The “fully corrective” conditional gradient method + #[clap(name = "fw")] + FW, + /// The “relaxed conditional gradient method + #[clap(name = "fwrelax")] + FWRelax, + /// The μPDPS primal-dual proximal splitting method + #[clap(name = "pdps")] + PDPS, +} + +impl DefaultAlgorithm { + /// Returns the algorithm configuration corresponding to the algorithm shorthand + pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { + use DefaultAlgorithm::*; + match *self { + FB => AlgorithmConfig::FB(Default::default()), + FISTA => AlgorithmConfig::FB(FBConfig{ + meta : FBMetaAlgorithm::InertiaFISTA, + .. Default::default() + }), + FW => AlgorithmConfig::FW(Default::default()), + FWRelax => AlgorithmConfig::FW(FWConfig{ + variant : FWVariant::Relaxed, + .. Default::default() + }), + PDPS => AlgorithmConfig::PDPS(Default::default()), + } + } + + /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand + pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { + self.to_named(self.default_config()) + } + + pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { + let name = self.to_possible_value().unwrap().get_name().to_string(); + Named{ name , data : alg } + } +} + + +// // Floats cannot be hashed directly, so just hash the debug formatting +// // for use as file identifier. +// impl<F : Float> Hash for AlgorithmConfig<F> { +// fn hash<H: Hasher>(&self, state: &mut H) { +// format!("{:?}", self).hash(state); +// } +// } + +/// Plotting level configuration +#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, ValueEnum, Debug)] +pub enum PlotLevel { + /// Plot nothing + #[clap(name = "none")] + None, + /// Plot problem data + #[clap(name = "data")] + Data, + /// Plot iterationwise state + #[clap(name = "iter")] + Iter, +} + +/// Algorithm and iterator config for the experiments + +#[derive(Clone, Debug, Serialize)] +#[serde(default)] +pub struct Configuration<F : Float> { + /// Algorithms to run + pub algorithms : Vec<Named<AlgorithmConfig<F>>>, + /// Options for algorithm step iteration (verbosity, etc.) + pub iterator_options : AlgIteratorOptions, + /// Plotting level + pub plot : PlotLevel, + /// Directory where to save results + pub outdir : String, + /// Bisection tree depth + pub bt_depth : DynamicDepth, +} + +type DefaultBT<F, const N : usize> = BT< + DynamicDepth, + F, + usize, + Bounds<F>, + N +>; +type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>; +type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::< + F, + Sensor, + Spread, + DefaultBT<F, N>, + N +>; + +/// This is a dirty workaround to rust-csv not supporting struct flattening etc. +#[derive(Serialize)] +struct CSVLog<F> { + iter : usize, + cpu_time : f64, + value : F, + post_value : F, + n_spikes : usize, + inner_iters : usize, + merged : usize, + pruned : usize, + this_iters : usize, +} + +/// Collected experiment statistics +#[derive(Clone, Debug, Serialize)] +struct ExperimentStats<F : Float> { + /// Signal-to-noise ratio in decibels + ssnr : F, + /// Proportion of noise in the signal as a number in $[0, 1]$. + noise_ratio : F, + /// When the experiment was run (UTC) + when : DateTime<Utc>, +} + +#[replace_float_literals(F::cast_from(literal))] +impl<F : Float> ExperimentStats<F> { + /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. + fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self { + let s = signal.norm2_squared(); + let n = noise.norm2_squared(); + let noise_ratio = (n / s).sqrt(); + let ssnr = 10.0 * (s / n).log10(); + ExperimentStats { + ssnr, + noise_ratio, + when : Utc::now(), + } + } +} +/// Collected algorithm statistics +#[derive(Clone, Debug, Serialize)] +struct AlgorithmStats<F : Float> { + /// Overall CPU time spent + cpu_time : F, + /// Real time spent + elapsed : F +} + + +/// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input +/// and outputs a [`DynError`]. +fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { + serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; + Ok(()) +} + + +/// Struct for experiment configurations +#[derive(Debug, Clone, Serialize)] +pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> +where F : Float, + [usize; N] : Serialize, + NoiseDistr : Distribution<F>, + S : Sensor<F, N>, + P : Spread<F, N>, + K : SimpleConvolutionKernel<F, N>, +{ + /// Domain $Ω$. + pub domain : Cube<F, N>, + /// Number of sensors along each dimension + pub sensor_count : [usize; N], + /// Noise distribution + pub noise_distr : NoiseDistr, + /// Seed for random noise generation (for repeatable experiments) + pub noise_seed : u64, + /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. + pub sensor : S, + /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. + pub spread : P, + /// Kernel $ρ$ of $𝒟$. + pub kernel : K, + /// True point sources + pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, + /// Regularisation parameter + pub α : F, + /// For plotting : how wide should the kernels be plotted + pub kernel_plot_width : F, + /// Data term + pub dataterm : DataTerm, + /// A map of default configurations for algorithms + #[serde(skip)] + pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, +} + +/// Trait for runnable experiments +pub trait RunnableExperiment<F : ClapFloat> { + /// Run all algorithms of the [`Configuration`] `config` on the experiment. + fn runall(&self, config : Configuration<F>) -> DynError; + + /// Returns the default configuration + fn default_config(&self) -> Configuration<F>; + + /// Return algorithm default config + fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) + -> Named<AlgorithmConfig<F>>; +} + +impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for +Named<Experiment<F, NoiseDistr, S, K, P, N>> +where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, + [usize; N] : Serialize, + S : Sensor<F, N> + Copy + Serialize, + P : Spread<F, N> + Copy + Serialize, + Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, + AutoConvolution<P> : BoundedBy<F, K>, + K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + Copy + Serialize, + Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, + PlotLookup : Plotting<N>, + DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, + BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, + DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, + NoiseDistr : Distribution<F> + Serialize { + + fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) + -> Named<AlgorithmConfig<F>> { + alg.to_named( + self.data + .algorithm_defaults + .get(&alg) + .map_or_else(|| alg.default_config(), + |config| config.clone()) + .cli_override(cli) + ) + } + + fn default_config(&self) -> Configuration<F> { + let default_alg = match self.data.dataterm { + DataTerm::L2Squared => DefaultAlgorithm::FB.get_named(), + DataTerm::L1 => DefaultAlgorithm::PDPS.get_named(), + }; + + Configuration{ + algorithms : vec![default_alg], + iterator_options : AlgIteratorOptions{ + max_iter : 2000, + verbose_iter : Verbose::Logarithmic(10), + quiet : false, + }, + plot : PlotLevel::Data, + outdir : "out".to_string(), + bt_depth : DynamicDepth(8), + } + } + + fn runall(&self, config : Configuration<F>) -> DynError { + let &Named { + name : ref experiment_name, + data : Experiment { + domain, sensor_count, ref noise_distr, sensor, spread, kernel, + ref μ_hat, α, kernel_plot_width, dataterm, noise_seed, + .. + } + } = self; + + // Set path + let prefix = format!("{}/{}/", config.outdir, experiment_name); + + // Set up operators + let depth = config.bt_depth; + let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); + let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); + + // Set up random number generator. + let mut rng = StdRng::seed_from_u64(noise_seed); + + // Generate the data and calculate SSNR statistic + let b_hat = opA.apply(μ_hat); + let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); + let b = &b_hat + &noise; + // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField + // overloading log10 and conflicting with standard NumTraits one. + let stats = ExperimentStats::new(&b, &noise); + + // Save experiment configuration and statistics + let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); + std::fs::create_dir_all(&prefix)?; + write_json(mkname_e("experiment"), self)?; + write_json(mkname_e("config"), &config)?; + write_json(mkname_e("stats"), &stats)?; + + plotall(&config, &prefix, &domain, &sensor, &kernel, &spread, + &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; + + // Run the algorithm(s) + for named @ Named { name : alg_name, data : alg } in config.algorithms.iter() { + let this_prefix = format!("{}{}/", prefix, alg_name); + + let running = || { + println!("{}\n{}\n{}", + format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), + format!("{:?}", config.iterator_options).bright_black(), + format!("{:?}", alg).bright_black()); + }; + + // Create Logger and IteratorFactory + let mut logger = Logger::new(); + let findim_data = prepare_optimise_weights(&opA); + let inner_config : InnerSettings<F> = Default::default(); + let inner_it = inner_config.iterator_options; + let logmap = |iter, Timed { cpu_time, data }| { + let IterInfo { + value, + n_spikes, + inner_iters, + merged, + pruned, + postprocessing, + this_iters, + .. + } = data; + let post_value = match postprocessing { + None => value, + Some(mut μ) => { + match dataterm { + DataTerm::L2Squared => { + optimise_weights( + &mut μ, &opA, &b, α, &findim_data, &inner_config, + inner_it + ); + dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon) + }, + _ => value, + } + } + }; + CSVLog { + iter, + value, + post_value, + n_spikes, + cpu_time : cpu_time.as_secs_f64(), + inner_iters, + merged, + pruned, + this_iters + } + }; + let iterator = config.iterator_options + .instantiate() + .timed() + .mapped(logmap) + .into_log(&mut logger); + let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); + + // Create plotter and directory if needed. + let plot_count = if config.plot >= PlotLevel::Iter { 2000 } else { 0 }; + let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid); + + // Run the algorithm + let start = Instant::now(); + let start_cpu = ProcessTime::now(); + let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) { + (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => { + running(); + pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter) + }, + (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => { + running(); + pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter) + }, + (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => { + running(); + pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared) + }, + (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => { + running(); + pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1) + }, + _ => { + let msg = format!("Algorithm “{}” not implemented for dataterm {:?}. Skipping.", + alg_name, dataterm).red(); + eprintln!("{}", msg); + continue + } + }; + let elapsed = start.elapsed().as_secs_f64(); + let cpu_time = start_cpu.elapsed().as_secs_f64(); + + println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); + + // Save results + println!("{}", "Saving results…".green()); + + let mkname = | + t| format!("{p}{n}_{t}", p = prefix, n = alg_name, t = t); + + write_json(mkname("config.json"), &named)?; + write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; + μ.write_csv(mkname("reco.txt"))?; + logger.write_csv(mkname("log.txt"))?; + } + + Ok(()) + } +} + +/// Plot experiment setup +#[replace_float_literals(F::cast_from(literal))] +fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( + config : &Configuration<F>, + prefix : &String, + domain : &Cube<F, N>, + sensor : &Sensor, + kernel : &Kernel, + spread : &Spread, + μ_hat : &DiscreteMeasure<Loc<F, N>, F>, + op𝒟 : &𝒟, + opA : &A, + b_hat : &A::Observable, + b : &A::Observable, + kernel_plot_width : F, +) -> DynError +where F : Float + ToNalgebraRealField, + Sensor : RealMapping<F, N> + Support<F, N> + Clone, + Spread : RealMapping<F, N> + Support<F, N> + Clone, + Kernel : RealMapping<F, N> + Support<F, N>, + Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, + 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, + 𝒟::Codomain : RealMapping<F, N>, + A : ForwardModel<Loc<F, N>, F>, + A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, + PlotLookup : Plotting<N>, + Cube<F, N> : SetOrd { + + if config.plot < PlotLevel::Data { + return Ok(()) + } + + let base = Convolution(sensor.clone(), spread.clone()); + + let resolution = if N==1 { 100 } else { 40 }; + let pfx = |n| format!("{}{}", prefix, n); + let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); + + PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); + PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); + PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); + PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); + + let plotgrid2 = lingrid(&domain, &[resolution; N]); + + let ω_hat = op𝒟.apply(μ_hat); + let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); + PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string()); + PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"), + "noise Aᵀ(Aμ̂ - b)".to_string()); + + let preadj_b = opA.preadjoint().apply(b); + let preadj_b_hat = opA.preadjoint().apply(b_hat); + //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); + PlotLookup::plot_into_file_spikes( + "Aᵀb".to_string(), &preadj_b, + "Aᵀb̂".to_string(), Some(&preadj_b_hat), + plotgrid2, None, &μ_hat, + pfx("omega_b") + ); + + // Save true solution and observables + let pfx = |n| format!("{}{}", prefix, n); + μ_hat.write_csv(pfx("orig.txt"))?; + opA.write_observable(&b_hat, pfx("b_hat"))?; + opA.write_observable(&b, pfx("b_noisy")) +} +