Thu, 01 Dec 2022 23:46:09 +0200
Warn when trying to run an unoptimised executable
| 0 | 1 | /*! |
| 2 | This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. | |
| 3 | */ | |
| 4 | ||
| 5 | use numeric_literals::replace_float_literals; | |
| 6 | use colored::Colorize; | |
| 7 | use serde::{Serialize, Deserialize}; | |
| 8 | use serde_json; | |
| 9 | use nalgebra::base::DVector; | |
| 10 | use std::hash::Hash; | |
| 11 | use chrono::{DateTime, Utc}; | |
| 12 | use cpu_time::ProcessTime; | |
| 13 | use clap::ValueEnum; | |
| 14 | use std::collections::HashMap; | |
| 15 | use std::time::Instant; | |
| 16 | ||
| 17 | use rand::prelude::{ | |
| 18 | StdRng, | |
| 19 | SeedableRng | |
| 20 | }; | |
| 21 | use rand_distr::Distribution; | |
| 22 | ||
| 23 | use alg_tools::bisection_tree::*; | |
| 24 | use alg_tools::iterate::{ | |
| 25 | Timed, | |
| 26 | AlgIteratorOptions, | |
| 27 | Verbose, | |
| 28 | AlgIteratorFactory, | |
| 29 | }; | |
| 30 | use alg_tools::logger::Logger; | |
| 31 | use alg_tools::error::DynError; | |
| 32 | use alg_tools::tabledump::TableDump; | |
| 33 | use alg_tools::sets::Cube; | |
| 34 | use alg_tools::mapping::RealMapping; | |
| 35 | use alg_tools::nalgebra_support::ToNalgebraRealField; | |
| 36 | use alg_tools::euclidean::Euclidean; | |
| 37 | use alg_tools::norms::{Norm, L1}; | |
| 38 | use alg_tools::lingrid::lingrid; | |
| 39 | use alg_tools::sets::SetOrd; | |
| 40 | ||
| 41 | use crate::kernels::*; | |
| 42 | use crate::types::*; | |
| 43 | use crate::measures::*; | |
| 44 | use crate::measures::merging::SpikeMerging; | |
| 45 | use crate::forward_model::*; | |
| 46 | use crate::fb::{ | |
| 47 | FBConfig, | |
| 48 | pointsource_fb, | |
| 49 | FBMetaAlgorithm, FBGenericConfig, | |
| 50 | }; | |
| 51 | use crate::pdps::{ | |
| 52 | PDPSConfig, | |
| 53 | L2Squared, | |
| 54 | pointsource_pdps, | |
| 55 | }; | |
| 56 | use crate::frank_wolfe::{ | |
| 57 | FWConfig, | |
| 58 | FWVariant, | |
| 59 | pointsource_fw, | |
| 60 | prepare_optimise_weights, | |
| 61 | optimise_weights, | |
| 62 | }; | |
| 63 | use crate::subproblem::InnerSettings; | |
| 64 | use crate::seminorms::*; | |
| 65 | use crate::plot::*; | |
| 66 | use crate::AlgorithmOverrides; | |
| 67 | ||
| 68 | /// Available algorithms and their configurations | |
| 69 | #[derive(Copy, Clone, Debug, Serialize, Deserialize)] | |
| 70 | pub enum AlgorithmConfig<F : Float> { | |
| 71 | FB(FBConfig<F>), | |
| 72 | FW(FWConfig<F>), | |
| 73 | PDPS(PDPSConfig<F>), | |
| 74 | } | |
| 75 | ||
| 76 | impl<F : ClapFloat> AlgorithmConfig<F> { | |
| 77 | /// Override supported parameters based on the command line. | |
| 78 | pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { | |
| 79 | let override_fb_generic = |g : FBGenericConfig<F>| { | |
| 80 | FBGenericConfig { | |
| 81 | bootstrap_insertions : cli.bootstrap_insertions | |
| 82 | .as_ref() | |
| 83 | .map_or(g.bootstrap_insertions, | |
| 84 | |n| Some((n[0], n[1]))), | |
| 85 | merge_every : cli.merge_every.unwrap_or(g.merge_every), | |
| 86 | merging : cli.merging.clone().unwrap_or(g.merging), | |
| 87 | final_merging : cli.final_merging.clone().unwrap_or(g.final_merging), | |
| 88 | .. g | |
| 89 | } | |
| 90 | }; | |
| 91 | ||
| 92 | use AlgorithmConfig::*; | |
| 93 | match self { | |
| 94 | FB(fb) => FB(FBConfig { | |
| 95 | τ0 : cli.tau0.unwrap_or(fb.τ0), | |
| 96 | insertion : override_fb_generic(fb.insertion), | |
| 97 | .. fb | |
| 98 | }), | |
| 99 | PDPS(pdps) => PDPS(PDPSConfig { | |
| 100 | τ0 : cli.tau0.unwrap_or(pdps.τ0), | |
| 101 | σ0 : cli.sigma0.unwrap_or(pdps.σ0), | |
| 102 | acceleration : cli.acceleration.unwrap_or(pdps.acceleration), | |
| 103 | insertion : override_fb_generic(pdps.insertion), | |
| 104 | .. pdps | |
| 105 | }), | |
| 106 | FW(fw) => FW(FWConfig { | |
| 107 | merging : cli.merging.clone().unwrap_or(fw.merging), | |
| 108 | .. fw | |
| 109 | }) | |
| 110 | } | |
| 111 | } | |
| 112 | } | |
| 113 | ||
| 114 | /// Helper struct for tagging and [`AlgorithmConfig`] or [`Experiment`] with a name. | |
| 115 | #[derive(Clone, Debug, Serialize, Deserialize)] | |
| 116 | pub struct Named<Data> { | |
| 117 | pub name : String, | |
| 118 | #[serde(flatten)] | |
| 119 | pub data : Data, | |
| 120 | } | |
| 121 | ||
| 122 | /// Shorthand algorithm configurations, to be used with the command line parser | |
| 123 | #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash)] | |
| 124 | pub enum DefaultAlgorithm { | |
| 125 | /// The μFB forward-backward method | |
| 126 | #[clap(name = "fb")] | |
| 127 | FB, | |
| 128 | /// The μFISTA inertial forward-backward method | |
| 129 | #[clap(name = "fista")] | |
| 130 | FISTA, | |
| 131 | /// The “fully corrective” conditional gradient method | |
| 132 | #[clap(name = "fw")] | |
| 133 | FW, | |
| 134 | /// The “relaxed conditional gradient method | |
| 135 | #[clap(name = "fwrelax")] | |
| 136 | FWRelax, | |
| 137 | /// The μPDPS primal-dual proximal splitting method | |
| 138 | #[clap(name = "pdps")] | |
| 139 | PDPS, | |
| 140 | } | |
| 141 | ||
| 142 | impl DefaultAlgorithm { | |
| 143 | /// Returns the algorithm configuration corresponding to the algorithm shorthand | |
| 144 | pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { | |
| 145 | use DefaultAlgorithm::*; | |
| 146 | match *self { | |
| 147 | FB => AlgorithmConfig::FB(Default::default()), | |
| 148 | FISTA => AlgorithmConfig::FB(FBConfig{ | |
| 149 | meta : FBMetaAlgorithm::InertiaFISTA, | |
| 150 | .. Default::default() | |
| 151 | }), | |
| 152 | FW => AlgorithmConfig::FW(Default::default()), | |
| 153 | FWRelax => AlgorithmConfig::FW(FWConfig{ | |
| 154 | variant : FWVariant::Relaxed, | |
| 155 | .. Default::default() | |
| 156 | }), | |
| 157 | PDPS => AlgorithmConfig::PDPS(Default::default()), | |
| 158 | } | |
| 159 | } | |
| 160 | ||
| 161 | /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand | |
| 162 | pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { | |
| 163 | self.to_named(self.default_config()) | |
| 164 | } | |
| 165 | ||
| 166 | pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { | |
| 167 | let name = self.to_possible_value().unwrap().get_name().to_string(); | |
| 168 | Named{ name , data : alg } | |
| 169 | } | |
| 170 | } | |
| 171 | ||
| 172 | ||
| 173 | // // Floats cannot be hashed directly, so just hash the debug formatting | |
| 174 | // // for use as file identifier. | |
| 175 | // impl<F : Float> Hash for AlgorithmConfig<F> { | |
| 176 | // fn hash<H: Hasher>(&self, state: &mut H) { | |
| 177 | // format!("{:?}", self).hash(state); | |
| 178 | // } | |
| 179 | // } | |
| 180 | ||
| 181 | /// Plotting level configuration | |
| 182 | #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, ValueEnum, Debug)] | |
| 183 | pub enum PlotLevel { | |
| 184 | /// Plot nothing | |
| 185 | #[clap(name = "none")] | |
| 186 | None, | |
| 187 | /// Plot problem data | |
| 188 | #[clap(name = "data")] | |
| 189 | Data, | |
| 190 | /// Plot iterationwise state | |
| 191 | #[clap(name = "iter")] | |
| 192 | Iter, | |
| 193 | } | |
| 194 | ||
| 195 | /// Algorithm and iterator config for the experiments | |
| 196 | ||
| 197 | #[derive(Clone, Debug, Serialize)] | |
| 198 | #[serde(default)] | |
| 199 | pub struct Configuration<F : Float> { | |
| 200 | /// Algorithms to run | |
| 201 | pub algorithms : Vec<Named<AlgorithmConfig<F>>>, | |
| 202 | /// Options for algorithm step iteration (verbosity, etc.) | |
| 203 | pub iterator_options : AlgIteratorOptions, | |
| 204 | /// Plotting level | |
| 205 | pub plot : PlotLevel, | |
| 206 | /// Directory where to save results | |
| 207 | pub outdir : String, | |
| 208 | /// Bisection tree depth | |
| 209 | pub bt_depth : DynamicDepth, | |
| 210 | } | |
| 211 | ||
| 212 | type DefaultBT<F, const N : usize> = BT< | |
| 213 | DynamicDepth, | |
| 214 | F, | |
| 215 | usize, | |
| 216 | Bounds<F>, | |
| 217 | N | |
| 218 | >; | |
| 219 | type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>; | |
| 220 | type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::< | |
| 221 | F, | |
| 222 | Sensor, | |
| 223 | Spread, | |
| 224 | DefaultBT<F, N>, | |
| 225 | N | |
| 226 | >; | |
| 227 | ||
| 228 | /// This is a dirty workaround to rust-csv not supporting struct flattening etc. | |
| 229 | #[derive(Serialize)] | |
| 230 | struct CSVLog<F> { | |
| 231 | iter : usize, | |
| 232 | cpu_time : f64, | |
| 233 | value : F, | |
| 234 | post_value : F, | |
| 235 | n_spikes : usize, | |
| 236 | inner_iters : usize, | |
| 237 | merged : usize, | |
| 238 | pruned : usize, | |
| 239 | this_iters : usize, | |
| 240 | } | |
| 241 | ||
| 242 | /// Collected experiment statistics | |
| 243 | #[derive(Clone, Debug, Serialize)] | |
| 244 | struct ExperimentStats<F : Float> { | |
| 245 | /// Signal-to-noise ratio in decibels | |
| 246 | ssnr : F, | |
| 247 | /// Proportion of noise in the signal as a number in $[0, 1]$. | |
| 248 | noise_ratio : F, | |
| 249 | /// When the experiment was run (UTC) | |
| 250 | when : DateTime<Utc>, | |
| 251 | } | |
| 252 | ||
| 253 | #[replace_float_literals(F::cast_from(literal))] | |
| 254 | impl<F : Float> ExperimentStats<F> { | |
| 255 | /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. | |
| 256 | fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self { | |
| 257 | let s = signal.norm2_squared(); | |
| 258 | let n = noise.norm2_squared(); | |
| 259 | let noise_ratio = (n / s).sqrt(); | |
| 260 | let ssnr = 10.0 * (s / n).log10(); | |
| 261 | ExperimentStats { | |
| 262 | ssnr, | |
| 263 | noise_ratio, | |
| 264 | when : Utc::now(), | |
| 265 | } | |
| 266 | } | |
| 267 | } | |
| 268 | /// Collected algorithm statistics | |
| 269 | #[derive(Clone, Debug, Serialize)] | |
| 270 | struct AlgorithmStats<F : Float> { | |
| 271 | /// Overall CPU time spent | |
| 272 | cpu_time : F, | |
| 273 | /// Real time spent | |
| 274 | elapsed : F | |
| 275 | } | |
| 276 | ||
| 277 | ||
| 278 | /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input | |
| 279 | /// and outputs a [`DynError`]. | |
| 280 | fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { | |
| 281 | serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; | |
| 282 | Ok(()) | |
| 283 | } | |
| 284 | ||
| 285 | ||
| 286 | /// Struct for experiment configurations | |
| 287 | #[derive(Debug, Clone, Serialize)] | |
| 288 | pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> | |
| 289 | where F : Float, | |
| 290 | [usize; N] : Serialize, | |
| 291 | NoiseDistr : Distribution<F>, | |
| 292 | S : Sensor<F, N>, | |
| 293 | P : Spread<F, N>, | |
| 294 | K : SimpleConvolutionKernel<F, N>, | |
| 295 | { | |
| 296 | /// Domain $Ω$. | |
| 297 | pub domain : Cube<F, N>, | |
| 298 | /// Number of sensors along each dimension | |
| 299 | pub sensor_count : [usize; N], | |
| 300 | /// Noise distribution | |
| 301 | pub noise_distr : NoiseDistr, | |
| 302 | /// Seed for random noise generation (for repeatable experiments) | |
| 303 | pub noise_seed : u64, | |
| 304 | /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. | |
| 305 | pub sensor : S, | |
| 306 | /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. | |
| 307 | pub spread : P, | |
| 308 | /// Kernel $ρ$ of $𝒟$. | |
| 309 | pub kernel : K, | |
| 310 | /// True point sources | |
| 311 | pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, | |
| 312 | /// Regularisation parameter | |
| 313 | pub α : F, | |
| 314 | /// For plotting : how wide should the kernels be plotted | |
| 315 | pub kernel_plot_width : F, | |
| 316 | /// Data term | |
| 317 | pub dataterm : DataTerm, | |
| 318 | /// A map of default configurations for algorithms | |
| 319 | #[serde(skip)] | |
| 320 | pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, | |
| 321 | } | |
| 322 | ||
| 323 | /// Trait for runnable experiments | |
| 324 | pub trait RunnableExperiment<F : ClapFloat> { | |
| 325 | /// Run all algorithms of the [`Configuration`] `config` on the experiment. | |
| 326 | fn runall(&self, config : Configuration<F>) -> DynError; | |
| 327 | ||
| 328 | /// Returns the default configuration | |
| 329 | fn default_config(&self) -> Configuration<F>; | |
| 330 | ||
| 331 | /// Return algorithm default config | |
| 332 | fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) | |
| 333 | -> Named<AlgorithmConfig<F>>; | |
| 334 | } | |
| 335 | ||
| 336 | impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for | |
| 337 | Named<Experiment<F, NoiseDistr, S, K, P, N>> | |
| 338 | where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, | |
| 339 | [usize; N] : Serialize, | |
| 340 | S : Sensor<F, N> + Copy + Serialize, | |
| 341 | P : Spread<F, N> + Copy + Serialize, | |
| 342 | Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, | |
| 343 | AutoConvolution<P> : BoundedBy<F, K>, | |
| 344 | K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + Copy + Serialize, | |
| 345 | Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, | |
| 346 | PlotLookup : Plotting<N>, | |
| 347 | DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, | |
| 348 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, | |
| 349 | DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, | |
| 350 | NoiseDistr : Distribution<F> + Serialize { | |
| 351 | ||
| 352 | fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) | |
| 353 | -> Named<AlgorithmConfig<F>> { | |
| 354 | alg.to_named( | |
| 355 | self.data | |
| 356 | .algorithm_defaults | |
| 357 | .get(&alg) | |
| 358 | .map_or_else(|| alg.default_config(), | |
| 359 | |config| config.clone()) | |
| 360 | .cli_override(cli) | |
| 361 | ) | |
| 362 | } | |
| 363 | ||
| 364 | fn default_config(&self) -> Configuration<F> { | |
| 365 | let default_alg = match self.data.dataterm { | |
| 366 | DataTerm::L2Squared => DefaultAlgorithm::FB.get_named(), | |
| 367 | DataTerm::L1 => DefaultAlgorithm::PDPS.get_named(), | |
| 368 | }; | |
| 369 | ||
| 370 | Configuration{ | |
| 371 | algorithms : vec![default_alg], | |
| 372 | iterator_options : AlgIteratorOptions{ | |
| 373 | max_iter : 2000, | |
| 374 | verbose_iter : Verbose::Logarithmic(10), | |
| 375 | quiet : false, | |
| 376 | }, | |
| 377 | plot : PlotLevel::Data, | |
| 378 | outdir : "out".to_string(), | |
| 379 | bt_depth : DynamicDepth(8), | |
| 380 | } | |
| 381 | } | |
| 382 | ||
| 383 | fn runall(&self, config : Configuration<F>) -> DynError { | |
| 384 | let &Named { | |
| 385 | name : ref experiment_name, | |
| 386 | data : Experiment { | |
| 387 | domain, sensor_count, ref noise_distr, sensor, spread, kernel, | |
| 388 | ref μ_hat, α, kernel_plot_width, dataterm, noise_seed, | |
| 389 | .. | |
| 390 | } | |
| 391 | } = self; | |
| 392 | ||
| 393 | // Set path | |
| 394 | let prefix = format!("{}/{}/", config.outdir, experiment_name); | |
| 395 | ||
| 396 | // Set up operators | |
| 397 | let depth = config.bt_depth; | |
| 398 | let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); | |
| 399 | let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); | |
| 400 | ||
| 401 | // Set up random number generator. | |
| 402 | let mut rng = StdRng::seed_from_u64(noise_seed); | |
| 403 | ||
| 404 | // Generate the data and calculate SSNR statistic | |
| 405 | let b_hat = opA.apply(μ_hat); | |
| 406 | let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); | |
| 407 | let b = &b_hat + &noise; | |
| 408 | // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField | |
| 409 | // overloading log10 and conflicting with standard NumTraits one. | |
| 410 | let stats = ExperimentStats::new(&b, &noise); | |
| 411 | ||
| 412 | // Save experiment configuration and statistics | |
| 413 | let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); | |
| 414 | std::fs::create_dir_all(&prefix)?; | |
| 415 | write_json(mkname_e("experiment"), self)?; | |
| 416 | write_json(mkname_e("config"), &config)?; | |
| 417 | write_json(mkname_e("stats"), &stats)?; | |
| 418 | ||
| 419 | plotall(&config, &prefix, &domain, &sensor, &kernel, &spread, | |
| 420 | &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; | |
| 421 | ||
| 422 | // Run the algorithm(s) | |
| 423 | for named @ Named { name : alg_name, data : alg } in config.algorithms.iter() { | |
| 424 | let this_prefix = format!("{}{}/", prefix, alg_name); | |
| 425 | ||
| 426 | let running = || { | |
| 427 | println!("{}\n{}\n{}", | |
| 428 | format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), | |
| 429 | format!("{:?}", config.iterator_options).bright_black(), | |
| 430 | format!("{:?}", alg).bright_black()); | |
| 431 | }; | |
| 432 | ||
| 433 | // Create Logger and IteratorFactory | |
| 434 | let mut logger = Logger::new(); | |
| 435 | let findim_data = prepare_optimise_weights(&opA); | |
| 436 | let inner_config : InnerSettings<F> = Default::default(); | |
| 437 | let inner_it = inner_config.iterator_options; | |
| 438 | let logmap = |iter, Timed { cpu_time, data }| { | |
| 439 | let IterInfo { | |
| 440 | value, | |
| 441 | n_spikes, | |
| 442 | inner_iters, | |
| 443 | merged, | |
| 444 | pruned, | |
| 445 | postprocessing, | |
| 446 | this_iters, | |
| 447 | .. | |
| 448 | } = data; | |
| 449 | let post_value = match postprocessing { | |
| 450 | None => value, | |
| 451 | Some(mut μ) => { | |
| 452 | match dataterm { | |
| 453 | DataTerm::L2Squared => { | |
| 454 | optimise_weights( | |
| 455 | &mut μ, &opA, &b, α, &findim_data, &inner_config, | |
| 456 | inner_it | |
| 457 | ); | |
| 458 | dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon) | |
| 459 | }, | |
| 460 | _ => value, | |
| 461 | } | |
| 462 | } | |
| 463 | }; | |
| 464 | CSVLog { | |
| 465 | iter, | |
| 466 | value, | |
| 467 | post_value, | |
| 468 | n_spikes, | |
| 469 | cpu_time : cpu_time.as_secs_f64(), | |
| 470 | inner_iters, | |
| 471 | merged, | |
| 472 | pruned, | |
| 473 | this_iters | |
| 474 | } | |
| 475 | }; | |
| 476 | let iterator = config.iterator_options | |
| 477 | .instantiate() | |
| 478 | .timed() | |
| 479 | .mapped(logmap) | |
| 480 | .into_log(&mut logger); | |
| 481 | let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); | |
| 482 | ||
| 483 | // Create plotter and directory if needed. | |
| 484 | let plot_count = if config.plot >= PlotLevel::Iter { 2000 } else { 0 }; | |
| 485 | let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid); | |
| 486 | ||
| 487 | // Run the algorithm | |
| 488 | let start = Instant::now(); | |
| 489 | let start_cpu = ProcessTime::now(); | |
| 490 | let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) { | |
| 491 | (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => { | |
| 492 | running(); | |
| 493 | pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter) | |
| 494 | }, | |
| 495 | (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => { | |
| 496 | running(); | |
| 497 | pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter) | |
| 498 | }, | |
| 499 | (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => { | |
| 500 | running(); | |
| 501 | pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared) | |
| 502 | }, | |
| 503 | (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => { | |
| 504 | running(); | |
| 505 | pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1) | |
| 506 | }, | |
| 507 | _ => { | |
| 508 | let msg = format!("Algorithm “{}” not implemented for dataterm {:?}. Skipping.", | |
| 509 | alg_name, dataterm).red(); | |
| 510 | eprintln!("{}", msg); | |
| 511 | continue | |
| 512 | } | |
| 513 | }; | |
| 514 | let elapsed = start.elapsed().as_secs_f64(); | |
| 515 | let cpu_time = start_cpu.elapsed().as_secs_f64(); | |
| 516 | ||
| 517 | println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); | |
| 518 | ||
| 519 | // Save results | |
| 520 | println!("{}", "Saving results…".green()); | |
| 521 | ||
| 522 | let mkname = | | |
| 523 | t| format!("{p}{n}_{t}", p = prefix, n = alg_name, t = t); | |
| 524 | ||
| 525 | write_json(mkname("config.json"), &named)?; | |
| 526 | write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; | |
| 527 | μ.write_csv(mkname("reco.txt"))?; | |
| 528 | logger.write_csv(mkname("log.txt"))?; | |
| 529 | } | |
| 530 | ||
| 531 | Ok(()) | |
| 532 | } | |
| 533 | } | |
| 534 | ||
| 535 | /// Plot experiment setup | |
| 536 | #[replace_float_literals(F::cast_from(literal))] | |
| 537 | fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( | |
| 538 | config : &Configuration<F>, | |
| 539 | prefix : &String, | |
| 540 | domain : &Cube<F, N>, | |
| 541 | sensor : &Sensor, | |
| 542 | kernel : &Kernel, | |
| 543 | spread : &Spread, | |
| 544 | μ_hat : &DiscreteMeasure<Loc<F, N>, F>, | |
| 545 | op𝒟 : &𝒟, | |
| 546 | opA : &A, | |
| 547 | b_hat : &A::Observable, | |
| 548 | b : &A::Observable, | |
| 549 | kernel_plot_width : F, | |
| 550 | ) -> DynError | |
| 551 | where F : Float + ToNalgebraRealField, | |
| 552 | Sensor : RealMapping<F, N> + Support<F, N> + Clone, | |
| 553 | Spread : RealMapping<F, N> + Support<F, N> + Clone, | |
| 554 | Kernel : RealMapping<F, N> + Support<F, N>, | |
| 555 | Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, | |
| 556 | 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, | |
| 557 | 𝒟::Codomain : RealMapping<F, N>, | |
| 558 | A : ForwardModel<Loc<F, N>, F>, | |
| 559 | A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, | |
| 560 | PlotLookup : Plotting<N>, | |
| 561 | Cube<F, N> : SetOrd { | |
| 562 | ||
| 563 | if config.plot < PlotLevel::Data { | |
| 564 | return Ok(()) | |
| 565 | } | |
| 566 | ||
| 567 | let base = Convolution(sensor.clone(), spread.clone()); | |
| 568 | ||
| 569 | let resolution = if N==1 { 100 } else { 40 }; | |
| 570 | let pfx = |n| format!("{}{}", prefix, n); | |
| 571 | let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); | |
| 572 | ||
| 573 | PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); | |
| 574 | PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); | |
| 575 | PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); | |
| 576 | PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); | |
| 577 | ||
| 578 | let plotgrid2 = lingrid(&domain, &[resolution; N]); | |
| 579 | ||
| 580 | let ω_hat = op𝒟.apply(μ_hat); | |
| 581 | let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); | |
| 582 | PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string()); | |
| 583 | PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"), | |
| 584 | "noise Aᵀ(Aμ̂ - b)".to_string()); | |
| 585 | ||
| 586 | let preadj_b = opA.preadjoint().apply(b); | |
| 587 | let preadj_b_hat = opA.preadjoint().apply(b_hat); | |
| 588 | //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); | |
| 589 | PlotLookup::plot_into_file_spikes( | |
| 590 | "Aᵀb".to_string(), &preadj_b, | |
| 591 | "Aᵀb̂".to_string(), Some(&preadj_b_hat), | |
| 592 | plotgrid2, None, &μ_hat, | |
| 593 | pfx("omega_b") | |
| 594 | ); | |
| 595 | ||
| 596 | // Save true solution and observables | |
| 597 | let pfx = |n| format!("{}{}", prefix, n); | |
| 598 | μ_hat.write_csv(pfx("orig.txt"))?; | |
| 599 | opA.write_observable(&b_hat, pfx("b_hat"))?; | |
| 600 | opA.write_observable(&b, pfx("b_noisy")) | |
| 601 | } | |
| 602 |