src/run.rs

changeset 0
eb3c7813b67a
child 1
d4fd5f32d10e
child 2
7a953a87b6c1
child 9
21b0e537ac0e
equal deleted inserted replaced
-1:000000000000 0:eb3c7813b67a
1 /*!
2 This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment.
3 */
4
5 use numeric_literals::replace_float_literals;
6 use colored::Colorize;
7 use serde::{Serialize, Deserialize};
8 use serde_json;
9 use nalgebra::base::DVector;
10 use std::hash::Hash;
11 use chrono::{DateTime, Utc};
12 use cpu_time::ProcessTime;
13 use clap::ValueEnum;
14 use std::collections::HashMap;
15 use std::time::Instant;
16
17 use rand::prelude::{
18 StdRng,
19 SeedableRng
20 };
21 use rand_distr::Distribution;
22
23 use alg_tools::bisection_tree::*;
24 use alg_tools::iterate::{
25 Timed,
26 AlgIteratorOptions,
27 Verbose,
28 AlgIteratorFactory,
29 };
30 use alg_tools::logger::Logger;
31 use alg_tools::error::DynError;
32 use alg_tools::tabledump::TableDump;
33 use alg_tools::sets::Cube;
34 use alg_tools::mapping::RealMapping;
35 use alg_tools::nalgebra_support::ToNalgebraRealField;
36 use alg_tools::euclidean::Euclidean;
37 use alg_tools::norms::{Norm, L1};
38 use alg_tools::lingrid::lingrid;
39 use alg_tools::sets::SetOrd;
40
41 use crate::kernels::*;
42 use crate::types::*;
43 use crate::measures::*;
44 use crate::measures::merging::SpikeMerging;
45 use crate::forward_model::*;
46 use crate::fb::{
47 FBConfig,
48 pointsource_fb,
49 FBMetaAlgorithm, FBGenericConfig,
50 };
51 use crate::pdps::{
52 PDPSConfig,
53 L2Squared,
54 pointsource_pdps,
55 };
56 use crate::frank_wolfe::{
57 FWConfig,
58 FWVariant,
59 pointsource_fw,
60 prepare_optimise_weights,
61 optimise_weights,
62 };
63 use crate::subproblem::InnerSettings;
64 use crate::seminorms::*;
65 use crate::plot::*;
66 use crate::AlgorithmOverrides;
67
68 /// Available algorithms and their configurations
69 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
70 pub enum AlgorithmConfig<F : Float> {
71 FB(FBConfig<F>),
72 FW(FWConfig<F>),
73 PDPS(PDPSConfig<F>),
74 }
75
76 impl<F : ClapFloat> AlgorithmConfig<F> {
77 /// Override supported parameters based on the command line.
78 pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self {
79 let override_fb_generic = |g : FBGenericConfig<F>| {
80 FBGenericConfig {
81 bootstrap_insertions : cli.bootstrap_insertions
82 .as_ref()
83 .map_or(g.bootstrap_insertions,
84 |n| Some((n[0], n[1]))),
85 merge_every : cli.merge_every.unwrap_or(g.merge_every),
86 merging : cli.merging.clone().unwrap_or(g.merging),
87 final_merging : cli.final_merging.clone().unwrap_or(g.final_merging),
88 .. g
89 }
90 };
91
92 use AlgorithmConfig::*;
93 match self {
94 FB(fb) => FB(FBConfig {
95 τ0 : cli.tau0.unwrap_or(fb.τ0),
96 insertion : override_fb_generic(fb.insertion),
97 .. fb
98 }),
99 PDPS(pdps) => PDPS(PDPSConfig {
100 τ0 : cli.tau0.unwrap_or(pdps.τ0),
101 σ0 : cli.sigma0.unwrap_or(pdps.σ0),
102 acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
103 insertion : override_fb_generic(pdps.insertion),
104 .. pdps
105 }),
106 FW(fw) => FW(FWConfig {
107 merging : cli.merging.clone().unwrap_or(fw.merging),
108 .. fw
109 })
110 }
111 }
112 }
113
114 /// Helper struct for tagging and [`AlgorithmConfig`] or [`Experiment`] with a name.
115 #[derive(Clone, Debug, Serialize, Deserialize)]
116 pub struct Named<Data> {
117 pub name : String,
118 #[serde(flatten)]
119 pub data : Data,
120 }
121
122 /// Shorthand algorithm configurations, to be used with the command line parser
123 #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash)]
124 pub enum DefaultAlgorithm {
125 /// The μFB forward-backward method
126 #[clap(name = "fb")]
127 FB,
128 /// The μFISTA inertial forward-backward method
129 #[clap(name = "fista")]
130 FISTA,
131 /// The “fully corrective” conditional gradient method
132 #[clap(name = "fw")]
133 FW,
134 /// The “relaxed conditional gradient method
135 #[clap(name = "fwrelax")]
136 FWRelax,
137 /// The μPDPS primal-dual proximal splitting method
138 #[clap(name = "pdps")]
139 PDPS,
140 }
141
142 impl DefaultAlgorithm {
143 /// Returns the algorithm configuration corresponding to the algorithm shorthand
144 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> {
145 use DefaultAlgorithm::*;
146 match *self {
147 FB => AlgorithmConfig::FB(Default::default()),
148 FISTA => AlgorithmConfig::FB(FBConfig{
149 meta : FBMetaAlgorithm::InertiaFISTA,
150 .. Default::default()
151 }),
152 FW => AlgorithmConfig::FW(Default::default()),
153 FWRelax => AlgorithmConfig::FW(FWConfig{
154 variant : FWVariant::Relaxed,
155 .. Default::default()
156 }),
157 PDPS => AlgorithmConfig::PDPS(Default::default()),
158 }
159 }
160
161 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand
162 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> {
163 self.to_named(self.default_config())
164 }
165
166 pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> {
167 let name = self.to_possible_value().unwrap().get_name().to_string();
168 Named{ name , data : alg }
169 }
170 }
171
172
173 // // Floats cannot be hashed directly, so just hash the debug formatting
174 // // for use as file identifier.
175 // impl<F : Float> Hash for AlgorithmConfig<F> {
176 // fn hash<H: Hasher>(&self, state: &mut H) {
177 // format!("{:?}", self).hash(state);
178 // }
179 // }
180
181 /// Plotting level configuration
182 #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, ValueEnum, Debug)]
183 pub enum PlotLevel {
184 /// Plot nothing
185 #[clap(name = "none")]
186 None,
187 /// Plot problem data
188 #[clap(name = "data")]
189 Data,
190 /// Plot iterationwise state
191 #[clap(name = "iter")]
192 Iter,
193 }
194
195 /// Algorithm and iterator config for the experiments
196
197 #[derive(Clone, Debug, Serialize)]
198 #[serde(default)]
199 pub struct Configuration<F : Float> {
200 /// Algorithms to run
201 pub algorithms : Vec<Named<AlgorithmConfig<F>>>,
202 /// Options for algorithm step iteration (verbosity, etc.)
203 pub iterator_options : AlgIteratorOptions,
204 /// Plotting level
205 pub plot : PlotLevel,
206 /// Directory where to save results
207 pub outdir : String,
208 /// Bisection tree depth
209 pub bt_depth : DynamicDepth,
210 }
211
212 type DefaultBT<F, const N : usize> = BT<
213 DynamicDepth,
214 F,
215 usize,
216 Bounds<F>,
217 N
218 >;
219 type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>;
220 type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::<
221 F,
222 Sensor,
223 Spread,
224 DefaultBT<F, N>,
225 N
226 >;
227
228 /// This is a dirty workaround to rust-csv not supporting struct flattening etc.
229 #[derive(Serialize)]
230 struct CSVLog<F> {
231 iter : usize,
232 cpu_time : f64,
233 value : F,
234 post_value : F,
235 n_spikes : usize,
236 inner_iters : usize,
237 merged : usize,
238 pruned : usize,
239 this_iters : usize,
240 }
241
242 /// Collected experiment statistics
243 #[derive(Clone, Debug, Serialize)]
244 struct ExperimentStats<F : Float> {
245 /// Signal-to-noise ratio in decibels
246 ssnr : F,
247 /// Proportion of noise in the signal as a number in $[0, 1]$.
248 noise_ratio : F,
249 /// When the experiment was run (UTC)
250 when : DateTime<Utc>,
251 }
252
253 #[replace_float_literals(F::cast_from(literal))]
254 impl<F : Float> ExperimentStats<F> {
255 /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal.
256 fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self {
257 let s = signal.norm2_squared();
258 let n = noise.norm2_squared();
259 let noise_ratio = (n / s).sqrt();
260 let ssnr = 10.0 * (s / n).log10();
261 ExperimentStats {
262 ssnr,
263 noise_ratio,
264 when : Utc::now(),
265 }
266 }
267 }
268 /// Collected algorithm statistics
269 #[derive(Clone, Debug, Serialize)]
270 struct AlgorithmStats<F : Float> {
271 /// Overall CPU time spent
272 cpu_time : F,
273 /// Real time spent
274 elapsed : F
275 }
276
277
278 /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input
279 /// and outputs a [`DynError`].
280 fn write_json<T : Serialize>(filename : String, data : &T) -> DynError {
281 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?;
282 Ok(())
283 }
284
285
286 /// Struct for experiment configurations
287 #[derive(Debug, Clone, Serialize)]
288 pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize>
289 where F : Float,
290 [usize; N] : Serialize,
291 NoiseDistr : Distribution<F>,
292 S : Sensor<F, N>,
293 P : Spread<F, N>,
294 K : SimpleConvolutionKernel<F, N>,
295 {
296 /// Domain $Ω$.
297 pub domain : Cube<F, N>,
298 /// Number of sensors along each dimension
299 pub sensor_count : [usize; N],
300 /// Noise distribution
301 pub noise_distr : NoiseDistr,
302 /// Seed for random noise generation (for repeatable experiments)
303 pub noise_seed : u64,
304 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$.
305 pub sensor : S,
306 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$.
307 pub spread : P,
308 /// Kernel $ρ$ of $𝒟$.
309 pub kernel : K,
310 /// True point sources
311 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>,
312 /// Regularisation parameter
313 pub α : F,
314 /// For plotting : how wide should the kernels be plotted
315 pub kernel_plot_width : F,
316 /// Data term
317 pub dataterm : DataTerm,
318 /// A map of default configurations for algorithms
319 #[serde(skip)]
320 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>,
321 }
322
323 /// Trait for runnable experiments
324 pub trait RunnableExperiment<F : ClapFloat> {
325 /// Run all algorithms of the [`Configuration`] `config` on the experiment.
326 fn runall(&self, config : Configuration<F>) -> DynError;
327
328 /// Returns the default configuration
329 fn default_config(&self) -> Configuration<F>;
330
331 /// Return algorithm default config
332 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>)
333 -> Named<AlgorithmConfig<F>>;
334 }
335
336 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for
337 Named<Experiment<F, NoiseDistr, S, K, P, N>>
338 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
339 [usize; N] : Serialize,
340 S : Sensor<F, N> + Copy + Serialize,
341 P : Spread<F, N> + Copy + Serialize,
342 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy,
343 AutoConvolution<P> : BoundedBy<F, K>,
344 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + Copy + Serialize,
345 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
346 PlotLookup : Plotting<N>,
347 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
348 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
349 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
350 NoiseDistr : Distribution<F> + Serialize {
351
352 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>)
353 -> Named<AlgorithmConfig<F>> {
354 alg.to_named(
355 self.data
356 .algorithm_defaults
357 .get(&alg)
358 .map_or_else(|| alg.default_config(),
359 |config| config.clone())
360 .cli_override(cli)
361 )
362 }
363
364 fn default_config(&self) -> Configuration<F> {
365 let default_alg = match self.data.dataterm {
366 DataTerm::L2Squared => DefaultAlgorithm::FB.get_named(),
367 DataTerm::L1 => DefaultAlgorithm::PDPS.get_named(),
368 };
369
370 Configuration{
371 algorithms : vec![default_alg],
372 iterator_options : AlgIteratorOptions{
373 max_iter : 2000,
374 verbose_iter : Verbose::Logarithmic(10),
375 quiet : false,
376 },
377 plot : PlotLevel::Data,
378 outdir : "out".to_string(),
379 bt_depth : DynamicDepth(8),
380 }
381 }
382
383 fn runall(&self, config : Configuration<F>) -> DynError {
384 let &Named {
385 name : ref experiment_name,
386 data : Experiment {
387 domain, sensor_count, ref noise_distr, sensor, spread, kernel,
388 ref μ_hat, α, kernel_plot_width, dataterm, noise_seed,
389 ..
390 }
391 } = self;
392
393 // Set path
394 let prefix = format!("{}/{}/", config.outdir, experiment_name);
395
396 // Set up operators
397 let depth = config.bt_depth;
398 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth);
399 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel);
400
401 // Set up random number generator.
402 let mut rng = StdRng::seed_from_u64(noise_seed);
403
404 // Generate the data and calculate SSNR statistic
405 let b_hat = opA.apply(μ_hat);
406 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng);
407 let b = &b_hat + &noise;
408 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
409 // overloading log10 and conflicting with standard NumTraits one.
410 let stats = ExperimentStats::new(&b, &noise);
411
412 // Save experiment configuration and statistics
413 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t);
414 std::fs::create_dir_all(&prefix)?;
415 write_json(mkname_e("experiment"), self)?;
416 write_json(mkname_e("config"), &config)?;
417 write_json(mkname_e("stats"), &stats)?;
418
419 plotall(&config, &prefix, &domain, &sensor, &kernel, &spread,
420 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?;
421
422 // Run the algorithm(s)
423 for named @ Named { name : alg_name, data : alg } in config.algorithms.iter() {
424 let this_prefix = format!("{}{}/", prefix, alg_name);
425
426 let running = || {
427 println!("{}\n{}\n{}",
428 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(),
429 format!("{:?}", config.iterator_options).bright_black(),
430 format!("{:?}", alg).bright_black());
431 };
432
433 // Create Logger and IteratorFactory
434 let mut logger = Logger::new();
435 let findim_data = prepare_optimise_weights(&opA);
436 let inner_config : InnerSettings<F> = Default::default();
437 let inner_it = inner_config.iterator_options;
438 let logmap = |iter, Timed { cpu_time, data }| {
439 let IterInfo {
440 value,
441 n_spikes,
442 inner_iters,
443 merged,
444 pruned,
445 postprocessing,
446 this_iters,
447 ..
448 } = data;
449 let post_value = match postprocessing {
450 None => value,
451 Some(mut μ) => {
452 match dataterm {
453 DataTerm::L2Squared => {
454 optimise_weights(
455 &mut μ, &opA, &b, α, &findim_data, &inner_config,
456 inner_it
457 );
458 dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon)
459 },
460 _ => value,
461 }
462 }
463 };
464 CSVLog {
465 iter,
466 value,
467 post_value,
468 n_spikes,
469 cpu_time : cpu_time.as_secs_f64(),
470 inner_iters,
471 merged,
472 pruned,
473 this_iters
474 }
475 };
476 let iterator = config.iterator_options
477 .instantiate()
478 .timed()
479 .mapped(logmap)
480 .into_log(&mut logger);
481 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]);
482
483 // Create plotter and directory if needed.
484 let plot_count = if config.plot >= PlotLevel::Iter { 2000 } else { 0 };
485 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid);
486
487 // Run the algorithm
488 let start = Instant::now();
489 let start_cpu = ProcessTime::now();
490 let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) {
491 (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => {
492 running();
493 pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter)
494 },
495 (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => {
496 running();
497 pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter)
498 },
499 (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => {
500 running();
501 pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared)
502 },
503 (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => {
504 running();
505 pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1)
506 },
507 _ => {
508 let msg = format!("Algorithm “{}” not implemented for dataterm {:?}. Skipping.",
509 alg_name, dataterm).red();
510 eprintln!("{}", msg);
511 continue
512 }
513 };
514 let elapsed = start.elapsed().as_secs_f64();
515 let cpu_time = start_cpu.elapsed().as_secs_f64();
516
517 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow());
518
519 // Save results
520 println!("{}", "Saving results…".green());
521
522 let mkname = |
523 t| format!("{p}{n}_{t}", p = prefix, n = alg_name, t = t);
524
525 write_json(mkname("config.json"), &named)?;
526 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?;
527 μ.write_csv(mkname("reco.txt"))?;
528 logger.write_csv(mkname("log.txt"))?;
529 }
530
531 Ok(())
532 }
533 }
534
535 /// Plot experiment setup
536 #[replace_float_literals(F::cast_from(literal))]
537 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>(
538 config : &Configuration<F>,
539 prefix : &String,
540 domain : &Cube<F, N>,
541 sensor : &Sensor,
542 kernel : &Kernel,
543 spread : &Spread,
544 μ_hat : &DiscreteMeasure<Loc<F, N>, F>,
545 op𝒟 : &𝒟,
546 opA : &A,
547 b_hat : &A::Observable,
548 b : &A::Observable,
549 kernel_plot_width : F,
550 ) -> DynError
551 where F : Float + ToNalgebraRealField,
552 Sensor : RealMapping<F, N> + Support<F, N> + Clone,
553 Spread : RealMapping<F, N> + Support<F, N> + Clone,
554 Kernel : RealMapping<F, N> + Support<F, N>,
555 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>,
556 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>,
557 𝒟::Codomain : RealMapping<F, N>,
558 A : ForwardModel<Loc<F, N>, F>,
559 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>,
560 PlotLookup : Plotting<N>,
561 Cube<F, N> : SetOrd {
562
563 if config.plot < PlotLevel::Data {
564 return Ok(())
565 }
566
567 let base = Convolution(sensor.clone(), spread.clone());
568
569 let resolution = if N==1 { 100 } else { 40 };
570 let pfx = |n| format!("{}{}", prefix, n);
571 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]);
572
573 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string());
574 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string());
575 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string());
576 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string());
577
578 let plotgrid2 = lingrid(&domain, &[resolution; N]);
579
580 let ω_hat = op𝒟.apply(μ_hat);
581 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b);
582 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string());
583 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"),
584 "noise Aᵀ(Aμ̂ - b)".to_string());
585
586 let preadj_b = opA.preadjoint().apply(b);
587 let preadj_b_hat = opA.preadjoint().apply(b_hat);
588 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds());
589 PlotLookup::plot_into_file_spikes(
590 "Aᵀb".to_string(), &preadj_b,
591 "Aᵀb̂".to_string(), Some(&preadj_b_hat),
592 plotgrid2, None, &μ_hat,
593 pfx("omega_b")
594 );
595
596 // Save true solution and observables
597 let pfx = |n| format!("{}{}", prefix, n);
598 μ_hat.write_csv(pfx("orig.txt"))?;
599 opA.write_observable(&b_hat, pfx("b_hat"))?;
600 opA.write_observable(&b, pfx("b_noisy"))
601 }
602

mercurial