Tue, 29 Nov 2022 15:36:12 +0200
fubar
0 | 1 | /*! |
2 | This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. | |
3 | */ | |
4 | ||
5 | use numeric_literals::replace_float_literals; | |
6 | use colored::Colorize; | |
7 | use serde::{Serialize, Deserialize}; | |
8 | use serde_json; | |
9 | use nalgebra::base::DVector; | |
10 | use std::hash::Hash; | |
11 | use chrono::{DateTime, Utc}; | |
12 | use cpu_time::ProcessTime; | |
13 | use clap::ValueEnum; | |
14 | use std::collections::HashMap; | |
15 | use std::time::Instant; | |
16 | ||
17 | use rand::prelude::{ | |
18 | StdRng, | |
19 | SeedableRng | |
20 | }; | |
21 | use rand_distr::Distribution; | |
22 | ||
23 | use alg_tools::bisection_tree::*; | |
24 | use alg_tools::iterate::{ | |
25 | Timed, | |
26 | AlgIteratorOptions, | |
27 | Verbose, | |
28 | AlgIteratorFactory, | |
29 | }; | |
30 | use alg_tools::logger::Logger; | |
31 | use alg_tools::error::DynError; | |
32 | use alg_tools::tabledump::TableDump; | |
33 | use alg_tools::sets::Cube; | |
34 | use alg_tools::mapping::RealMapping; | |
35 | use alg_tools::nalgebra_support::ToNalgebraRealField; | |
36 | use alg_tools::euclidean::Euclidean; | |
37 | use alg_tools::norms::{Norm, L1}; | |
38 | use alg_tools::lingrid::lingrid; | |
39 | use alg_tools::sets::SetOrd; | |
40 | ||
41 | use crate::kernels::*; | |
42 | use crate::types::*; | |
43 | use crate::measures::*; | |
44 | use crate::measures::merging::SpikeMerging; | |
45 | use crate::forward_model::*; | |
46 | use crate::fb::{ | |
47 | FBConfig, | |
48 | pointsource_fb, | |
49 | FBMetaAlgorithm, FBGenericConfig, | |
50 | }; | |
51 | use crate::pdps::{ | |
52 | PDPSConfig, | |
53 | L2Squared, | |
54 | pointsource_pdps, | |
55 | }; | |
56 | use crate::frank_wolfe::{ | |
57 | FWConfig, | |
58 | FWVariant, | |
59 | pointsource_fw, | |
2 | 60 | }; |
61 | use crate::weight_optim::{ | |
0 | 62 | prepare_optimise_weights, |
2 | 63 | optimise_weights_l2, |
0 | 64 | }; |
65 | use crate::subproblem::InnerSettings; | |
66 | use crate::seminorms::*; | |
67 | use crate::plot::*; | |
68 | use crate::AlgorithmOverrides; | |
69 | ||
70 | /// Available algorithms and their configurations | |
71 | #[derive(Copy, Clone, Debug, Serialize, Deserialize)] | |
72 | pub enum AlgorithmConfig<F : Float> { | |
73 | FB(FBConfig<F>), | |
74 | FW(FWConfig<F>), | |
75 | PDPS(PDPSConfig<F>), | |
76 | } | |
77 | ||
78 | impl<F : ClapFloat> AlgorithmConfig<F> { | |
79 | /// Override supported parameters based on the command line. | |
80 | pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { | |
81 | let override_fb_generic = |g : FBGenericConfig<F>| { | |
82 | FBGenericConfig { | |
83 | bootstrap_insertions : cli.bootstrap_insertions | |
84 | .as_ref() | |
85 | .map_or(g.bootstrap_insertions, | |
86 | |n| Some((n[0], n[1]))), | |
87 | merge_every : cli.merge_every.unwrap_or(g.merge_every), | |
88 | merging : cli.merging.clone().unwrap_or(g.merging), | |
89 | final_merging : cli.final_merging.clone().unwrap_or(g.final_merging), | |
90 | .. g | |
91 | } | |
92 | }; | |
93 | ||
94 | use AlgorithmConfig::*; | |
95 | match self { | |
96 | FB(fb) => FB(FBConfig { | |
97 | τ0 : cli.tau0.unwrap_or(fb.τ0), | |
98 | insertion : override_fb_generic(fb.insertion), | |
99 | .. fb | |
100 | }), | |
101 | PDPS(pdps) => PDPS(PDPSConfig { | |
102 | τ0 : cli.tau0.unwrap_or(pdps.τ0), | |
103 | σ0 : cli.sigma0.unwrap_or(pdps.σ0), | |
104 | acceleration : cli.acceleration.unwrap_or(pdps.acceleration), | |
105 | insertion : override_fb_generic(pdps.insertion), | |
106 | .. pdps | |
107 | }), | |
108 | FW(fw) => FW(FWConfig { | |
109 | merging : cli.merging.clone().unwrap_or(fw.merging), | |
110 | .. fw | |
111 | }) | |
112 | } | |
113 | } | |
114 | } | |
115 | ||
116 | /// Helper struct for tagging and [`AlgorithmConfig`] or [`Experiment`] with a name. | |
117 | #[derive(Clone, Debug, Serialize, Deserialize)] | |
118 | pub struct Named<Data> { | |
119 | pub name : String, | |
120 | #[serde(flatten)] | |
121 | pub data : Data, | |
122 | } | |
123 | ||
124 | /// Shorthand algorithm configurations, to be used with the command line parser | |
125 | #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash)] | |
126 | pub enum DefaultAlgorithm { | |
127 | /// The μFB forward-backward method | |
128 | #[clap(name = "fb")] | |
129 | FB, | |
130 | /// The μFISTA inertial forward-backward method | |
131 | #[clap(name = "fista")] | |
132 | FISTA, | |
133 | /// The “fully corrective” conditional gradient method | |
134 | #[clap(name = "fw")] | |
135 | FW, | |
136 | /// The “relaxed conditional gradient method | |
137 | #[clap(name = "fwrelax")] | |
138 | FWRelax, | |
139 | /// The μPDPS primal-dual proximal splitting method | |
140 | #[clap(name = "pdps")] | |
141 | PDPS, | |
142 | } | |
143 | ||
144 | impl DefaultAlgorithm { | |
145 | /// Returns the algorithm configuration corresponding to the algorithm shorthand | |
146 | pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { | |
147 | use DefaultAlgorithm::*; | |
148 | match *self { | |
149 | FB => AlgorithmConfig::FB(Default::default()), | |
150 | FISTA => AlgorithmConfig::FB(FBConfig{ | |
151 | meta : FBMetaAlgorithm::InertiaFISTA, | |
152 | .. Default::default() | |
153 | }), | |
154 | FW => AlgorithmConfig::FW(Default::default()), | |
155 | FWRelax => AlgorithmConfig::FW(FWConfig{ | |
156 | variant : FWVariant::Relaxed, | |
157 | .. Default::default() | |
158 | }), | |
159 | PDPS => AlgorithmConfig::PDPS(Default::default()), | |
160 | } | |
161 | } | |
162 | ||
163 | /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand | |
164 | pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { | |
165 | self.to_named(self.default_config()) | |
166 | } | |
167 | ||
168 | pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { | |
169 | let name = self.to_possible_value().unwrap().get_name().to_string(); | |
170 | Named{ name , data : alg } | |
171 | } | |
172 | } | |
173 | ||
174 | ||
175 | // // Floats cannot be hashed directly, so just hash the debug formatting | |
176 | // // for use as file identifier. | |
177 | // impl<F : Float> Hash for AlgorithmConfig<F> { | |
178 | // fn hash<H: Hasher>(&self, state: &mut H) { | |
179 | // format!("{:?}", self).hash(state); | |
180 | // } | |
181 | // } | |
182 | ||
183 | /// Plotting level configuration | |
184 | #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, ValueEnum, Debug)] | |
185 | pub enum PlotLevel { | |
186 | /// Plot nothing | |
187 | #[clap(name = "none")] | |
188 | None, | |
189 | /// Plot problem data | |
190 | #[clap(name = "data")] | |
191 | Data, | |
192 | /// Plot iterationwise state | |
193 | #[clap(name = "iter")] | |
194 | Iter, | |
195 | } | |
196 | ||
197 | /// Algorithm and iterator config for the experiments | |
198 | ||
199 | #[derive(Clone, Debug, Serialize)] | |
200 | #[serde(default)] | |
201 | pub struct Configuration<F : Float> { | |
202 | /// Algorithms to run | |
203 | pub algorithms : Vec<Named<AlgorithmConfig<F>>>, | |
204 | /// Options for algorithm step iteration (verbosity, etc.) | |
205 | pub iterator_options : AlgIteratorOptions, | |
206 | /// Plotting level | |
207 | pub plot : PlotLevel, | |
208 | /// Directory where to save results | |
209 | pub outdir : String, | |
210 | /// Bisection tree depth | |
211 | pub bt_depth : DynamicDepth, | |
212 | } | |
213 | ||
214 | type DefaultBT<F, const N : usize> = BT< | |
215 | DynamicDepth, | |
216 | F, | |
217 | usize, | |
218 | Bounds<F>, | |
219 | N | |
220 | >; | |
221 | type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>; | |
222 | type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::< | |
223 | F, | |
224 | Sensor, | |
225 | Spread, | |
226 | DefaultBT<F, N>, | |
227 | N | |
228 | >; | |
229 | ||
230 | /// This is a dirty workaround to rust-csv not supporting struct flattening etc. | |
231 | #[derive(Serialize)] | |
232 | struct CSVLog<F> { | |
233 | iter : usize, | |
234 | cpu_time : f64, | |
235 | value : F, | |
236 | post_value : F, | |
237 | n_spikes : usize, | |
238 | inner_iters : usize, | |
239 | merged : usize, | |
240 | pruned : usize, | |
241 | this_iters : usize, | |
242 | } | |
243 | ||
244 | /// Collected experiment statistics | |
245 | #[derive(Clone, Debug, Serialize)] | |
246 | struct ExperimentStats<F : Float> { | |
247 | /// Signal-to-noise ratio in decibels | |
248 | ssnr : F, | |
249 | /// Proportion of noise in the signal as a number in $[0, 1]$. | |
250 | noise_ratio : F, | |
251 | /// When the experiment was run (UTC) | |
252 | when : DateTime<Utc>, | |
253 | } | |
254 | ||
255 | #[replace_float_literals(F::cast_from(literal))] | |
256 | impl<F : Float> ExperimentStats<F> { | |
257 | /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. | |
258 | fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self { | |
259 | let s = signal.norm2_squared(); | |
260 | let n = noise.norm2_squared(); | |
261 | let noise_ratio = (n / s).sqrt(); | |
262 | let ssnr = 10.0 * (s / n).log10(); | |
263 | ExperimentStats { | |
264 | ssnr, | |
265 | noise_ratio, | |
266 | when : Utc::now(), | |
267 | } | |
268 | } | |
269 | } | |
270 | /// Collected algorithm statistics | |
271 | #[derive(Clone, Debug, Serialize)] | |
272 | struct AlgorithmStats<F : Float> { | |
273 | /// Overall CPU time spent | |
274 | cpu_time : F, | |
275 | /// Real time spent | |
276 | elapsed : F | |
277 | } | |
278 | ||
279 | ||
280 | /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input | |
281 | /// and outputs a [`DynError`]. | |
282 | fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { | |
283 | serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; | |
284 | Ok(()) | |
285 | } | |
286 | ||
287 | ||
288 | /// Struct for experiment configurations | |
289 | #[derive(Debug, Clone, Serialize)] | |
290 | pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> | |
291 | where F : Float, | |
292 | [usize; N] : Serialize, | |
293 | NoiseDistr : Distribution<F>, | |
294 | S : Sensor<F, N>, | |
295 | P : Spread<F, N>, | |
296 | K : SimpleConvolutionKernel<F, N>, | |
297 | { | |
298 | /// Domain $Ω$. | |
299 | pub domain : Cube<F, N>, | |
300 | /// Number of sensors along each dimension | |
301 | pub sensor_count : [usize; N], | |
302 | /// Noise distribution | |
303 | pub noise_distr : NoiseDistr, | |
304 | /// Seed for random noise generation (for repeatable experiments) | |
305 | pub noise_seed : u64, | |
306 | /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. | |
307 | pub sensor : S, | |
308 | /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. | |
309 | pub spread : P, | |
310 | /// Kernel $ρ$ of $𝒟$. | |
311 | pub kernel : K, | |
312 | /// True point sources | |
313 | pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, | |
314 | /// Regularisation parameter | |
315 | pub α : F, | |
316 | /// For plotting : how wide should the kernels be plotted | |
317 | pub kernel_plot_width : F, | |
318 | /// Data term | |
319 | pub dataterm : DataTerm, | |
320 | /// A map of default configurations for algorithms | |
321 | #[serde(skip)] | |
322 | pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, | |
323 | } | |
324 | ||
325 | /// Trait for runnable experiments | |
326 | pub trait RunnableExperiment<F : ClapFloat> { | |
327 | /// Run all algorithms of the [`Configuration`] `config` on the experiment. | |
328 | fn runall(&self, config : Configuration<F>) -> DynError; | |
329 | ||
330 | /// Returns the default configuration | |
331 | fn default_config(&self) -> Configuration<F>; | |
332 | ||
333 | /// Return algorithm default config | |
334 | fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) | |
335 | -> Named<AlgorithmConfig<F>>; | |
336 | } | |
337 | ||
338 | impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for | |
339 | Named<Experiment<F, NoiseDistr, S, K, P, N>> | |
340 | where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, | |
341 | [usize; N] : Serialize, | |
342 | S : Sensor<F, N> + Copy + Serialize, | |
343 | P : Spread<F, N> + Copy + Serialize, | |
344 | Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, | |
345 | AutoConvolution<P> : BoundedBy<F, K>, | |
346 | K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + Copy + Serialize, | |
347 | Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, | |
348 | PlotLookup : Plotting<N>, | |
349 | DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, | |
350 | BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, | |
351 | DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, | |
352 | NoiseDistr : Distribution<F> + Serialize { | |
353 | ||
354 | fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) | |
355 | -> Named<AlgorithmConfig<F>> { | |
356 | alg.to_named( | |
357 | self.data | |
358 | .algorithm_defaults | |
359 | .get(&alg) | |
360 | .map_or_else(|| alg.default_config(), | |
361 | |config| config.clone()) | |
362 | .cli_override(cli) | |
363 | ) | |
364 | } | |
365 | ||
366 | fn default_config(&self) -> Configuration<F> { | |
367 | let default_alg = match self.data.dataterm { | |
368 | DataTerm::L2Squared => DefaultAlgorithm::FB.get_named(), | |
369 | DataTerm::L1 => DefaultAlgorithm::PDPS.get_named(), | |
370 | }; | |
371 | ||
372 | Configuration{ | |
373 | algorithms : vec![default_alg], | |
374 | iterator_options : AlgIteratorOptions{ | |
375 | max_iter : 2000, | |
376 | verbose_iter : Verbose::Logarithmic(10), | |
377 | quiet : false, | |
378 | }, | |
379 | plot : PlotLevel::Data, | |
380 | outdir : "out".to_string(), | |
381 | bt_depth : DynamicDepth(8), | |
382 | } | |
383 | } | |
384 | ||
385 | fn runall(&self, config : Configuration<F>) -> DynError { | |
386 | let &Named { | |
387 | name : ref experiment_name, | |
388 | data : Experiment { | |
389 | domain, sensor_count, ref noise_distr, sensor, spread, kernel, | |
390 | ref μ_hat, α, kernel_plot_width, dataterm, noise_seed, | |
391 | .. | |
392 | } | |
393 | } = self; | |
394 | ||
395 | // Set path | |
396 | let prefix = format!("{}/{}/", config.outdir, experiment_name); | |
397 | ||
398 | // Set up operators | |
399 | let depth = config.bt_depth; | |
400 | let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); | |
401 | let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); | |
402 | ||
403 | // Set up random number generator. | |
404 | let mut rng = StdRng::seed_from_u64(noise_seed); | |
405 | ||
406 | // Generate the data and calculate SSNR statistic | |
407 | let b_hat = opA.apply(μ_hat); | |
408 | let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); | |
409 | let b = &b_hat + &noise; | |
410 | // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField | |
411 | // overloading log10 and conflicting with standard NumTraits one. | |
412 | let stats = ExperimentStats::new(&b, &noise); | |
413 | ||
414 | // Save experiment configuration and statistics | |
415 | let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); | |
416 | std::fs::create_dir_all(&prefix)?; | |
417 | write_json(mkname_e("experiment"), self)?; | |
418 | write_json(mkname_e("config"), &config)?; | |
419 | write_json(mkname_e("stats"), &stats)?; | |
420 | ||
421 | plotall(&config, &prefix, &domain, &sensor, &kernel, &spread, | |
422 | &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; | |
423 | ||
424 | // Run the algorithm(s) | |
425 | for named @ Named { name : alg_name, data : alg } in config.algorithms.iter() { | |
426 | let this_prefix = format!("{}{}/", prefix, alg_name); | |
427 | ||
428 | let running = || { | |
429 | println!("{}\n{}\n{}", | |
430 | format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), | |
431 | format!("{:?}", config.iterator_options).bright_black(), | |
432 | format!("{:?}", alg).bright_black()); | |
433 | }; | |
434 | ||
435 | // Create Logger and IteratorFactory | |
436 | let mut logger = Logger::new(); | |
437 | let findim_data = prepare_optimise_weights(&opA); | |
438 | let inner_config : InnerSettings<F> = Default::default(); | |
439 | let inner_it = inner_config.iterator_options; | |
440 | let logmap = |iter, Timed { cpu_time, data }| { | |
441 | let IterInfo { | |
442 | value, | |
443 | n_spikes, | |
444 | inner_iters, | |
445 | merged, | |
446 | pruned, | |
447 | postprocessing, | |
448 | this_iters, | |
449 | .. | |
450 | } = data; | |
451 | let post_value = match postprocessing { | |
452 | None => value, | |
453 | Some(mut μ) => { | |
454 | match dataterm { | |
455 | DataTerm::L2Squared => { | |
2 | 456 | optimise_weights_l2( |
0 | 457 | &mut μ, &opA, &b, α, &findim_data, &inner_config, |
458 | inner_it | |
459 | ); | |
460 | dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon) | |
461 | }, | |
462 | _ => value, | |
463 | } | |
464 | } | |
465 | }; | |
466 | CSVLog { | |
467 | iter, | |
468 | value, | |
469 | post_value, | |
470 | n_spikes, | |
471 | cpu_time : cpu_time.as_secs_f64(), | |
472 | inner_iters, | |
473 | merged, | |
474 | pruned, | |
475 | this_iters | |
476 | } | |
477 | }; | |
478 | let iterator = config.iterator_options | |
479 | .instantiate() | |
480 | .timed() | |
481 | .mapped(logmap) | |
482 | .into_log(&mut logger); | |
483 | let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); | |
484 | ||
485 | // Create plotter and directory if needed. | |
486 | let plot_count = if config.plot >= PlotLevel::Iter { 2000 } else { 0 }; | |
487 | let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid); | |
488 | ||
489 | // Run the algorithm | |
490 | let start = Instant::now(); | |
491 | let start_cpu = ProcessTime::now(); | |
492 | let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) { | |
493 | (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => { | |
494 | running(); | |
495 | pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter) | |
496 | }, | |
497 | (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => { | |
498 | running(); | |
499 | pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter) | |
500 | }, | |
501 | (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => { | |
502 | running(); | |
503 | pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared) | |
504 | }, | |
505 | (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => { | |
506 | running(); | |
507 | pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1) | |
508 | }, | |
509 | _ => { | |
510 | let msg = format!("Algorithm “{}” not implemented for dataterm {:?}. Skipping.", | |
511 | alg_name, dataterm).red(); | |
512 | eprintln!("{}", msg); | |
513 | continue | |
514 | } | |
515 | }; | |
516 | let elapsed = start.elapsed().as_secs_f64(); | |
517 | let cpu_time = start_cpu.elapsed().as_secs_f64(); | |
518 | ||
519 | println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); | |
520 | ||
521 | // Save results | |
522 | println!("{}", "Saving results…".green()); | |
523 | ||
524 | let mkname = | | |
525 | t| format!("{p}{n}_{t}", p = prefix, n = alg_name, t = t); | |
526 | ||
527 | write_json(mkname("config.json"), &named)?; | |
528 | write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; | |
529 | μ.write_csv(mkname("reco.txt"))?; | |
530 | logger.write_csv(mkname("log.txt"))?; | |
531 | } | |
532 | ||
533 | Ok(()) | |
534 | } | |
535 | } | |
536 | ||
537 | /// Plot experiment setup | |
538 | #[replace_float_literals(F::cast_from(literal))] | |
539 | fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( | |
540 | config : &Configuration<F>, | |
541 | prefix : &String, | |
542 | domain : &Cube<F, N>, | |
543 | sensor : &Sensor, | |
544 | kernel : &Kernel, | |
545 | spread : &Spread, | |
546 | μ_hat : &DiscreteMeasure<Loc<F, N>, F>, | |
547 | op𝒟 : &𝒟, | |
548 | opA : &A, | |
549 | b_hat : &A::Observable, | |
550 | b : &A::Observable, | |
551 | kernel_plot_width : F, | |
552 | ) -> DynError | |
553 | where F : Float + ToNalgebraRealField, | |
554 | Sensor : RealMapping<F, N> + Support<F, N> + Clone, | |
555 | Spread : RealMapping<F, N> + Support<F, N> + Clone, | |
556 | Kernel : RealMapping<F, N> + Support<F, N>, | |
557 | Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, | |
558 | 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, | |
559 | 𝒟::Codomain : RealMapping<F, N>, | |
560 | A : ForwardModel<Loc<F, N>, F>, | |
561 | A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, | |
562 | PlotLookup : Plotting<N>, | |
563 | Cube<F, N> : SetOrd { | |
564 | ||
565 | if config.plot < PlotLevel::Data { | |
566 | return Ok(()) | |
567 | } | |
568 | ||
569 | let base = Convolution(sensor.clone(), spread.clone()); | |
570 | ||
571 | let resolution = if N==1 { 100 } else { 40 }; | |
572 | let pfx = |n| format!("{}{}", prefix, n); | |
573 | let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); | |
574 | ||
575 | PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); | |
576 | PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); | |
577 | PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); | |
578 | PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); | |
579 | ||
580 | let plotgrid2 = lingrid(&domain, &[resolution; N]); | |
581 | ||
582 | let ω_hat = op𝒟.apply(μ_hat); | |
583 | let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); | |
584 | PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string()); | |
585 | PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"), | |
586 | "noise Aᵀ(Aμ̂ - b)".to_string()); | |
587 | ||
588 | let preadj_b = opA.preadjoint().apply(b); | |
589 | let preadj_b_hat = opA.preadjoint().apply(b_hat); | |
590 | //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); | |
591 | PlotLookup::plot_into_file_spikes( | |
592 | "Aᵀb".to_string(), &preadj_b, | |
593 | "Aᵀb̂".to_string(), Some(&preadj_b_hat), | |
594 | plotgrid2, None, &μ_hat, | |
595 | pfx("omega_b") | |
596 | ); | |
597 | ||
598 | // Save true solution and observables | |
599 | let pfx = |n| format!("{}{}", prefix, n); | |
600 | μ_hat.write_csv(pfx("orig.txt"))?; | |
601 | opA.write_observable(&b_hat, pfx("b_hat"))?; | |
602 | opA.write_observable(&b, pfx("b_noisy")) | |
603 | } | |
604 |