|
1 /*! |
|
2 This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. |
|
3 */ |
|
4 |
|
5 use numeric_literals::replace_float_literals; |
|
6 use colored::Colorize; |
|
7 use serde::{Serialize, Deserialize}; |
|
8 use serde_json; |
|
9 use nalgebra::base::DVector; |
|
10 use std::hash::Hash; |
|
11 use chrono::{DateTime, Utc}; |
|
12 use cpu_time::ProcessTime; |
|
13 use clap::ValueEnum; |
|
14 use std::collections::HashMap; |
|
15 use std::time::Instant; |
|
16 |
|
17 use rand::prelude::{ |
|
18 StdRng, |
|
19 SeedableRng |
|
20 }; |
|
21 use rand_distr::Distribution; |
|
22 |
|
23 use alg_tools::bisection_tree::*; |
|
24 use alg_tools::iterate::{ |
|
25 Timed, |
|
26 AlgIteratorOptions, |
|
27 Verbose, |
|
28 AlgIteratorFactory, |
|
29 }; |
|
30 use alg_tools::logger::Logger; |
|
31 use alg_tools::error::DynError; |
|
32 use alg_tools::tabledump::TableDump; |
|
33 use alg_tools::sets::Cube; |
|
34 use alg_tools::mapping::RealMapping; |
|
35 use alg_tools::nalgebra_support::ToNalgebraRealField; |
|
36 use alg_tools::euclidean::Euclidean; |
|
37 use alg_tools::norms::{Norm, L1}; |
|
38 use alg_tools::lingrid::lingrid; |
|
39 use alg_tools::sets::SetOrd; |
|
40 |
|
41 use crate::kernels::*; |
|
42 use crate::types::*; |
|
43 use crate::measures::*; |
|
44 use crate::measures::merging::SpikeMerging; |
|
45 use crate::forward_model::*; |
|
46 use crate::fb::{ |
|
47 FBConfig, |
|
48 pointsource_fb, |
|
49 FBMetaAlgorithm, FBGenericConfig, |
|
50 }; |
|
51 use crate::pdps::{ |
|
52 PDPSConfig, |
|
53 L2Squared, |
|
54 pointsource_pdps, |
|
55 }; |
|
56 use crate::frank_wolfe::{ |
|
57 FWConfig, |
|
58 FWVariant, |
|
59 pointsource_fw, |
|
60 prepare_optimise_weights, |
|
61 optimise_weights, |
|
62 }; |
|
63 use crate::subproblem::InnerSettings; |
|
64 use crate::seminorms::*; |
|
65 use crate::plot::*; |
|
66 use crate::AlgorithmOverrides; |
|
67 |
|
68 /// Available algorithms and their configurations |
|
69 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
|
70 pub enum AlgorithmConfig<F : Float> { |
|
71 FB(FBConfig<F>), |
|
72 FW(FWConfig<F>), |
|
73 PDPS(PDPSConfig<F>), |
|
74 } |
|
75 |
|
76 impl<F : ClapFloat> AlgorithmConfig<F> { |
|
77 /// Override supported parameters based on the command line. |
|
78 pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { |
|
79 let override_fb_generic = |g : FBGenericConfig<F>| { |
|
80 FBGenericConfig { |
|
81 bootstrap_insertions : cli.bootstrap_insertions |
|
82 .as_ref() |
|
83 .map_or(g.bootstrap_insertions, |
|
84 |n| Some((n[0], n[1]))), |
|
85 merge_every : cli.merge_every.unwrap_or(g.merge_every), |
|
86 merging : cli.merging.clone().unwrap_or(g.merging), |
|
87 final_merging : cli.final_merging.clone().unwrap_or(g.final_merging), |
|
88 .. g |
|
89 } |
|
90 }; |
|
91 |
|
92 use AlgorithmConfig::*; |
|
93 match self { |
|
94 FB(fb) => FB(FBConfig { |
|
95 τ0 : cli.tau0.unwrap_or(fb.τ0), |
|
96 insertion : override_fb_generic(fb.insertion), |
|
97 .. fb |
|
98 }), |
|
99 PDPS(pdps) => PDPS(PDPSConfig { |
|
100 τ0 : cli.tau0.unwrap_or(pdps.τ0), |
|
101 σ0 : cli.sigma0.unwrap_or(pdps.σ0), |
|
102 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
|
103 insertion : override_fb_generic(pdps.insertion), |
|
104 .. pdps |
|
105 }), |
|
106 FW(fw) => FW(FWConfig { |
|
107 merging : cli.merging.clone().unwrap_or(fw.merging), |
|
108 .. fw |
|
109 }) |
|
110 } |
|
111 } |
|
112 } |
|
113 |
|
114 /// Helper struct for tagging and [`AlgorithmConfig`] or [`Experiment`] with a name. |
|
115 #[derive(Clone, Debug, Serialize, Deserialize)] |
|
116 pub struct Named<Data> { |
|
117 pub name : String, |
|
118 #[serde(flatten)] |
|
119 pub data : Data, |
|
120 } |
|
121 |
|
122 /// Shorthand algorithm configurations, to be used with the command line parser |
|
123 #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash)] |
|
124 pub enum DefaultAlgorithm { |
|
125 /// The μFB forward-backward method |
|
126 #[clap(name = "fb")] |
|
127 FB, |
|
128 /// The μFISTA inertial forward-backward method |
|
129 #[clap(name = "fista")] |
|
130 FISTA, |
|
131 /// The “fully corrective” conditional gradient method |
|
132 #[clap(name = "fw")] |
|
133 FW, |
|
134 /// The “relaxed conditional gradient method |
|
135 #[clap(name = "fwrelax")] |
|
136 FWRelax, |
|
137 /// The μPDPS primal-dual proximal splitting method |
|
138 #[clap(name = "pdps")] |
|
139 PDPS, |
|
140 } |
|
141 |
|
142 impl DefaultAlgorithm { |
|
143 /// Returns the algorithm configuration corresponding to the algorithm shorthand |
|
144 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { |
|
145 use DefaultAlgorithm::*; |
|
146 match *self { |
|
147 FB => AlgorithmConfig::FB(Default::default()), |
|
148 FISTA => AlgorithmConfig::FB(FBConfig{ |
|
149 meta : FBMetaAlgorithm::InertiaFISTA, |
|
150 .. Default::default() |
|
151 }), |
|
152 FW => AlgorithmConfig::FW(Default::default()), |
|
153 FWRelax => AlgorithmConfig::FW(FWConfig{ |
|
154 variant : FWVariant::Relaxed, |
|
155 .. Default::default() |
|
156 }), |
|
157 PDPS => AlgorithmConfig::PDPS(Default::default()), |
|
158 } |
|
159 } |
|
160 |
|
161 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand |
|
162 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { |
|
163 self.to_named(self.default_config()) |
|
164 } |
|
165 |
|
166 pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { |
|
167 let name = self.to_possible_value().unwrap().get_name().to_string(); |
|
168 Named{ name , data : alg } |
|
169 } |
|
170 } |
|
171 |
|
172 |
|
173 // // Floats cannot be hashed directly, so just hash the debug formatting |
|
174 // // for use as file identifier. |
|
175 // impl<F : Float> Hash for AlgorithmConfig<F> { |
|
176 // fn hash<H: Hasher>(&self, state: &mut H) { |
|
177 // format!("{:?}", self).hash(state); |
|
178 // } |
|
179 // } |
|
180 |
|
181 /// Plotting level configuration |
|
182 #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, ValueEnum, Debug)] |
|
183 pub enum PlotLevel { |
|
184 /// Plot nothing |
|
185 #[clap(name = "none")] |
|
186 None, |
|
187 /// Plot problem data |
|
188 #[clap(name = "data")] |
|
189 Data, |
|
190 /// Plot iterationwise state |
|
191 #[clap(name = "iter")] |
|
192 Iter, |
|
193 } |
|
194 |
|
195 /// Algorithm and iterator config for the experiments |
|
196 |
|
197 #[derive(Clone, Debug, Serialize)] |
|
198 #[serde(default)] |
|
199 pub struct Configuration<F : Float> { |
|
200 /// Algorithms to run |
|
201 pub algorithms : Vec<Named<AlgorithmConfig<F>>>, |
|
202 /// Options for algorithm step iteration (verbosity, etc.) |
|
203 pub iterator_options : AlgIteratorOptions, |
|
204 /// Plotting level |
|
205 pub plot : PlotLevel, |
|
206 /// Directory where to save results |
|
207 pub outdir : String, |
|
208 /// Bisection tree depth |
|
209 pub bt_depth : DynamicDepth, |
|
210 } |
|
211 |
|
212 type DefaultBT<F, const N : usize> = BT< |
|
213 DynamicDepth, |
|
214 F, |
|
215 usize, |
|
216 Bounds<F>, |
|
217 N |
|
218 >; |
|
219 type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>; |
|
220 type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::< |
|
221 F, |
|
222 Sensor, |
|
223 Spread, |
|
224 DefaultBT<F, N>, |
|
225 N |
|
226 >; |
|
227 |
|
228 /// This is a dirty workaround to rust-csv not supporting struct flattening etc. |
|
229 #[derive(Serialize)] |
|
230 struct CSVLog<F> { |
|
231 iter : usize, |
|
232 cpu_time : f64, |
|
233 value : F, |
|
234 post_value : F, |
|
235 n_spikes : usize, |
|
236 inner_iters : usize, |
|
237 merged : usize, |
|
238 pruned : usize, |
|
239 this_iters : usize, |
|
240 } |
|
241 |
|
242 /// Collected experiment statistics |
|
243 #[derive(Clone, Debug, Serialize)] |
|
244 struct ExperimentStats<F : Float> { |
|
245 /// Signal-to-noise ratio in decibels |
|
246 ssnr : F, |
|
247 /// Proportion of noise in the signal as a number in $[0, 1]$. |
|
248 noise_ratio : F, |
|
249 /// When the experiment was run (UTC) |
|
250 when : DateTime<Utc>, |
|
251 } |
|
252 |
|
253 #[replace_float_literals(F::cast_from(literal))] |
|
254 impl<F : Float> ExperimentStats<F> { |
|
255 /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. |
|
256 fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self { |
|
257 let s = signal.norm2_squared(); |
|
258 let n = noise.norm2_squared(); |
|
259 let noise_ratio = (n / s).sqrt(); |
|
260 let ssnr = 10.0 * (s / n).log10(); |
|
261 ExperimentStats { |
|
262 ssnr, |
|
263 noise_ratio, |
|
264 when : Utc::now(), |
|
265 } |
|
266 } |
|
267 } |
|
268 /// Collected algorithm statistics |
|
269 #[derive(Clone, Debug, Serialize)] |
|
270 struct AlgorithmStats<F : Float> { |
|
271 /// Overall CPU time spent |
|
272 cpu_time : F, |
|
273 /// Real time spent |
|
274 elapsed : F |
|
275 } |
|
276 |
|
277 |
|
278 /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input |
|
279 /// and outputs a [`DynError`]. |
|
280 fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { |
|
281 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; |
|
282 Ok(()) |
|
283 } |
|
284 |
|
285 |
|
286 /// Struct for experiment configurations |
|
287 #[derive(Debug, Clone, Serialize)] |
|
288 pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> |
|
289 where F : Float, |
|
290 [usize; N] : Serialize, |
|
291 NoiseDistr : Distribution<F>, |
|
292 S : Sensor<F, N>, |
|
293 P : Spread<F, N>, |
|
294 K : SimpleConvolutionKernel<F, N>, |
|
295 { |
|
296 /// Domain $Ω$. |
|
297 pub domain : Cube<F, N>, |
|
298 /// Number of sensors along each dimension |
|
299 pub sensor_count : [usize; N], |
|
300 /// Noise distribution |
|
301 pub noise_distr : NoiseDistr, |
|
302 /// Seed for random noise generation (for repeatable experiments) |
|
303 pub noise_seed : u64, |
|
304 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. |
|
305 pub sensor : S, |
|
306 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. |
|
307 pub spread : P, |
|
308 /// Kernel $ρ$ of $𝒟$. |
|
309 pub kernel : K, |
|
310 /// True point sources |
|
311 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, |
|
312 /// Regularisation parameter |
|
313 pub α : F, |
|
314 /// For plotting : how wide should the kernels be plotted |
|
315 pub kernel_plot_width : F, |
|
316 /// Data term |
|
317 pub dataterm : DataTerm, |
|
318 /// A map of default configurations for algorithms |
|
319 #[serde(skip)] |
|
320 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, |
|
321 } |
|
322 |
|
323 /// Trait for runnable experiments |
|
324 pub trait RunnableExperiment<F : ClapFloat> { |
|
325 /// Run all algorithms of the [`Configuration`] `config` on the experiment. |
|
326 fn runall(&self, config : Configuration<F>) -> DynError; |
|
327 |
|
328 /// Returns the default configuration |
|
329 fn default_config(&self) -> Configuration<F>; |
|
330 |
|
331 /// Return algorithm default config |
|
332 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
|
333 -> Named<AlgorithmConfig<F>>; |
|
334 } |
|
335 |
|
336 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for |
|
337 Named<Experiment<F, NoiseDistr, S, K, P, N>> |
|
338 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, |
|
339 [usize; N] : Serialize, |
|
340 S : Sensor<F, N> + Copy + Serialize, |
|
341 P : Spread<F, N> + Copy + Serialize, |
|
342 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, |
|
343 AutoConvolution<P> : BoundedBy<F, K>, |
|
344 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + Copy + Serialize, |
|
345 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
|
346 PlotLookup : Plotting<N>, |
|
347 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
|
348 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
349 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
350 NoiseDistr : Distribution<F> + Serialize { |
|
351 |
|
352 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
|
353 -> Named<AlgorithmConfig<F>> { |
|
354 alg.to_named( |
|
355 self.data |
|
356 .algorithm_defaults |
|
357 .get(&alg) |
|
358 .map_or_else(|| alg.default_config(), |
|
359 |config| config.clone()) |
|
360 .cli_override(cli) |
|
361 ) |
|
362 } |
|
363 |
|
364 fn default_config(&self) -> Configuration<F> { |
|
365 let default_alg = match self.data.dataterm { |
|
366 DataTerm::L2Squared => DefaultAlgorithm::FB.get_named(), |
|
367 DataTerm::L1 => DefaultAlgorithm::PDPS.get_named(), |
|
368 }; |
|
369 |
|
370 Configuration{ |
|
371 algorithms : vec![default_alg], |
|
372 iterator_options : AlgIteratorOptions{ |
|
373 max_iter : 2000, |
|
374 verbose_iter : Verbose::Logarithmic(10), |
|
375 quiet : false, |
|
376 }, |
|
377 plot : PlotLevel::Data, |
|
378 outdir : "out".to_string(), |
|
379 bt_depth : DynamicDepth(8), |
|
380 } |
|
381 } |
|
382 |
|
383 fn runall(&self, config : Configuration<F>) -> DynError { |
|
384 let &Named { |
|
385 name : ref experiment_name, |
|
386 data : Experiment { |
|
387 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
|
388 ref μ_hat, α, kernel_plot_width, dataterm, noise_seed, |
|
389 .. |
|
390 } |
|
391 } = self; |
|
392 |
|
393 // Set path |
|
394 let prefix = format!("{}/{}/", config.outdir, experiment_name); |
|
395 |
|
396 // Set up operators |
|
397 let depth = config.bt_depth; |
|
398 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); |
|
399 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); |
|
400 |
|
401 // Set up random number generator. |
|
402 let mut rng = StdRng::seed_from_u64(noise_seed); |
|
403 |
|
404 // Generate the data and calculate SSNR statistic |
|
405 let b_hat = opA.apply(μ_hat); |
|
406 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); |
|
407 let b = &b_hat + &noise; |
|
408 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
|
409 // overloading log10 and conflicting with standard NumTraits one. |
|
410 let stats = ExperimentStats::new(&b, &noise); |
|
411 |
|
412 // Save experiment configuration and statistics |
|
413 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); |
|
414 std::fs::create_dir_all(&prefix)?; |
|
415 write_json(mkname_e("experiment"), self)?; |
|
416 write_json(mkname_e("config"), &config)?; |
|
417 write_json(mkname_e("stats"), &stats)?; |
|
418 |
|
419 plotall(&config, &prefix, &domain, &sensor, &kernel, &spread, |
|
420 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
|
421 |
|
422 // Run the algorithm(s) |
|
423 for named @ Named { name : alg_name, data : alg } in config.algorithms.iter() { |
|
424 let this_prefix = format!("{}{}/", prefix, alg_name); |
|
425 |
|
426 let running = || { |
|
427 println!("{}\n{}\n{}", |
|
428 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), |
|
429 format!("{:?}", config.iterator_options).bright_black(), |
|
430 format!("{:?}", alg).bright_black()); |
|
431 }; |
|
432 |
|
433 // Create Logger and IteratorFactory |
|
434 let mut logger = Logger::new(); |
|
435 let findim_data = prepare_optimise_weights(&opA); |
|
436 let inner_config : InnerSettings<F> = Default::default(); |
|
437 let inner_it = inner_config.iterator_options; |
|
438 let logmap = |iter, Timed { cpu_time, data }| { |
|
439 let IterInfo { |
|
440 value, |
|
441 n_spikes, |
|
442 inner_iters, |
|
443 merged, |
|
444 pruned, |
|
445 postprocessing, |
|
446 this_iters, |
|
447 .. |
|
448 } = data; |
|
449 let post_value = match postprocessing { |
|
450 None => value, |
|
451 Some(mut μ) => { |
|
452 match dataterm { |
|
453 DataTerm::L2Squared => { |
|
454 optimise_weights( |
|
455 &mut μ, &opA, &b, α, &findim_data, &inner_config, |
|
456 inner_it |
|
457 ); |
|
458 dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon) |
|
459 }, |
|
460 _ => value, |
|
461 } |
|
462 } |
|
463 }; |
|
464 CSVLog { |
|
465 iter, |
|
466 value, |
|
467 post_value, |
|
468 n_spikes, |
|
469 cpu_time : cpu_time.as_secs_f64(), |
|
470 inner_iters, |
|
471 merged, |
|
472 pruned, |
|
473 this_iters |
|
474 } |
|
475 }; |
|
476 let iterator = config.iterator_options |
|
477 .instantiate() |
|
478 .timed() |
|
479 .mapped(logmap) |
|
480 .into_log(&mut logger); |
|
481 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
|
482 |
|
483 // Create plotter and directory if needed. |
|
484 let plot_count = if config.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
|
485 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid); |
|
486 |
|
487 // Run the algorithm |
|
488 let start = Instant::now(); |
|
489 let start_cpu = ProcessTime::now(); |
|
490 let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) { |
|
491 (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => { |
|
492 running(); |
|
493 pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter) |
|
494 }, |
|
495 (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => { |
|
496 running(); |
|
497 pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter) |
|
498 }, |
|
499 (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => { |
|
500 running(); |
|
501 pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared) |
|
502 }, |
|
503 (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => { |
|
504 running(); |
|
505 pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1) |
|
506 }, |
|
507 _ => { |
|
508 let msg = format!("Algorithm “{}” not implemented for dataterm {:?}. Skipping.", |
|
509 alg_name, dataterm).red(); |
|
510 eprintln!("{}", msg); |
|
511 continue |
|
512 } |
|
513 }; |
|
514 let elapsed = start.elapsed().as_secs_f64(); |
|
515 let cpu_time = start_cpu.elapsed().as_secs_f64(); |
|
516 |
|
517 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); |
|
518 |
|
519 // Save results |
|
520 println!("{}", "Saving results…".green()); |
|
521 |
|
522 let mkname = | |
|
523 t| format!("{p}{n}_{t}", p = prefix, n = alg_name, t = t); |
|
524 |
|
525 write_json(mkname("config.json"), &named)?; |
|
526 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; |
|
527 μ.write_csv(mkname("reco.txt"))?; |
|
528 logger.write_csv(mkname("log.txt"))?; |
|
529 } |
|
530 |
|
531 Ok(()) |
|
532 } |
|
533 } |
|
534 |
|
535 /// Plot experiment setup |
|
536 #[replace_float_literals(F::cast_from(literal))] |
|
537 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( |
|
538 config : &Configuration<F>, |
|
539 prefix : &String, |
|
540 domain : &Cube<F, N>, |
|
541 sensor : &Sensor, |
|
542 kernel : &Kernel, |
|
543 spread : &Spread, |
|
544 μ_hat : &DiscreteMeasure<Loc<F, N>, F>, |
|
545 op𝒟 : &𝒟, |
|
546 opA : &A, |
|
547 b_hat : &A::Observable, |
|
548 b : &A::Observable, |
|
549 kernel_plot_width : F, |
|
550 ) -> DynError |
|
551 where F : Float + ToNalgebraRealField, |
|
552 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
|
553 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
|
554 Kernel : RealMapping<F, N> + Support<F, N>, |
|
555 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, |
|
556 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
|
557 𝒟::Codomain : RealMapping<F, N>, |
|
558 A : ForwardModel<Loc<F, N>, F>, |
|
559 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, |
|
560 PlotLookup : Plotting<N>, |
|
561 Cube<F, N> : SetOrd { |
|
562 |
|
563 if config.plot < PlotLevel::Data { |
|
564 return Ok(()) |
|
565 } |
|
566 |
|
567 let base = Convolution(sensor.clone(), spread.clone()); |
|
568 |
|
569 let resolution = if N==1 { 100 } else { 40 }; |
|
570 let pfx = |n| format!("{}{}", prefix, n); |
|
571 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
|
572 |
|
573 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); |
|
574 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); |
|
575 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); |
|
576 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); |
|
577 |
|
578 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
|
579 |
|
580 let ω_hat = op𝒟.apply(μ_hat); |
|
581 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
|
582 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string()); |
|
583 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"), |
|
584 "noise Aᵀ(Aμ̂ - b)".to_string()); |
|
585 |
|
586 let preadj_b = opA.preadjoint().apply(b); |
|
587 let preadj_b_hat = opA.preadjoint().apply(b_hat); |
|
588 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); |
|
589 PlotLookup::plot_into_file_spikes( |
|
590 "Aᵀb".to_string(), &preadj_b, |
|
591 "Aᵀb̂".to_string(), Some(&preadj_b_hat), |
|
592 plotgrid2, None, &μ_hat, |
|
593 pfx("omega_b") |
|
594 ); |
|
595 |
|
596 // Save true solution and observables |
|
597 let pfx = |n| format!("{}{}", prefix, n); |
|
598 μ_hat.write_csv(pfx("orig.txt"))?; |
|
599 opA.write_observable(&b_hat, pfx("b_hat"))?; |
|
600 opA.write_observable(&b, pfx("b_noisy")) |
|
601 } |
|
602 |