271 fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { |
273 fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { |
272 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; |
274 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; |
273 Ok(()) |
275 Ok(()) |
274 } |
276 } |
275 |
277 |
|
278 |
|
279 /// Struct for experiment configurations |
|
280 #[derive(Debug, Clone, Serialize)] |
|
281 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize> |
|
282 where F : Float, |
|
283 [usize; N] : Serialize, |
|
284 NoiseDistr : Distribution<F>, |
|
285 S : Sensor<F, N>, |
|
286 P : Spread<F, N>, |
|
287 K : SimpleConvolutionKernel<F, N>, |
|
288 { |
|
289 /// Domain $Ω$. |
|
290 pub domain : Cube<F, N>, |
|
291 /// Number of sensors along each dimension |
|
292 pub sensor_count : [usize; N], |
|
293 /// Noise distribution |
|
294 pub noise_distr : NoiseDistr, |
|
295 /// Seed for random noise generation (for repeatable experiments) |
|
296 pub noise_seed : u64, |
|
297 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. |
|
298 pub sensor : S, |
|
299 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. |
|
300 pub spread : P, |
|
301 /// Kernel $ρ$ of $𝒟$. |
|
302 pub kernel : K, |
|
303 /// True point sources |
|
304 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, |
|
305 /// Regularisation term and parameter |
|
306 pub regularisation : Regularisation<F>, |
|
307 /// For plotting : how wide should the kernels be plotted |
|
308 pub kernel_plot_width : F, |
|
309 /// Data term |
|
310 pub dataterm : DataTerm, |
|
311 /// A map of default configurations for algorithms |
|
312 #[serde(skip)] |
|
313 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, |
|
314 } |
|
315 |
|
316 /// Trait for runnable experiments |
|
317 pub trait RunnableExperiment<F : ClapFloat> { |
|
318 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. |
|
319 fn runall(&self, cli : &CommandLineArgs, |
|
320 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; |
|
321 |
|
322 /// Return algorithm default config |
|
323 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
|
324 -> Named<AlgorithmConfig<F>>; |
|
325 } |
|
326 |
|
327 // *** macro boilerplate *** |
|
328 macro_rules! impl_experiment { |
|
329 ($type:ident, $reg_field:ident, $reg_convert:path) => { |
|
330 // *** macro *** |
|
331 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for |
|
332 Named<$type<F, NoiseDistr, S, K, P, N>> |
|
333 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, |
|
334 [usize; N] : Serialize, |
|
335 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
|
336 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
|
337 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, |
|
338 AutoConvolution<P> : BoundedBy<F, K>, |
|
339 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> |
|
340 + Copy + Serialize + std::fmt::Debug, |
|
341 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
|
342 PlotLookup : Plotting<N>, |
|
343 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
|
344 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
345 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
346 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug { |
|
347 |
|
348 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
|
349 -> Named<AlgorithmConfig<F>> { |
|
350 alg.to_named( |
|
351 self.data |
|
352 .algorithm_defaults |
|
353 .get(&alg) |
|
354 .map_or_else(|| alg.default_config(), |
|
355 |config| config.clone()) |
|
356 .cli_override(cli) |
|
357 ) |
|
358 } |
|
359 |
|
360 fn runall(&self, cli : &CommandLineArgs, |
|
361 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
|
362 // Get experiment configuration |
|
363 let &Named { |
|
364 name : ref experiment_name, |
|
365 data : $type { |
|
366 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
|
367 ref μ_hat, /*regularisation,*/ kernel_plot_width, dataterm, noise_seed, |
|
368 .. |
|
369 } |
|
370 } = self; |
|
371 #[allow(deprecated)] |
|
372 let regularisation = $reg_convert(self.data.$reg_field); |
|
373 |
|
374 println!("{}\n{}", |
|
375 format!("Performing experiment {}…", experiment_name).cyan(), |
|
376 format!("{:?}", &self.data).bright_black()); |
|
377 |
|
378 // Set up output directory |
|
379 let prefix = format!("{}/{}/", cli.outdir, self.name); |
|
380 |
|
381 // Set up algorithms |
|
382 let iterator_options = AlgIteratorOptions{ |
|
383 max_iter : cli.max_iter, |
|
384 verbose_iter : cli.verbose_iter |
|
385 .map_or(Verbose::Logarithmic(10), |
|
386 |n| Verbose::Every(n)), |
|
387 quiet : cli.quiet, |
|
388 }; |
|
389 let algorithms = match (algs, self.data.dataterm) { |
|
390 (Some(algs), _) => algs, |
|
391 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], |
|
392 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], |
|
393 }; |
|
394 |
|
395 // Set up operators |
|
396 let depth = DynamicDepth(8); |
|
397 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); |
|
398 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); |
|
399 |
|
400 // Set up random number generator. |
|
401 let mut rng = StdRng::seed_from_u64(noise_seed); |
|
402 |
|
403 // Generate the data and calculate SSNR statistic |
|
404 let b_hat = opA.apply(μ_hat); |
|
405 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); |
|
406 let b = &b_hat + &noise; |
|
407 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
|
408 // overloading log10 and conflicting with standard NumTraits one. |
|
409 let stats = ExperimentStats::new(&b, &noise); |
|
410 |
|
411 // Save experiment configuration and statistics |
|
412 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); |
|
413 std::fs::create_dir_all(&prefix)?; |
|
414 write_json(mkname_e("experiment"), self)?; |
|
415 write_json(mkname_e("config"), cli)?; |
|
416 write_json(mkname_e("stats"), &stats)?; |
|
417 |
|
418 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
|
419 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
|
420 |
|
421 // Run the algorithm(s) |
|
422 for named @ Named { name : alg_name, data : alg } in algorithms.iter() { |
|
423 let this_prefix = format!("{}{}/", prefix, alg_name); |
|
424 |
|
425 let running = || if !cli.quiet { |
|
426 println!("{}\n{}\n{}", |
|
427 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), |
|
428 format!("{:?}", iterator_options).bright_black(), |
|
429 format!("{:?}", alg).bright_black()); |
|
430 }; |
|
431 let not_implemented = || { |
|
432 let msg = format!("Algorithm “{alg_name}” not implemented for \ |
|
433 dataterm {dataterm:?} and regularisation {regularisation:?}. \ |
|
434 Skipping.").red(); |
|
435 eprintln!("{}", msg); |
|
436 }; |
|
437 // Create Logger and IteratorFactory |
|
438 let mut logger = Logger::new(); |
|
439 let findim_data = prepare_optimise_weights(&opA); |
|
440 let inner_config : InnerSettings<F> = Default::default(); |
|
441 let inner_it = inner_config.iterator_options; |
|
442 let logmap = |iter, Timed { cpu_time, data }| { |
|
443 let IterInfo { |
|
444 value, |
|
445 n_spikes, |
|
446 inner_iters, |
|
447 merged, |
|
448 pruned, |
|
449 postprocessing, |
|
450 this_iters, |
|
451 .. |
|
452 } = data; |
|
453 let post_value = match (postprocessing, dataterm, regularisation) { |
|
454 (Some(mut μ), DataTerm::L2Squared, Regularisation::Radon(α)) => { |
|
455 // Comparison postprocessing is only implemented for the case handled |
|
456 // by the FW variants. |
|
457 optimise_weights( |
|
458 &mut μ, &opA, &b, α, &findim_data, &inner_config, |
|
459 inner_it |
|
460 ); |
|
461 dataterm.value_at_residual(opA.apply(&μ) - &b) |
|
462 + regularisation.apply(&μ) |
|
463 }, |
|
464 _ => value, |
|
465 }; |
|
466 CSVLog { |
|
467 iter, |
|
468 value, |
|
469 post_value, |
|
470 n_spikes, |
|
471 cpu_time : cpu_time.as_secs_f64(), |
|
472 inner_iters, |
|
473 merged, |
|
474 pruned, |
|
475 this_iters |
|
476 } |
|
477 }; |
|
478 let iterator = iterator_options.instantiate() |
|
479 .timed() |
|
480 .mapped(logmap) |
|
481 .into_log(&mut logger); |
|
482 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
|
483 |
|
484 // Create plotter and directory if needed. |
|
485 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
|
486 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid); |
|
487 |
|
488 // Run the algorithm |
|
489 let start = Instant::now(); |
|
490 let start_cpu = ProcessTime::now(); |
|
491 let μ = match alg { |
|
492 AlgorithmConfig::FB(ref algconfig) => { |
|
493 match (regularisation, dataterm) { |
|
494 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
|
495 running(); |
|
496 pointsource_fb_reg( |
|
497 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
|
498 iterator, plotter |
|
499 ) |
|
500 }, |
|
501 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
|
502 running(); |
|
503 pointsource_fb_reg( |
|
504 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
|
505 iterator, plotter |
|
506 ) |
|
507 }, |
|
508 _ => { |
|
509 not_implemented(); |
|
510 continue |
|
511 } |
|
512 } |
|
513 }, |
|
514 AlgorithmConfig::PDPS(ref algconfig) => { |
|
515 running(); |
|
516 match (regularisation, dataterm) { |
|
517 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
|
518 pointsource_pdps_reg( |
|
519 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
|
520 iterator, plotter, L2Squared |
|
521 ) |
|
522 }, |
|
523 (Regularisation::Radon(α),DataTerm::L2Squared) => { |
|
524 pointsource_pdps_reg( |
|
525 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
|
526 iterator, plotter, L2Squared |
|
527 ) |
|
528 }, |
|
529 (Regularisation::NonnegRadon(α), DataTerm::L1) => { |
|
530 pointsource_pdps_reg( |
|
531 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
|
532 iterator, plotter, L1 |
|
533 ) |
|
534 }, |
|
535 (Regularisation::Radon(α), DataTerm::L1) => { |
|
536 pointsource_pdps_reg( |
|
537 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
|
538 iterator, plotter, L1 |
|
539 ) |
|
540 }, |
|
541 } |
|
542 }, |
|
543 AlgorithmConfig::FW(ref algconfig) => { |
|
544 match (regularisation, dataterm) { |
|
545 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
|
546 running(); |
|
547 pointsource_fw(&opA, &b, α, algconfig, iterator, plotter) |
|
548 }, |
|
549 _ => { |
|
550 not_implemented(); |
|
551 continue |
|
552 } |
|
553 } |
|
554 } |
|
555 }; |
|
556 |
|
557 let elapsed = start.elapsed().as_secs_f64(); |
|
558 let cpu_time = start_cpu.elapsed().as_secs_f64(); |
|
559 |
|
560 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); |
|
561 |
|
562 // Save results |
|
563 println!("{}", "Saving results…".green()); |
|
564 |
|
565 let mkname = |t| format!("{prefix}{alg_name}_{t}"); |
|
566 |
|
567 write_json(mkname("config.json"), &named)?; |
|
568 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; |
|
569 μ.write_csv(mkname("reco.txt"))?; |
|
570 logger.write_csv(mkname("log.txt"))?; |
|
571 } |
|
572 |
|
573 Ok(()) |
|
574 } |
|
575 } |
|
576 // *** macro end boiler plate *** |
|
577 }} |
|
578 // *** actual code *** |
|
579 |
|
580 impl_experiment!(ExperimentV2, regularisation, std::convert::identity); |
|
581 |
|
582 /// Plot experiment setup |
|
583 #[replace_float_literals(F::cast_from(literal))] |
|
584 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( |
|
585 cli : &CommandLineArgs, |
|
586 prefix : &String, |
|
587 domain : &Cube<F, N>, |
|
588 sensor : &Sensor, |
|
589 kernel : &Kernel, |
|
590 spread : &Spread, |
|
591 μ_hat : &DiscreteMeasure<Loc<F, N>, F>, |
|
592 op𝒟 : &𝒟, |
|
593 opA : &A, |
|
594 b_hat : &A::Observable, |
|
595 b : &A::Observable, |
|
596 kernel_plot_width : F, |
|
597 ) -> DynError |
|
598 where F : Float + ToNalgebraRealField, |
|
599 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
|
600 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
|
601 Kernel : RealMapping<F, N> + Support<F, N>, |
|
602 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, |
|
603 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
|
604 𝒟::Codomain : RealMapping<F, N>, |
|
605 A : ForwardModel<Loc<F, N>, F>, |
|
606 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, |
|
607 PlotLookup : Plotting<N>, |
|
608 Cube<F, N> : SetOrd { |
|
609 |
|
610 if cli.plot < PlotLevel::Data { |
|
611 return Ok(()) |
|
612 } |
|
613 |
|
614 let base = Convolution(sensor.clone(), spread.clone()); |
|
615 |
|
616 let resolution = if N==1 { 100 } else { 40 }; |
|
617 let pfx = |n| format!("{}{}", prefix, n); |
|
618 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
|
619 |
|
620 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); |
|
621 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); |
|
622 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); |
|
623 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); |
|
624 |
|
625 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
|
626 |
|
627 let ω_hat = op𝒟.apply(μ_hat); |
|
628 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
|
629 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string()); |
|
630 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"), |
|
631 "noise Aᵀ(Aμ̂ - b)".to_string()); |
|
632 |
|
633 let preadj_b = opA.preadjoint().apply(b); |
|
634 let preadj_b_hat = opA.preadjoint().apply(b_hat); |
|
635 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); |
|
636 PlotLookup::plot_into_file_spikes( |
|
637 "Aᵀb".to_string(), &preadj_b, |
|
638 "Aᵀb̂".to_string(), Some(&preadj_b_hat), |
|
639 plotgrid2, None, &μ_hat, |
|
640 pfx("omega_b") |
|
641 ); |
|
642 |
|
643 // Save true solution and observables |
|
644 let pfx = |n| format!("{}{}", prefix, n); |
|
645 μ_hat.write_csv(pfx("orig.txt"))?; |
|
646 opA.write_observable(&b_hat, pfx("b_hat"))?; |
|
647 opA.write_observable(&b, pfx("b_noisy")) |
|
648 } |
|
649 |
|
650 // |
|
651 // Deprecated interface |
|
652 // |
276 |
653 |
277 /// Struct for experiment configurations |
654 /// Struct for experiment configurations |
278 #[derive(Debug, Clone, Serialize)] |
655 #[derive(Debug, Clone, Serialize)] |
279 pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> |
656 pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> |
280 where F : Float, |
657 where F : Float, |
299 /// Kernel $ρ$ of $𝒟$. |
676 /// Kernel $ρ$ of $𝒟$. |
300 pub kernel : K, |
677 pub kernel : K, |
301 /// True point sources |
678 /// True point sources |
302 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, |
679 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, |
303 /// Regularisation parameter |
680 /// Regularisation parameter |
|
681 #[deprecated(note = "Use [`ExperimentV2`], which replaces `α` by more generic `regularisation`")] |
304 pub α : F, |
682 pub α : F, |
305 /// For plotting : how wide should the kernels be plotted |
683 /// For plotting : how wide should the kernels be plotted |
306 pub kernel_plot_width : F, |
684 pub kernel_plot_width : F, |
307 /// Data term |
685 /// Data term |
308 pub dataterm : DataTerm, |
686 pub dataterm : DataTerm, |
309 /// A map of default configurations for algorithms |
687 /// A map of default configurations for algorithms |
310 #[serde(skip)] |
688 #[serde(skip)] |
311 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, |
689 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, |
312 } |
690 } |
313 |
691 |
314 /// Trait for runnable experiments |
692 impl_experiment!(Experiment, α, Regularisation::NonnegRadon); |
315 pub trait RunnableExperiment<F : ClapFloat> { |
|
316 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. |
|
317 fn runall(&self, cli : &CommandLineArgs, |
|
318 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; |
|
319 |
|
320 /// Return algorithm default config |
|
321 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
|
322 -> Named<AlgorithmConfig<F>>; |
|
323 } |
|
324 |
|
325 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for |
|
326 Named<Experiment<F, NoiseDistr, S, K, P, N>> |
|
327 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, |
|
328 [usize; N] : Serialize, |
|
329 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
|
330 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
|
331 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, |
|
332 AutoConvolution<P> : BoundedBy<F, K>, |
|
333 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> |
|
334 + Copy + Serialize + std::fmt::Debug, |
|
335 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
|
336 PlotLookup : Plotting<N>, |
|
337 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
|
338 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
339 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
|
340 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug { |
|
341 |
|
342 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
|
343 -> Named<AlgorithmConfig<F>> { |
|
344 alg.to_named( |
|
345 self.data |
|
346 .algorithm_defaults |
|
347 .get(&alg) |
|
348 .map_or_else(|| alg.default_config(), |
|
349 |config| config.clone()) |
|
350 .cli_override(cli) |
|
351 ) |
|
352 } |
|
353 |
|
354 fn runall(&self, cli : &CommandLineArgs, |
|
355 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
|
356 // Get experiment configuration |
|
357 let &Named { |
|
358 name : ref experiment_name, |
|
359 data : Experiment { |
|
360 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
|
361 ref μ_hat, α, kernel_plot_width, dataterm, noise_seed, |
|
362 .. |
|
363 } |
|
364 } = self; |
|
365 |
|
366 println!("{}\n{}", |
|
367 format!("Performing experiment {}…", experiment_name).cyan(), |
|
368 format!("{:?}", &self.data).bright_black()); |
|
369 |
|
370 // Set up output directory |
|
371 let prefix = format!("{}/{}/", cli.outdir, self.name); |
|
372 |
|
373 // Set up algorithms |
|
374 let iterator_options = AlgIteratorOptions{ |
|
375 max_iter : cli.max_iter, |
|
376 verbose_iter : cli.verbose_iter |
|
377 .map_or(Verbose::Logarithmic(10), |
|
378 |n| Verbose::Every(n)), |
|
379 quiet : cli.quiet, |
|
380 }; |
|
381 let algorithms = match (algs, self.data.dataterm) { |
|
382 (Some(algs), _) => algs, |
|
383 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], |
|
384 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], |
|
385 }; |
|
386 |
|
387 // Set up operators |
|
388 let depth = DynamicDepth(8); |
|
389 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); |
|
390 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); |
|
391 |
|
392 // Set up random number generator. |
|
393 let mut rng = StdRng::seed_from_u64(noise_seed); |
|
394 |
|
395 // Generate the data and calculate SSNR statistic |
|
396 let b_hat = opA.apply(μ_hat); |
|
397 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); |
|
398 let b = &b_hat + &noise; |
|
399 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
|
400 // overloading log10 and conflicting with standard NumTraits one. |
|
401 let stats = ExperimentStats::new(&b, &noise); |
|
402 |
|
403 // Save experiment configuration and statistics |
|
404 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); |
|
405 std::fs::create_dir_all(&prefix)?; |
|
406 write_json(mkname_e("experiment"), self)?; |
|
407 write_json(mkname_e("config"), cli)?; |
|
408 write_json(mkname_e("stats"), &stats)?; |
|
409 |
|
410 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
|
411 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
|
412 |
|
413 // Run the algorithm(s) |
|
414 for named @ Named { name : alg_name, data : alg } in algorithms.iter() { |
|
415 let this_prefix = format!("{}{}/", prefix, alg_name); |
|
416 |
|
417 let running = || if !cli.quiet { |
|
418 println!("{}\n{}\n{}", |
|
419 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), |
|
420 format!("{:?}", iterator_options).bright_black(), |
|
421 format!("{:?}", alg).bright_black()); |
|
422 }; |
|
423 |
|
424 // Create Logger and IteratorFactory |
|
425 let mut logger = Logger::new(); |
|
426 let findim_data = prepare_optimise_weights(&opA); |
|
427 let inner_config : InnerSettings<F> = Default::default(); |
|
428 let inner_it = inner_config.iterator_options; |
|
429 let logmap = |iter, Timed { cpu_time, data }| { |
|
430 let IterInfo { |
|
431 value, |
|
432 n_spikes, |
|
433 inner_iters, |
|
434 merged, |
|
435 pruned, |
|
436 postprocessing, |
|
437 this_iters, |
|
438 .. |
|
439 } = data; |
|
440 let post_value = match postprocessing { |
|
441 None => value, |
|
442 Some(mut μ) => { |
|
443 match dataterm { |
|
444 DataTerm::L2Squared => { |
|
445 optimise_weights( |
|
446 &mut μ, &opA, &b, α, &findim_data, &inner_config, |
|
447 inner_it |
|
448 ); |
|
449 dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon) |
|
450 }, |
|
451 _ => value, |
|
452 } |
|
453 } |
|
454 }; |
|
455 CSVLog { |
|
456 iter, |
|
457 value, |
|
458 post_value, |
|
459 n_spikes, |
|
460 cpu_time : cpu_time.as_secs_f64(), |
|
461 inner_iters, |
|
462 merged, |
|
463 pruned, |
|
464 this_iters |
|
465 } |
|
466 }; |
|
467 let iterator = iterator_options.instantiate() |
|
468 .timed() |
|
469 .mapped(logmap) |
|
470 .into_log(&mut logger); |
|
471 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
|
472 |
|
473 // Create plotter and directory if needed. |
|
474 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
|
475 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid); |
|
476 |
|
477 // Run the algorithm |
|
478 let start = Instant::now(); |
|
479 let start_cpu = ProcessTime::now(); |
|
480 let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) { |
|
481 (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => { |
|
482 running(); |
|
483 pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter) |
|
484 }, |
|
485 (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => { |
|
486 running(); |
|
487 pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter) |
|
488 }, |
|
489 (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => { |
|
490 running(); |
|
491 pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared) |
|
492 }, |
|
493 (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => { |
|
494 running(); |
|
495 pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1) |
|
496 }, |
|
497 _ => { |
|
498 let msg = format!("Algorithm “{alg_name}” not implemented for \ |
|
499 dataterm {dataterm:?}. Skipping.").red(); |
|
500 eprintln!("{}", msg); |
|
501 continue |
|
502 } |
|
503 }; |
|
504 let elapsed = start.elapsed().as_secs_f64(); |
|
505 let cpu_time = start_cpu.elapsed().as_secs_f64(); |
|
506 |
|
507 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); |
|
508 |
|
509 // Save results |
|
510 println!("{}", "Saving results…".green()); |
|
511 |
|
512 let mkname = |t| format!("{prefix}{alg_name}_{t}"); |
|
513 |
|
514 write_json(mkname("config.json"), &named)?; |
|
515 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; |
|
516 μ.write_csv(mkname("reco.txt"))?; |
|
517 logger.write_csv(mkname("log.txt"))?; |
|
518 } |
|
519 |
|
520 Ok(()) |
|
521 } |
|
522 } |
|
523 |
|
524 /// Plot experiment setup |
|
525 #[replace_float_literals(F::cast_from(literal))] |
|
526 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( |
|
527 cli : &CommandLineArgs, |
|
528 prefix : &String, |
|
529 domain : &Cube<F, N>, |
|
530 sensor : &Sensor, |
|
531 kernel : &Kernel, |
|
532 spread : &Spread, |
|
533 μ_hat : &DiscreteMeasure<Loc<F, N>, F>, |
|
534 op𝒟 : &𝒟, |
|
535 opA : &A, |
|
536 b_hat : &A::Observable, |
|
537 b : &A::Observable, |
|
538 kernel_plot_width : F, |
|
539 ) -> DynError |
|
540 where F : Float + ToNalgebraRealField, |
|
541 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
|
542 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
|
543 Kernel : RealMapping<F, N> + Support<F, N>, |
|
544 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, |
|
545 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
|
546 𝒟::Codomain : RealMapping<F, N>, |
|
547 A : ForwardModel<Loc<F, N>, F>, |
|
548 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, |
|
549 PlotLookup : Plotting<N>, |
|
550 Cube<F, N> : SetOrd { |
|
551 |
|
552 if cli.plot < PlotLevel::Data { |
|
553 return Ok(()) |
|
554 } |
|
555 |
|
556 let base = Convolution(sensor.clone(), spread.clone()); |
|
557 |
|
558 let resolution = if N==1 { 100 } else { 40 }; |
|
559 let pfx = |n| format!("{}{}", prefix, n); |
|
560 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
|
561 |
|
562 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); |
|
563 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); |
|
564 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); |
|
565 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); |
|
566 |
|
567 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
|
568 |
|
569 let ω_hat = op𝒟.apply(μ_hat); |
|
570 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
|
571 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string()); |
|
572 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"), |
|
573 "noise Aᵀ(Aμ̂ - b)".to_string()); |
|
574 |
|
575 let preadj_b = opA.preadjoint().apply(b); |
|
576 let preadj_b_hat = opA.preadjoint().apply(b_hat); |
|
577 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); |
|
578 PlotLookup::plot_into_file_spikes( |
|
579 "Aᵀb".to_string(), &preadj_b, |
|
580 "Aᵀb̂".to_string(), Some(&preadj_b_hat), |
|
581 plotgrid2, None, &μ_hat, |
|
582 pfx("omega_b") |
|
583 ); |
|
584 |
|
585 // Save true solution and observables |
|
586 let pfx = |n| format!("{}{}", prefix, n); |
|
587 μ_hat.write_csv(pfx("orig.txt"))?; |
|
588 opA.write_observable(&b_hat, pfx("b_hat"))?; |
|
589 opA.write_observable(&b, pfx("b_noisy")) |
|
590 } |
|
591 |
|