Thu, 23 Jan 2025 23:34:05 +0100
Merging adjustments, parameter tuning, etc.
/*! This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. */ use numeric_literals::replace_float_literals; use colored::Colorize; use serde::{Serialize, Deserialize}; use serde_json; use nalgebra::base::DVector; use std::hash::Hash; use chrono::{DateTime, Utc}; use cpu_time::ProcessTime; use clap::ValueEnum; use std::collections::HashMap; use std::time::Instant; use rand::prelude::{ StdRng, SeedableRng }; use rand_distr::Distribution; use alg_tools::bisection_tree::*; use alg_tools::iterate::{ Timed, AlgIteratorOptions, Verbose, AlgIteratorFactory, LoggingIteratorFactory, TimingIteratorFactory, BasicAlgIteratorFactory, }; use alg_tools::logger::Logger; use alg_tools::error::{ DynError, DynResult, }; use alg_tools::tabledump::TableDump; use alg_tools::sets::Cube; use alg_tools::mapping::{ RealMapping, DifferentiableMapping, DifferentiableRealMapping, Instance }; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::euclidean::Euclidean; use alg_tools::lingrid::{lingrid, LinSpace}; use alg_tools::sets::SetOrd; use alg_tools::linops::{RowOp, IdOp /*, ZeroOp*/}; use alg_tools::discrete_gradient::{Grad, ForwardNeumann}; use alg_tools::convex::Zero; use alg_tools::maputil::map3; use alg_tools::direct_product::Pair; use crate::kernels::*; use crate::types::*; use crate::measures::*; use crate::measures::merging::{SpikeMerging,SpikeMergingMethod}; use crate::forward_model::*; use crate::forward_model::sensor_grid::{ SensorGrid, SensorGridBT, //SensorGridBTFN, Sensor, Spread, }; use crate::fb::{ FBConfig, FBGenericConfig, pointsource_fb_reg, pointsource_fista_reg, }; use crate::sliding_fb::{ SlidingFBConfig, TransportConfig, pointsource_sliding_fb_reg }; use crate::sliding_pdps::{ SlidingPDPSConfig, pointsource_sliding_pdps_pair }; use crate::forward_pdps::{ ForwardPDPSConfig, pointsource_forward_pdps_pair }; use crate::pdps::{ PDPSConfig, pointsource_pdps_reg, }; use crate::frank_wolfe::{ FWConfig, FWVariant, pointsource_fw_reg, //WeightOptim, }; use crate::subproblem::{InnerSettings, InnerMethod}; use crate::seminorms::*; use crate::plot::*; use crate::{AlgorithmOverrides, CommandLineArgs}; use crate::tolerance::Tolerance; use crate::regularisation::{ Regularisation, RadonRegTerm, NonnegRadonRegTerm }; use crate::dataterm::{ L1, L2Squared, }; use crate::prox_penalty::{ RadonSquared, //ProxPenalty, }; use alg_tools::norms::{L2, NormExponent}; use alg_tools::operator_arithmetic::Weighted; use anyhow::anyhow; /// Available proximal terms #[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub enum ProxTerm { /// Partial-to-wave operator 𝒟. Wave, /// Radon-norm squared RadonSquared } /// Available algorithms and their configurations #[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub enum AlgorithmConfig<F : Float> { FB(FBConfig<F>, ProxTerm), FISTA(FBConfig<F>, ProxTerm), FW(FWConfig<F>), PDPS(PDPSConfig<F>, ProxTerm), SlidingFB(SlidingFBConfig<F>, ProxTerm), ForwardPDPS(ForwardPDPSConfig<F>, ProxTerm), SlidingPDPS(SlidingPDPSConfig<F>, ProxTerm), } fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { assert!(v.len() == 3); Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } } impl<F : ClapFloat> AlgorithmConfig<F> { /// Override supported parameters based on the command line. pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { let override_merging = |g : SpikeMergingMethod<F>| { SpikeMergingMethod { enabled : cli.merge.unwrap_or(g.enabled), radius : cli.merge_radius.unwrap_or(g.radius), interp : cli.merge_interp.unwrap_or(g.interp), } }; let override_fb_generic = |g : FBGenericConfig<F>| { FBGenericConfig { bootstrap_insertions : cli.bootstrap_insertions .as_ref() .map_or(g.bootstrap_insertions, |n| Some((n[0], n[1]))), merge_every : cli.merge_every.unwrap_or(g.merge_every), merging : override_merging(g.merging), final_merging : cli.final_merging.unwrap_or(g.final_merging), fitness_merging : cli.fitness_merging.unwrap_or(g.fitness_merging), tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance), .. g } }; let override_transport = |g : TransportConfig<F>| { TransportConfig { θ0 : cli.theta0.unwrap_or(g.θ0), tolerance_mult_pos: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_pos), tolerance_mult_pri: cli.transport_tolerance_pri.unwrap_or(g.tolerance_mult_pri), adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), .. g } }; use AlgorithmConfig::*; match self { FB(fb, prox) => FB(FBConfig { τ0 : cli.tau0.unwrap_or(fb.τ0), generic : override_fb_generic(fb.generic), .. fb }, prox), FISTA(fb, prox) => FISTA(FBConfig { τ0 : cli.tau0.unwrap_or(fb.τ0), generic : override_fb_generic(fb.generic), .. fb }, prox), PDPS(pdps, prox) => PDPS(PDPSConfig { τ0 : cli.tau0.unwrap_or(pdps.τ0), σ0 : cli.sigma0.unwrap_or(pdps.σ0), acceleration : cli.acceleration.unwrap_or(pdps.acceleration), generic : override_fb_generic(pdps.generic), .. pdps }, prox), FW(fw) => FW(FWConfig { merging : override_merging(fw.merging), tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), .. fw }), SlidingFB(sfb, prox) => SlidingFB(SlidingFBConfig { τ0 : cli.tau0.unwrap_or(sfb.τ0), transport : override_transport(sfb.transport), insertion : override_fb_generic(sfb.insertion), .. sfb }, prox), SlidingPDPS(spdps, prox) => SlidingPDPS(SlidingPDPSConfig { τ0 : cli.tau0.unwrap_or(spdps.τ0), σp0 : cli.sigmap0.unwrap_or(spdps.σp0), σd0 : cli.sigma0.unwrap_or(spdps.σd0), //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), transport : override_transport(spdps.transport), insertion : override_fb_generic(spdps.insertion), .. spdps }, prox), ForwardPDPS(fpdps, prox) => ForwardPDPS(ForwardPDPSConfig { τ0 : cli.tau0.unwrap_or(fpdps.τ0), σp0 : cli.sigmap0.unwrap_or(fpdps.σp0), σd0 : cli.sigma0.unwrap_or(fpdps.σd0), //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), insertion : override_fb_generic(fpdps.insertion), .. fpdps }, prox), } } } /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Named<Data> { pub name : String, #[serde(flatten)] pub data : Data, } /// Shorthand algorithm configurations, to be used with the command line parser #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] pub enum DefaultAlgorithm { /// The μFB forward-backward method #[clap(name = "fb")] FB, /// The μFISTA inertial forward-backward method #[clap(name = "fista")] FISTA, /// The “fully corrective” conditional gradient method #[clap(name = "fw")] FW, /// The “relaxed conditional gradient method #[clap(name = "fwrelax")] FWRelax, /// The μPDPS primal-dual proximal splitting method #[clap(name = "pdps")] PDPS, /// The sliding FB method #[clap(name = "sliding_fb", alias = "sfb")] SlidingFB, /// The sliding PDPS method #[clap(name = "sliding_pdps", alias = "spdps")] SlidingPDPS, /// The PDPS method with a forward step for the smooth function #[clap(name = "forward_pdps", alias = "fpdps")] ForwardPDPS, // Radon variants /// The μFB forward-backward method with radon-norm squared proximal term #[clap(name = "radon_fb")] RadonFB, /// The μFISTA inertial forward-backward method with radon-norm squared proximal term #[clap(name = "radon_fista")] RadonFISTA, /// The μPDPS primal-dual proximal splitting method with radon-norm squared proximal term #[clap(name = "radon_pdps")] RadonPDPS, /// The sliding FB method with radon-norm squared proximal term #[clap(name = "radon_sliding_fb", alias = "radon_sfb")] RadonSlidingFB, /// The sliding PDPS method with radon-norm squared proximal term #[clap(name = "radon_sliding_pdps", alias = "radon_spdps")] RadonSlidingPDPS, /// The PDPS method with a forward step for the smooth function with radon-norm squared proximal term #[clap(name = "radon_forward_pdps", alias = "radon_fpdps")] RadonForwardPDPS, } impl DefaultAlgorithm { /// Returns the algorithm configuration corresponding to the algorithm shorthand pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { use DefaultAlgorithm::*; let radon_insertion = FBGenericConfig { merging : SpikeMergingMethod{ interp : false, .. Default::default() }, inner : InnerSettings { method : InnerMethod::PDPS, // SSN not implemented .. Default::default() }, .. Default::default() }; match *self { FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave), FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave), FW => AlgorithmConfig::FW(Default::default()), FWRelax => AlgorithmConfig::FW(FWConfig{ variant : FWVariant::Relaxed, .. Default::default() }), PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave), SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave), SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave), ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave), // Radon variants RadonFB => AlgorithmConfig::FB( FBConfig{ generic : radon_insertion, ..Default::default() }, ProxTerm::RadonSquared ), RadonFISTA => AlgorithmConfig::FISTA( FBConfig{ generic : radon_insertion, ..Default::default() }, ProxTerm::RadonSquared ), RadonPDPS => AlgorithmConfig::PDPS( PDPSConfig{ generic : radon_insertion, ..Default::default() }, ProxTerm::RadonSquared ), RadonSlidingFB => AlgorithmConfig::SlidingFB( SlidingFBConfig{ insertion : radon_insertion, ..Default::default() }, ProxTerm::RadonSquared ), RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS( SlidingPDPSConfig{ insertion : radon_insertion, ..Default::default() }, ProxTerm::RadonSquared ), RadonForwardPDPS => AlgorithmConfig::ForwardPDPS( ForwardPDPSConfig{ insertion : radon_insertion, ..Default::default() }, ProxTerm::RadonSquared ), } } /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { self.to_named(self.default_config()) } pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { let name = self.to_possible_value().unwrap().get_name().to_string(); Named{ name , data : alg } } } // // Floats cannot be hashed directly, so just hash the debug formatting // // for use as file identifier. // impl<F : Float> Hash for AlgorithmConfig<F> { // fn hash<H: Hasher>(&self, state: &mut H) { // format!("{:?}", self).hash(state); // } // } /// Plotting level configuration #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, ValueEnum, Debug)] pub enum PlotLevel { /// Plot nothing #[clap(name = "none")] None, /// Plot problem data #[clap(name = "data")] Data, /// Plot iterationwise state #[clap(name = "iter")] Iter, } impl Default for PlotLevel { fn default() -> Self { Self::Data } } type DefaultBT<F, const N : usize> = BT< DynamicDepth, F, usize, Bounds<F>, N >; type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>; type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::< F, Sensor, Spread, DefaultBT<F, N>, N >; /// This is a dirty workaround to rust-csv not supporting struct flattening etc. #[derive(Serialize)] struct CSVLog<F> { iter : usize, cpu_time : f64, value : F, relative_value : F, //post_value : F, n_spikes : usize, inner_iters : usize, merged : usize, pruned : usize, this_iters : usize, } /// Collected experiment statistics #[derive(Clone, Debug, Serialize)] struct ExperimentStats<F : Float> { /// Signal-to-noise ratio in decibels ssnr : F, /// Proportion of noise in the signal as a number in $[0, 1]$. noise_ratio : F, /// When the experiment was run (UTC) when : DateTime<Utc>, } #[replace_float_literals(F::cast_from(literal))] impl<F : Float> ExperimentStats<F> { /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self { let s = signal.norm2_squared(); let n = noise.norm2_squared(); let noise_ratio = (n / s).sqrt(); let ssnr = 10.0 * (s / n).log10(); ExperimentStats { ssnr, noise_ratio, when : Utc::now(), } } } /// Collected algorithm statistics #[derive(Clone, Debug, Serialize)] struct AlgorithmStats<F : Float> { /// Overall CPU time spent cpu_time : F, /// Real time spent elapsed : F } /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input /// and outputs a [`DynError`]. fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; Ok(()) } /// Struct for experiment configurations #[derive(Debug, Clone, Serialize)] pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize> where F : Float + ClapFloat, [usize; N] : Serialize, NoiseDistr : Distribution<F>, S : Sensor<F, N>, P : Spread<F, N>, K : SimpleConvolutionKernel<F, N>, { /// Domain $Ω$. pub domain : Cube<F, N>, /// Number of sensors along each dimension pub sensor_count : [usize; N], /// Noise distribution pub noise_distr : NoiseDistr, /// Seed for random noise generation (for repeatable experiments) pub noise_seed : u64, /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. pub sensor : S, /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. pub spread : P, /// Kernel $ρ$ of $𝒟$. pub kernel : K, /// True point sources pub μ_hat : RNDM<F, N>, /// Regularisation term and parameter pub regularisation : Regularisation<F>, /// For plotting : how wide should the kernels be plotted pub kernel_plot_width : F, /// Data term pub dataterm : DataTerm, /// A map of default configurations for algorithms pub algorithm_overrides : HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>, /// Default merge radius pub default_merge_radius : F, } #[derive(Debug, Clone, Serialize)] pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize> where F : Float + ClapFloat, [usize; N] : Serialize, NoiseDistr : Distribution<F>, S : Sensor<F, N>, P : Spread<F, N>, K : SimpleConvolutionKernel<F, N>, B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, { /// Basic setup pub base : ExperimentV2<F, NoiseDistr, S, K, P, N>, /// Weight of TV term pub λ : F, /// Bias function pub bias : B, } /// Trait for runnable experiments pub trait RunnableExperiment<F : ClapFloat> { /// Run all algorithms provided, or default algorithms if none provided, on the experiment. fn runall(&self, cli : &CommandLineArgs, algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; /// Return algorithm default config fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F>; } /// Helper function to print experiment start message and save setup. /// Returns saving prefix. fn start_experiment<E, S>( experiment : &Named<E>, cli : &CommandLineArgs, stats : S, ) -> DynResult<String> where E : Serialize + std::fmt::Debug, S : Serialize, { let Named { name : experiment_name, data } = experiment; println!("{}\n{}", format!("Performing experiment {}…", experiment_name).cyan(), format!("Experiment settings: {}", serde_json::to_string(&data)?).bright_black()); // Set up output directory let prefix = format!("{}/{}/", cli.outdir, experiment_name); // Save experiment configuration and statistics let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); std::fs::create_dir_all(&prefix)?; write_json(mkname_e("experiment"), experiment)?; write_json(mkname_e("config"), cli)?; write_json(mkname_e("stats"), &stats)?; Ok(prefix) } /// Error codes for running an algorithm on an experiment. enum RunError { /// Algorithm not implemented for this experiment NotImplemented, } use RunError::*; type DoRunAllIt<'a, F, const N : usize> = LoggingIteratorFactory< 'a, Timed<IterInfo<F, N>>, TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F, N>>> >; /// Helper function to run all algorithms on an experiment. fn do_runall<F : Float + for<'b> Deserialize<'b>, Z, const N : usize>( experiment_name : &String, prefix : &String, cli : &CommandLineArgs, algorithms : Vec<Named<AlgorithmConfig<F>>>, plotgrid : LinSpace<Loc<F, N>, [usize; N]>, mut save_extra : impl FnMut(String, Z) -> DynError, mut do_alg : impl FnMut( &AlgorithmConfig<F>, DoRunAllIt<F, N>, SeqPlotter<F, N>, String, ) -> Result<(RNDM<F, N>, Z), RunError>, ) -> DynError where PlotLookup : Plotting<N>, { let mut logs = Vec::new(); let iterator_options = AlgIteratorOptions{ max_iter : cli.max_iter, verbose_iter : cli.verbose_iter .map_or(Verbose::LogarithmicCap{base : 10, cap : 2}, |n| Verbose::Every(n)), quiet : cli.quiet, }; // Run the algorithm(s) for named @ Named { name : alg_name, data : alg } in algorithms.iter() { let this_prefix = format!("{}{}/", prefix, alg_name); // Create Logger and IteratorFactory let mut logger = Logger::new(); let iterator = iterator_options.instantiate() .timed() .into_log(&mut logger); let running = if !cli.quiet { format!("{}\n{}\n{}\n", format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), format!("Iteration settings: {}", serde_json::to_string(&iterator_options)?).bright_black(), format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black()) } else { "".to_string() }; // // The following is for postprocessing, which has been disabled anyway. // // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation { // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), // }; //let findim_data = reg.prepare_optimise_weights(&opA, &b); //let inner_config : InnerSettings<F> = Default::default(); //let inner_it = inner_config.iterator_options; // Create plotter and directory if needed. let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()); let start = Instant::now(); let start_cpu = ProcessTime::now(); let (μ, z) = match do_alg(alg, iterator, plotter, running) { Ok(μ) => μ, Err(RunError::NotImplemented) => { let msg = format!("Algorithm “{alg_name}” not implemented for {experiment_name}. \ Skipping.").red(); eprintln!("{}", msg); continue } }; let elapsed = start.elapsed().as_secs_f64(); let cpu_time = start_cpu.elapsed().as_secs_f64(); println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); // Save results println!("{}", "Saving results …".green()); let mkname = |t| format!("{prefix}{alg_name}_{t}"); write_json(mkname("config.json"), &named)?; write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; μ.write_csv(mkname("reco.txt"))?; save_extra(mkname(""), z)?; //logger.write_csv(mkname("log.txt"))?; logs.push((mkname("log.txt"), logger)); } save_logs(logs, format!("{prefix}valuerange.json"), cli.load_valuerange) } #[replace_float_literals(F::cast_from(literal))] impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N : usize> RunnableExperiment<F> for Named<ExperimentV2<F, NoiseDistr, S, K, P, N>> where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> + Default + for<'b> Deserialize<'b>, [usize; N] : Serialize, S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy // TODO: shold not have differentiability as a requirement, but // decide availability of sliding based on it. //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, // TODO: very weird that rust only compiles with Differentiable // instead of the above one on references, which is required by // poitsource_sliding_fb_reg. + DifferentiableRealMapping<F, N> + Lipschitz<L2, FloatType=F>, for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb. AutoConvolution<P> : BoundedBy<F, K>, K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + Copy + Serialize + std::fmt::Debug, Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, PlotLookup : Plotting<N>, DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, RNDM<F, N> : SpikeMerging<F>, NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, // PreadjointCodomain : Space + Bounded<F> + DifferentiableRealMapping<F, N>, // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, { fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { AlgorithmOverrides { merge_radius : Some(self.data.default_merge_radius), .. self.data.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) } } fn runall(&self, cli : &CommandLineArgs, algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { // Get experiment configuration let &Named { name : ref experiment_name, data : ExperimentV2 { domain, sensor_count, ref noise_distr, sensor, spread, kernel, ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, .. } } = self; // Set up algorithms let algorithms = match (algs, dataterm) { (Some(algs), _) => algs, (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], }; // Set up operators let depth = DynamicDepth(8); let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); // Set up random number generator. let mut rng = StdRng::seed_from_u64(noise_seed); // Generate the data and calculate SSNR statistic let b_hat = opA.apply(μ_hat); let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); let b = &b_hat + &noise; // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField // overloading log10 and conflicting with standard NumTraits one. let stats = ExperimentStats::new(&b, &noise); let prefix = start_experiment(&self, cli, stats)?; plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); let save_extra = |_, ()| Ok(()); do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, |alg, iterator, plotter, running| { let μ = match alg { AlgorithmConfig::FB(ref algconfig, prox) => { match (regularisation, dataterm, prox) { (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fb_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fb_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); pointsource_fb_reg( &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter ) }), (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); pointsource_fb_reg( &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter ) }), _ => Err(NotImplemented) } }, AlgorithmConfig::FISTA(ref algconfig, prox) => { match (regularisation, dataterm, prox) { (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fista_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_fista_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); pointsource_fista_reg( &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter ) }), (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); pointsource_fista_reg( &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter ) }), _ => Err(NotImplemented), } }, AlgorithmConfig::SlidingFB(ref algconfig, prox) => { match (regularisation, dataterm, prox) { (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_fb_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_fb_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter ) }), (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); pointsource_sliding_fb_reg( &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter ) }), (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); pointsource_sliding_fb_reg( &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter ) }), _ => Err(NotImplemented), } }, AlgorithmConfig::PDPS(ref algconfig, prox) => { print!("{running}"); match (regularisation, dataterm, prox) { (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L2Squared ) }), (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L2Squared ) }), (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L1 ) }), (Regularisation::Radon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ pointsource_pdps_reg( &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, L1 ) }), (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ pointsource_pdps_reg( &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter, L2Squared ) }), (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ pointsource_pdps_reg( &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter, L2Squared ) }), (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ pointsource_pdps_reg( &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter, L1 ) }), (Regularisation::Radon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ pointsource_pdps_reg( &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter, L1 ) }), // _ => Err(NotImplemented), } }, AlgorithmConfig::FW(ref algconfig) => { match (regularisation, dataterm) { (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ print!("{running}"); pointsource_fw_reg(&opA, &b, RadonRegTerm(α), algconfig, iterator, plotter) }), (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ print!("{running}"); pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), algconfig, iterator, plotter) }), _ => Err(NotImplemented), } }, _ => Err(NotImplemented), }?; Ok((μ, ())) }) } } #[replace_float_literals(F::cast_from(literal))] impl<F, NoiseDistr, S, K, P, B, /*PreadjointCodomain,*/ const N : usize> RunnableExperiment<F> for Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> + Default + for<'b> Deserialize<'b>, [usize; N] : Serialize, S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy // TODO: shold not have differentiability as a requirement, but // decide availability of sliding based on it. //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, // TODO: very weird that rust only compiles with Differentiable // instead of the above one on references, which is required by // poitsource_sliding_fb_reg. + DifferentiableRealMapping<F, N> + Lipschitz<L2, FloatType=F>, for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb. AutoConvolution<P> : BoundedBy<F, K>, K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + Copy + Serialize + std::fmt::Debug, Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, PlotLookup : Plotting<N>, DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, RNDM<F, N> : SpikeMerging<F>, NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, // PreadjointCodomain : Bounded<F> + DifferentiableRealMapping<F, N>, // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, { fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { AlgorithmOverrides { merge_radius : Some(self.data.base.default_merge_radius), .. self.data.base.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) } } fn runall(&self, cli : &CommandLineArgs, algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { // Get experiment configuration let &Named { name : ref experiment_name, data : ExperimentBiased { λ, ref bias, base : ExperimentV2 { domain, sensor_count, ref noise_distr, sensor, spread, kernel, ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, .. } } } = self; // Set up algorithms let algorithms = match (algs, dataterm) { (Some(algs), _) => algs, _ => vec![DefaultAlgorithm::SlidingPDPS.get_named()], }; // Set up operators let depth = DynamicDepth(8); let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); let opAext = RowOp(opA.clone(), IdOp::new()); let fnR = Zero::new(); let h = map3(domain.span_start(), domain.span_end(), sensor_count, |a, b, n| (b-a)/F::cast_from(n)) .into_iter() .reduce(NumTraitsFloat::max) .unwrap(); let z = DVector::zeros(sensor_count.iter().product()); let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap(); let y = opKz.apply(&z); let fnH = Weighted{ base_fn : L1.as_mapping(), weight : λ}; // TODO: L_{2,1} // let zero_y = y.clone(); // let zeroBTFN = opA.preadjoint().apply(&zero_y); // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); // Set up random number generator. let mut rng = StdRng::seed_from_u64(noise_seed); // Generate the data and calculate SSNR statistic let bias_vec = DVector::from_vec(opA.grid() .into_iter() .map(|v| bias.apply(v)) .collect::<Vec<F>>()); let b_hat : DVector<_> = opA.apply(μ_hat) + &bias_vec; let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); let b = &b_hat + &noise; // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField // overloading log10 and conflicting with standard NumTraits one. let stats = ExperimentStats::new(&b, &noise); let prefix = start_experiment(&self, cli, stats)?; plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; opA.write_observable(&bias_vec, format!("{prefix}bias"))?; let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z")); // Run the algorithms do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, |alg, iterator, plotter, running| { let Pair(μ, z) = match alg { AlgorithmConfig::ForwardPDPS(ref algconfig, prox) => { match (regularisation, dataterm, prox) { (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_forward_pdps_pair( &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_forward_pdps_pair( &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); pointsource_forward_pdps_pair( &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter, /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); pointsource_forward_pdps_pair( &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter, /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), _ => Err(NotImplemented) } }, AlgorithmConfig::SlidingPDPS(ref algconfig, prox) => { match (regularisation, dataterm, prox) { (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_pdps_pair( &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ print!("{running}"); pointsource_sliding_pdps_pair( &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, iterator, plotter, /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); pointsource_sliding_pdps_pair( &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter, /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ print!("{running}"); pointsource_sliding_pdps_pair( &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, iterator, plotter, /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), ) }), _ => Err(NotImplemented) } }, _ => Err(NotImplemented) }?; Ok((μ, z)) }) } } #[derive(Copy, Clone, Debug, Serialize, Deserialize)] struct ValueRange<F : Float> { ini : F, min : F, } impl<F : Float> ValueRange<F> { fn expand_with(self, other : Self) -> Self { ValueRange { ini : self.ini.max(other.ini), min : self.min.min(other.min), } } } /// Calculative minimum and maximum values of all the `logs`, and save them into /// corresponding file names given as the first elements of the tuples in the vectors. fn save_logs<F : Float + for<'b> Deserialize<'b>, const N : usize>( logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)>, valuerange_file : String, load_valuerange : bool, ) -> DynError { // Process logs for relative values println!("{}", "Processing logs…"); // Find minimum value and initial value within a single log let proc_single_log = |log : &Logger<Timed<IterInfo<F, N>>>| { let d = log.data(); let mi = d.iter() .map(|i| i.data.value) .reduce(NumTraitsFloat::min); d.first() .map(|i| i.data.value) .zip(mi) .map(|(ini, min)| ValueRange{ ini, min }) }; // Find minimum and maximum value over all logs let mut v = logs.iter() .filter_map(|&(_, ref log)| proc_single_log(log)) .reduce(|v1, v2| v1.expand_with(v2)) .ok_or(anyhow!("No algorithms found"))?; // Load existing range if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() { let data = std::fs::read_to_string(&valuerange_file)?; v = v.expand_with(serde_json::from_str(&data)?); } let logmap = |Timed { cpu_time, iter, data }| { let IterInfo { value, n_spikes, inner_iters, merged, pruned, //postprocessing, this_iters, .. } = data; // let post_value = match (postprocessing, dataterm) { // (Some(mut μ), DataTerm::L2Squared) => { // // Comparison postprocessing is only implemented for the case handled // // by the FW variants. // reg.optimise_weights( // &mut μ, &opA, &b, &findim_data, &inner_config, // inner_it // ); // dataterm.value_at_residual(opA.apply(&μ) - &b) // + regularisation.apply(&μ) // }, // _ => value, // }; let relative_value = (value - v.min)/(v.ini - v.min); CSVLog { iter, value, relative_value, //post_value, n_spikes, cpu_time : cpu_time.as_secs_f64(), inner_iters, merged, pruned, this_iters } }; println!("{}", "Saving logs …".green()); serde_json::to_writer_pretty(std::fs::File::create(&valuerange_file)?, &v)?; for (name, logger) in logs { logger.map(logmap).write_csv(name)?; } Ok(()) } /// Plot experiment setup #[replace_float_literals(F::cast_from(literal))] fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( cli : &CommandLineArgs, prefix : &String, domain : &Cube<F, N>, sensor : &Sensor, kernel : &Kernel, spread : &Spread, μ_hat : &RNDM<F, N>, op𝒟 : &𝒟, opA : &A, b_hat : &A::Observable, b : &A::Observable, kernel_plot_width : F, ) -> DynError where F : Float + ToNalgebraRealField, Sensor : RealMapping<F, N> + Support<F, N> + Clone, Spread : RealMapping<F, N> + Support<F, N> + Clone, Kernel : RealMapping<F, N> + Support<F, N>, Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>, 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, 𝒟::Codomain : RealMapping<F, N>, A : ForwardModel<RNDM<F, N>, F>, for<'a> &'a A::Observable : Instance<A::Observable>, A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>, PlotLookup : Plotting<N>, Cube<F, N> : SetOrd { if cli.plot < PlotLevel::Data { return Ok(()) } let base = Convolution(sensor.clone(), spread.clone()); let resolution = if N==1 { 100 } else { 40 }; let pfx = |n| format!("{prefix}{n}"); let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor")); PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel")); PlotLookup::plot_into_file(spread, plotgrid, pfx("spread")); PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor")); let plotgrid2 = lingrid(&domain, &[resolution; N]); let ω_hat = op𝒟.apply(μ_hat); let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat")); PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise")); let preadj_b = opA.preadjoint().apply(b); let preadj_b_hat = opA.preadjoint().apply(b_hat); //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); PlotLookup::plot_into_file_spikes( Some(&preadj_b), Some(&preadj_b_hat), plotgrid2, &μ_hat, pfx("omega_b") ); PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b")); PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat")); // Save true solution and observables μ_hat.write_csv(pfx("orig.txt"))?; opA.write_observable(&b_hat, pfx("b_hat"))?; opA.write_observable(&b, pfx("b_noisy")) }