| 271 fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { |
273 fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { |
| 272 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; |
274 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; |
| 273 Ok(()) |
275 Ok(()) |
| 274 } |
276 } |
| 275 |
277 |
| |
278 |
| |
279 /// Struct for experiment configurations |
| |
280 #[derive(Debug, Clone, Serialize)] |
| |
281 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize> |
| |
282 where F : Float, |
| |
283 [usize; N] : Serialize, |
| |
284 NoiseDistr : Distribution<F>, |
| |
285 S : Sensor<F, N>, |
| |
286 P : Spread<F, N>, |
| |
287 K : SimpleConvolutionKernel<F, N>, |
| |
288 { |
| |
289 /// Domain $Ω$. |
| |
290 pub domain : Cube<F, N>, |
| |
291 /// Number of sensors along each dimension |
| |
292 pub sensor_count : [usize; N], |
| |
293 /// Noise distribution |
| |
294 pub noise_distr : NoiseDistr, |
| |
295 /// Seed for random noise generation (for repeatable experiments) |
| |
296 pub noise_seed : u64, |
| |
297 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. |
| |
298 pub sensor : S, |
| |
299 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. |
| |
300 pub spread : P, |
| |
301 /// Kernel $ρ$ of $𝒟$. |
| |
302 pub kernel : K, |
| |
303 /// True point sources |
| |
304 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, |
| |
305 /// Regularisation term and parameter |
| |
306 pub regularisation : Regularisation<F>, |
| |
307 /// For plotting : how wide should the kernels be plotted |
| |
308 pub kernel_plot_width : F, |
| |
309 /// Data term |
| |
310 pub dataterm : DataTerm, |
| |
311 /// A map of default configurations for algorithms |
| |
312 #[serde(skip)] |
| |
313 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, |
| |
314 } |
| |
315 |
| |
316 /// Trait for runnable experiments |
| |
317 pub trait RunnableExperiment<F : ClapFloat> { |
| |
318 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. |
| |
319 fn runall(&self, cli : &CommandLineArgs, |
| |
320 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; |
| |
321 |
| |
322 /// Return algorithm default config |
| |
323 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
| |
324 -> Named<AlgorithmConfig<F>>; |
| |
325 } |
| |
326 |
| |
327 // *** macro boilerplate *** |
| |
328 macro_rules! impl_experiment { |
| |
329 ($type:ident, $reg_field:ident, $reg_convert:path) => { |
| |
330 // *** macro *** |
| |
331 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for |
| |
332 Named<$type<F, NoiseDistr, S, K, P, N>> |
| |
333 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, |
| |
334 [usize; N] : Serialize, |
| |
335 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
| |
336 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
| |
337 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, |
| |
338 AutoConvolution<P> : BoundedBy<F, K>, |
| |
339 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> |
| |
340 + Copy + Serialize + std::fmt::Debug, |
| |
341 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
| |
342 PlotLookup : Plotting<N>, |
| |
343 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
| |
344 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| |
345 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
| |
346 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug { |
| |
347 |
| |
348 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
| |
349 -> Named<AlgorithmConfig<F>> { |
| |
350 alg.to_named( |
| |
351 self.data |
| |
352 .algorithm_defaults |
| |
353 .get(&alg) |
| |
354 .map_or_else(|| alg.default_config(), |
| |
355 |config| config.clone()) |
| |
356 .cli_override(cli) |
| |
357 ) |
| |
358 } |
| |
359 |
| |
360 fn runall(&self, cli : &CommandLineArgs, |
| |
361 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
| |
362 // Get experiment configuration |
| |
363 let &Named { |
| |
364 name : ref experiment_name, |
| |
365 data : $type { |
| |
366 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
| |
367 ref μ_hat, /*regularisation,*/ kernel_plot_width, dataterm, noise_seed, |
| |
368 .. |
| |
369 } |
| |
370 } = self; |
| |
371 #[allow(deprecated)] |
| |
372 let regularisation = $reg_convert(self.data.$reg_field); |
| |
373 |
| |
374 println!("{}\n{}", |
| |
375 format!("Performing experiment {}…", experiment_name).cyan(), |
| |
376 format!("{:?}", &self.data).bright_black()); |
| |
377 |
| |
378 // Set up output directory |
| |
379 let prefix = format!("{}/{}/", cli.outdir, self.name); |
| |
380 |
| |
381 // Set up algorithms |
| |
382 let iterator_options = AlgIteratorOptions{ |
| |
383 max_iter : cli.max_iter, |
| |
384 verbose_iter : cli.verbose_iter |
| |
385 .map_or(Verbose::Logarithmic(10), |
| |
386 |n| Verbose::Every(n)), |
| |
387 quiet : cli.quiet, |
| |
388 }; |
| |
389 let algorithms = match (algs, self.data.dataterm) { |
| |
390 (Some(algs), _) => algs, |
| |
391 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], |
| |
392 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], |
| |
393 }; |
| |
394 |
| |
395 // Set up operators |
| |
396 let depth = DynamicDepth(8); |
| |
397 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); |
| |
398 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); |
| |
399 |
| |
400 // Set up random number generator. |
| |
401 let mut rng = StdRng::seed_from_u64(noise_seed); |
| |
402 |
| |
403 // Generate the data and calculate SSNR statistic |
| |
404 let b_hat = opA.apply(μ_hat); |
| |
405 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); |
| |
406 let b = &b_hat + &noise; |
| |
407 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
| |
408 // overloading log10 and conflicting with standard NumTraits one. |
| |
409 let stats = ExperimentStats::new(&b, &noise); |
| |
410 |
| |
411 // Save experiment configuration and statistics |
| |
412 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); |
| |
413 std::fs::create_dir_all(&prefix)?; |
| |
414 write_json(mkname_e("experiment"), self)?; |
| |
415 write_json(mkname_e("config"), cli)?; |
| |
416 write_json(mkname_e("stats"), &stats)?; |
| |
417 |
| |
418 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
| |
419 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
| |
420 |
| |
421 // Run the algorithm(s) |
| |
422 for named @ Named { name : alg_name, data : alg } in algorithms.iter() { |
| |
423 let this_prefix = format!("{}{}/", prefix, alg_name); |
| |
424 |
| |
425 let running = || if !cli.quiet { |
| |
426 println!("{}\n{}\n{}", |
| |
427 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), |
| |
428 format!("{:?}", iterator_options).bright_black(), |
| |
429 format!("{:?}", alg).bright_black()); |
| |
430 }; |
| |
431 let not_implemented = || { |
| |
432 let msg = format!("Algorithm “{alg_name}” not implemented for \ |
| |
433 dataterm {dataterm:?} and regularisation {regularisation:?}. \ |
| |
434 Skipping.").red(); |
| |
435 eprintln!("{}", msg); |
| |
436 }; |
| |
437 // Create Logger and IteratorFactory |
| |
438 let mut logger = Logger::new(); |
| |
439 let findim_data = prepare_optimise_weights(&opA); |
| |
440 let inner_config : InnerSettings<F> = Default::default(); |
| |
441 let inner_it = inner_config.iterator_options; |
| |
442 let logmap = |iter, Timed { cpu_time, data }| { |
| |
443 let IterInfo { |
| |
444 value, |
| |
445 n_spikes, |
| |
446 inner_iters, |
| |
447 merged, |
| |
448 pruned, |
| |
449 postprocessing, |
| |
450 this_iters, |
| |
451 .. |
| |
452 } = data; |
| |
453 let post_value = match (postprocessing, dataterm, regularisation) { |
| |
454 (Some(mut μ), DataTerm::L2Squared, Regularisation::Radon(α)) => { |
| |
455 // Comparison postprocessing is only implemented for the case handled |
| |
456 // by the FW variants. |
| |
457 optimise_weights( |
| |
458 &mut μ, &opA, &b, α, &findim_data, &inner_config, |
| |
459 inner_it |
| |
460 ); |
| |
461 dataterm.value_at_residual(opA.apply(&μ) - &b) |
| |
462 + regularisation.apply(&μ) |
| |
463 }, |
| |
464 _ => value, |
| |
465 }; |
| |
466 CSVLog { |
| |
467 iter, |
| |
468 value, |
| |
469 post_value, |
| |
470 n_spikes, |
| |
471 cpu_time : cpu_time.as_secs_f64(), |
| |
472 inner_iters, |
| |
473 merged, |
| |
474 pruned, |
| |
475 this_iters |
| |
476 } |
| |
477 }; |
| |
478 let iterator = iterator_options.instantiate() |
| |
479 .timed() |
| |
480 .mapped(logmap) |
| |
481 .into_log(&mut logger); |
| |
482 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
| |
483 |
| |
484 // Create plotter and directory if needed. |
| |
485 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
| |
486 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid); |
| |
487 |
| |
488 // Run the algorithm |
| |
489 let start = Instant::now(); |
| |
490 let start_cpu = ProcessTime::now(); |
| |
491 let μ = match alg { |
| |
492 AlgorithmConfig::FB(ref algconfig) => { |
| |
493 match (regularisation, dataterm) { |
| |
494 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
| |
495 running(); |
| |
496 pointsource_fb_reg( |
| |
497 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
| |
498 iterator, plotter |
| |
499 ) |
| |
500 }, |
| |
501 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
| |
502 running(); |
| |
503 pointsource_fb_reg( |
| |
504 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
| |
505 iterator, plotter |
| |
506 ) |
| |
507 }, |
| |
508 _ => { |
| |
509 not_implemented(); |
| |
510 continue |
| |
511 } |
| |
512 } |
| |
513 }, |
| |
514 AlgorithmConfig::PDPS(ref algconfig) => { |
| |
515 running(); |
| |
516 match (regularisation, dataterm) { |
| |
517 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
| |
518 pointsource_pdps_reg( |
| |
519 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
| |
520 iterator, plotter, L2Squared |
| |
521 ) |
| |
522 }, |
| |
523 (Regularisation::Radon(α),DataTerm::L2Squared) => { |
| |
524 pointsource_pdps_reg( |
| |
525 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
| |
526 iterator, plotter, L2Squared |
| |
527 ) |
| |
528 }, |
| |
529 (Regularisation::NonnegRadon(α), DataTerm::L1) => { |
| |
530 pointsource_pdps_reg( |
| |
531 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
| |
532 iterator, plotter, L1 |
| |
533 ) |
| |
534 }, |
| |
535 (Regularisation::Radon(α), DataTerm::L1) => { |
| |
536 pointsource_pdps_reg( |
| |
537 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
| |
538 iterator, plotter, L1 |
| |
539 ) |
| |
540 }, |
| |
541 } |
| |
542 }, |
| |
543 AlgorithmConfig::FW(ref algconfig) => { |
| |
544 match (regularisation, dataterm) { |
| |
545 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
| |
546 running(); |
| |
547 pointsource_fw(&opA, &b, α, algconfig, iterator, plotter) |
| |
548 }, |
| |
549 _ => { |
| |
550 not_implemented(); |
| |
551 continue |
| |
552 } |
| |
553 } |
| |
554 } |
| |
555 }; |
| |
556 |
| |
557 let elapsed = start.elapsed().as_secs_f64(); |
| |
558 let cpu_time = start_cpu.elapsed().as_secs_f64(); |
| |
559 |
| |
560 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); |
| |
561 |
| |
562 // Save results |
| |
563 println!("{}", "Saving results…".green()); |
| |
564 |
| |
565 let mkname = |t| format!("{prefix}{alg_name}_{t}"); |
| |
566 |
| |
567 write_json(mkname("config.json"), &named)?; |
| |
568 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; |
| |
569 μ.write_csv(mkname("reco.txt"))?; |
| |
570 logger.write_csv(mkname("log.txt"))?; |
| |
571 } |
| |
572 |
| |
573 Ok(()) |
| |
574 } |
| |
575 } |
| |
576 // *** macro end boiler plate *** |
| |
577 }} |
| |
578 // *** actual code *** |
| |
579 |
| |
580 impl_experiment!(ExperimentV2, regularisation, std::convert::identity); |
| |
581 |
| |
582 /// Plot experiment setup |
| |
583 #[replace_float_literals(F::cast_from(literal))] |
| |
584 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( |
| |
585 cli : &CommandLineArgs, |
| |
586 prefix : &String, |
| |
587 domain : &Cube<F, N>, |
| |
588 sensor : &Sensor, |
| |
589 kernel : &Kernel, |
| |
590 spread : &Spread, |
| |
591 μ_hat : &DiscreteMeasure<Loc<F, N>, F>, |
| |
592 op𝒟 : &𝒟, |
| |
593 opA : &A, |
| |
594 b_hat : &A::Observable, |
| |
595 b : &A::Observable, |
| |
596 kernel_plot_width : F, |
| |
597 ) -> DynError |
| |
598 where F : Float + ToNalgebraRealField, |
| |
599 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
| |
600 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
| |
601 Kernel : RealMapping<F, N> + Support<F, N>, |
| |
602 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, |
| |
603 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
| |
604 𝒟::Codomain : RealMapping<F, N>, |
| |
605 A : ForwardModel<Loc<F, N>, F>, |
| |
606 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, |
| |
607 PlotLookup : Plotting<N>, |
| |
608 Cube<F, N> : SetOrd { |
| |
609 |
| |
610 if cli.plot < PlotLevel::Data { |
| |
611 return Ok(()) |
| |
612 } |
| |
613 |
| |
614 let base = Convolution(sensor.clone(), spread.clone()); |
| |
615 |
| |
616 let resolution = if N==1 { 100 } else { 40 }; |
| |
617 let pfx = |n| format!("{}{}", prefix, n); |
| |
618 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
| |
619 |
| |
620 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); |
| |
621 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); |
| |
622 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); |
| |
623 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); |
| |
624 |
| |
625 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
| |
626 |
| |
627 let ω_hat = op𝒟.apply(μ_hat); |
| |
628 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
| |
629 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string()); |
| |
630 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"), |
| |
631 "noise Aᵀ(Aμ̂ - b)".to_string()); |
| |
632 |
| |
633 let preadj_b = opA.preadjoint().apply(b); |
| |
634 let preadj_b_hat = opA.preadjoint().apply(b_hat); |
| |
635 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); |
| |
636 PlotLookup::plot_into_file_spikes( |
| |
637 "Aᵀb".to_string(), &preadj_b, |
| |
638 "Aᵀb̂".to_string(), Some(&preadj_b_hat), |
| |
639 plotgrid2, None, &μ_hat, |
| |
640 pfx("omega_b") |
| |
641 ); |
| |
642 |
| |
643 // Save true solution and observables |
| |
644 let pfx = |n| format!("{}{}", prefix, n); |
| |
645 μ_hat.write_csv(pfx("orig.txt"))?; |
| |
646 opA.write_observable(&b_hat, pfx("b_hat"))?; |
| |
647 opA.write_observable(&b, pfx("b_noisy")) |
| |
648 } |
| |
649 |
| |
650 // |
| |
651 // Deprecated interface |
| |
652 // |
| 276 |
653 |
| 277 /// Struct for experiment configurations |
654 /// Struct for experiment configurations |
| 278 #[derive(Debug, Clone, Serialize)] |
655 #[derive(Debug, Clone, Serialize)] |
| 279 pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> |
656 pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> |
| 280 where F : Float, |
657 where F : Float, |
| 299 /// Kernel $ρ$ of $𝒟$. |
676 /// Kernel $ρ$ of $𝒟$. |
| 300 pub kernel : K, |
677 pub kernel : K, |
| 301 /// True point sources |
678 /// True point sources |
| 302 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, |
679 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, |
| 303 /// Regularisation parameter |
680 /// Regularisation parameter |
| |
681 #[deprecated(note = "Use [`ExperimentV2`], which replaces `α` by more generic `regularisation`")] |
| 304 pub α : F, |
682 pub α : F, |
| 305 /// For plotting : how wide should the kernels be plotted |
683 /// For plotting : how wide should the kernels be plotted |
| 306 pub kernel_plot_width : F, |
684 pub kernel_plot_width : F, |
| 307 /// Data term |
685 /// Data term |
| 308 pub dataterm : DataTerm, |
686 pub dataterm : DataTerm, |
| 309 /// A map of default configurations for algorithms |
687 /// A map of default configurations for algorithms |
| 310 #[serde(skip)] |
688 #[serde(skip)] |
| 311 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, |
689 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, |
| 312 } |
690 } |
| 313 |
691 |
| 314 /// Trait for runnable experiments |
692 impl_experiment!(Experiment, α, Regularisation::NonnegRadon); |
| 315 pub trait RunnableExperiment<F : ClapFloat> { |
|
| 316 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. |
|
| 317 fn runall(&self, cli : &CommandLineArgs, |
|
| 318 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; |
|
| 319 |
|
| 320 /// Return algorithm default config |
|
| 321 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
|
| 322 -> Named<AlgorithmConfig<F>>; |
|
| 323 } |
|
| 324 |
|
| 325 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for |
|
| 326 Named<Experiment<F, NoiseDistr, S, K, P, N>> |
|
| 327 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, |
|
| 328 [usize; N] : Serialize, |
|
| 329 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
|
| 330 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
|
| 331 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, |
|
| 332 AutoConvolution<P> : BoundedBy<F, K>, |
|
| 333 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> |
|
| 334 + Copy + Serialize + std::fmt::Debug, |
|
| 335 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
|
| 336 PlotLookup : Plotting<N>, |
|
| 337 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
|
| 338 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 339 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
| 340 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug { |
|
| 341 |
|
| 342 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
|
| 343 -> Named<AlgorithmConfig<F>> { |
|
| 344 alg.to_named( |
|
| 345 self.data |
|
| 346 .algorithm_defaults |
|
| 347 .get(&alg) |
|
| 348 .map_or_else(|| alg.default_config(), |
|
| 349 |config| config.clone()) |
|
| 350 .cli_override(cli) |
|
| 351 ) |
|
| 352 } |
|
| 353 |
|
| 354 fn runall(&self, cli : &CommandLineArgs, |
|
| 355 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
|
| 356 // Get experiment configuration |
|
| 357 let &Named { |
|
| 358 name : ref experiment_name, |
|
| 359 data : Experiment { |
|
| 360 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
|
| 361 ref μ_hat, α, kernel_plot_width, dataterm, noise_seed, |
|
| 362 .. |
|
| 363 } |
|
| 364 } = self; |
|
| 365 |
|
| 366 println!("{}\n{}", |
|
| 367 format!("Performing experiment {}…", experiment_name).cyan(), |
|
| 368 format!("{:?}", &self.data).bright_black()); |
|
| 369 |
|
| 370 // Set up output directory |
|
| 371 let prefix = format!("{}/{}/", cli.outdir, self.name); |
|
| 372 |
|
| 373 // Set up algorithms |
|
| 374 let iterator_options = AlgIteratorOptions{ |
|
| 375 max_iter : cli.max_iter, |
|
| 376 verbose_iter : cli.verbose_iter |
|
| 377 .map_or(Verbose::Logarithmic(10), |
|
| 378 |n| Verbose::Every(n)), |
|
| 379 quiet : cli.quiet, |
|
| 380 }; |
|
| 381 let algorithms = match (algs, self.data.dataterm) { |
|
| 382 (Some(algs), _) => algs, |
|
| 383 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], |
|
| 384 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], |
|
| 385 }; |
|
| 386 |
|
| 387 // Set up operators |
|
| 388 let depth = DynamicDepth(8); |
|
| 389 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); |
|
| 390 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); |
|
| 391 |
|
| 392 // Set up random number generator. |
|
| 393 let mut rng = StdRng::seed_from_u64(noise_seed); |
|
| 394 |
|
| 395 // Generate the data and calculate SSNR statistic |
|
| 396 let b_hat = opA.apply(μ_hat); |
|
| 397 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); |
|
| 398 let b = &b_hat + &noise; |
|
| 399 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
|
| 400 // overloading log10 and conflicting with standard NumTraits one. |
|
| 401 let stats = ExperimentStats::new(&b, &noise); |
|
| 402 |
|
| 403 // Save experiment configuration and statistics |
|
| 404 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); |
|
| 405 std::fs::create_dir_all(&prefix)?; |
|
| 406 write_json(mkname_e("experiment"), self)?; |
|
| 407 write_json(mkname_e("config"), cli)?; |
|
| 408 write_json(mkname_e("stats"), &stats)?; |
|
| 409 |
|
| 410 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
|
| 411 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
|
| 412 |
|
| 413 // Run the algorithm(s) |
|
| 414 for named @ Named { name : alg_name, data : alg } in algorithms.iter() { |
|
| 415 let this_prefix = format!("{}{}/", prefix, alg_name); |
|
| 416 |
|
| 417 let running = || if !cli.quiet { |
|
| 418 println!("{}\n{}\n{}", |
|
| 419 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), |
|
| 420 format!("{:?}", iterator_options).bright_black(), |
|
| 421 format!("{:?}", alg).bright_black()); |
|
| 422 }; |
|
| 423 |
|
| 424 // Create Logger and IteratorFactory |
|
| 425 let mut logger = Logger::new(); |
|
| 426 let findim_data = prepare_optimise_weights(&opA); |
|
| 427 let inner_config : InnerSettings<F> = Default::default(); |
|
| 428 let inner_it = inner_config.iterator_options; |
|
| 429 let logmap = |iter, Timed { cpu_time, data }| { |
|
| 430 let IterInfo { |
|
| 431 value, |
|
| 432 n_spikes, |
|
| 433 inner_iters, |
|
| 434 merged, |
|
| 435 pruned, |
|
| 436 postprocessing, |
|
| 437 this_iters, |
|
| 438 .. |
|
| 439 } = data; |
|
| 440 let post_value = match postprocessing { |
|
| 441 None => value, |
|
| 442 Some(mut μ) => { |
|
| 443 match dataterm { |
|
| 444 DataTerm::L2Squared => { |
|
| 445 optimise_weights( |
|
| 446 &mut μ, &opA, &b, α, &findim_data, &inner_config, |
|
| 447 inner_it |
|
| 448 ); |
|
| 449 dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon) |
|
| 450 }, |
|
| 451 _ => value, |
|
| 452 } |
|
| 453 } |
|
| 454 }; |
|
| 455 CSVLog { |
|
| 456 iter, |
|
| 457 value, |
|
| 458 post_value, |
|
| 459 n_spikes, |
|
| 460 cpu_time : cpu_time.as_secs_f64(), |
|
| 461 inner_iters, |
|
| 462 merged, |
|
| 463 pruned, |
|
| 464 this_iters |
|
| 465 } |
|
| 466 }; |
|
| 467 let iterator = iterator_options.instantiate() |
|
| 468 .timed() |
|
| 469 .mapped(logmap) |
|
| 470 .into_log(&mut logger); |
|
| 471 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
|
| 472 |
|
| 473 // Create plotter and directory if needed. |
|
| 474 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
|
| 475 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid); |
|
| 476 |
|
| 477 // Run the algorithm |
|
| 478 let start = Instant::now(); |
|
| 479 let start_cpu = ProcessTime::now(); |
|
| 480 let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) { |
|
| 481 (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => { |
|
| 482 running(); |
|
| 483 pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter) |
|
| 484 }, |
|
| 485 (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => { |
|
| 486 running(); |
|
| 487 pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter) |
|
| 488 }, |
|
| 489 (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => { |
|
| 490 running(); |
|
| 491 pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared) |
|
| 492 }, |
|
| 493 (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => { |
|
| 494 running(); |
|
| 495 pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1) |
|
| 496 }, |
|
| 497 _ => { |
|
| 498 let msg = format!("Algorithm “{alg_name}” not implemented for \ |
|
| 499 dataterm {dataterm:?}. Skipping.").red(); |
|
| 500 eprintln!("{}", msg); |
|
| 501 continue |
|
| 502 } |
|
| 503 }; |
|
| 504 let elapsed = start.elapsed().as_secs_f64(); |
|
| 505 let cpu_time = start_cpu.elapsed().as_secs_f64(); |
|
| 506 |
|
| 507 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); |
|
| 508 |
|
| 509 // Save results |
|
| 510 println!("{}", "Saving results…".green()); |
|
| 511 |
|
| 512 let mkname = |t| format!("{prefix}{alg_name}_{t}"); |
|
| 513 |
|
| 514 write_json(mkname("config.json"), &named)?; |
|
| 515 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; |
|
| 516 μ.write_csv(mkname("reco.txt"))?; |
|
| 517 logger.write_csv(mkname("log.txt"))?; |
|
| 518 } |
|
| 519 |
|
| 520 Ok(()) |
|
| 521 } |
|
| 522 } |
|
| 523 |
|
| 524 /// Plot experiment setup |
|
| 525 #[replace_float_literals(F::cast_from(literal))] |
|
| 526 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( |
|
| 527 cli : &CommandLineArgs, |
|
| 528 prefix : &String, |
|
| 529 domain : &Cube<F, N>, |
|
| 530 sensor : &Sensor, |
|
| 531 kernel : &Kernel, |
|
| 532 spread : &Spread, |
|
| 533 μ_hat : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 534 op𝒟 : &𝒟, |
|
| 535 opA : &A, |
|
| 536 b_hat : &A::Observable, |
|
| 537 b : &A::Observable, |
|
| 538 kernel_plot_width : F, |
|
| 539 ) -> DynError |
|
| 540 where F : Float + ToNalgebraRealField, |
|
| 541 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
|
| 542 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
|
| 543 Kernel : RealMapping<F, N> + Support<F, N>, |
|
| 544 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, |
|
| 545 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
|
| 546 𝒟::Codomain : RealMapping<F, N>, |
|
| 547 A : ForwardModel<Loc<F, N>, F>, |
|
| 548 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, |
|
| 549 PlotLookup : Plotting<N>, |
|
| 550 Cube<F, N> : SetOrd { |
|
| 551 |
|
| 552 if cli.plot < PlotLevel::Data { |
|
| 553 return Ok(()) |
|
| 554 } |
|
| 555 |
|
| 556 let base = Convolution(sensor.clone(), spread.clone()); |
|
| 557 |
|
| 558 let resolution = if N==1 { 100 } else { 40 }; |
|
| 559 let pfx = |n| format!("{}{}", prefix, n); |
|
| 560 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
|
| 561 |
|
| 562 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); |
|
| 563 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); |
|
| 564 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); |
|
| 565 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); |
|
| 566 |
|
| 567 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
|
| 568 |
|
| 569 let ω_hat = op𝒟.apply(μ_hat); |
|
| 570 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
|
| 571 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string()); |
|
| 572 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"), |
|
| 573 "noise Aᵀ(Aμ̂ - b)".to_string()); |
|
| 574 |
|
| 575 let preadj_b = opA.preadjoint().apply(b); |
|
| 576 let preadj_b_hat = opA.preadjoint().apply(b_hat); |
|
| 577 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); |
|
| 578 PlotLookup::plot_into_file_spikes( |
|
| 579 "Aᵀb".to_string(), &preadj_b, |
|
| 580 "Aᵀb̂".to_string(), Some(&preadj_b_hat), |
|
| 581 plotgrid2, None, &μ_hat, |
|
| 582 pfx("omega_b") |
|
| 583 ); |
|
| 584 |
|
| 585 // Save true solution and observables |
|
| 586 let pfx = |n| format!("{}{}", prefix, n); |
|
| 587 μ_hat.write_csv(pfx("orig.txt"))?; |
|
| 588 opA.write_observable(&b_hat, pfx("b_hat"))?; |
|
| 589 opA.write_observable(&b, pfx("b_noisy")) |
|
| 590 } |
|
| 591 |
|