24 use alg_tools::iterate::{ |
24 use alg_tools::iterate::{ |
25 Timed, |
25 Timed, |
26 AlgIteratorOptions, |
26 AlgIteratorOptions, |
27 Verbose, |
27 Verbose, |
28 AlgIteratorFactory, |
28 AlgIteratorFactory, |
|
29 LoggingIteratorFactory, |
|
30 TimingIteratorFactory, |
|
31 BasicAlgIteratorFactory, |
29 }; |
32 }; |
30 use alg_tools::logger::Logger; |
33 use alg_tools::logger::Logger; |
31 use alg_tools::error::DynError; |
34 use alg_tools::error::{ |
|
35 DynError, |
|
36 DynResult, |
|
37 }; |
32 use alg_tools::tabledump::TableDump; |
38 use alg_tools::tabledump::TableDump; |
33 use alg_tools::sets::Cube; |
39 use alg_tools::sets::Cube; |
34 use alg_tools::mapping::RealMapping; |
40 use alg_tools::mapping::{ |
|
41 RealMapping, |
|
42 DifferentiableMapping, |
|
43 DifferentiableRealMapping, |
|
44 Instance |
|
45 }; |
35 use alg_tools::nalgebra_support::ToNalgebraRealField; |
46 use alg_tools::nalgebra_support::ToNalgebraRealField; |
36 use alg_tools::euclidean::Euclidean; |
47 use alg_tools::euclidean::Euclidean; |
37 use alg_tools::norms::L1; |
48 use alg_tools::lingrid::{lingrid, LinSpace}; |
38 use alg_tools::lingrid::lingrid; |
|
39 use alg_tools::sets::SetOrd; |
49 use alg_tools::sets::SetOrd; |
|
50 use alg_tools::linops::{RowOp, IdOp /*, ZeroOp*/}; |
|
51 use alg_tools::discrete_gradient::{Grad, ForwardNeumann}; |
|
52 use alg_tools::convex::Zero; |
|
53 use alg_tools::maputil::map3; |
|
54 use alg_tools::direct_product::Pair; |
40 |
55 |
41 use crate::kernels::*; |
56 use crate::kernels::*; |
42 use crate::types::*; |
57 use crate::types::*; |
43 use crate::measures::*; |
58 use crate::measures::*; |
44 use crate::measures::merging::SpikeMerging; |
59 use crate::measures::merging::{SpikeMerging,SpikeMergingMethod}; |
45 use crate::forward_model::*; |
60 use crate::forward_model::*; |
|
61 use crate::forward_model::sensor_grid::{ |
|
62 SensorGrid, |
|
63 SensorGridBT, |
|
64 //SensorGridBTFN, |
|
65 Sensor, |
|
66 Spread, |
|
67 }; |
|
68 |
46 use crate::fb::{ |
69 use crate::fb::{ |
47 FBConfig, |
70 FBConfig, |
|
71 FBGenericConfig, |
48 pointsource_fb_reg, |
72 pointsource_fb_reg, |
49 FBMetaAlgorithm, |
73 pointsource_fista_reg, |
50 FBGenericConfig, |
74 }; |
|
75 use crate::sliding_fb::{ |
|
76 SlidingFBConfig, |
|
77 TransportConfig, |
|
78 pointsource_sliding_fb_reg |
|
79 }; |
|
80 use crate::sliding_pdps::{ |
|
81 SlidingPDPSConfig, |
|
82 pointsource_sliding_pdps_pair |
|
83 }; |
|
84 use crate::forward_pdps::{ |
|
85 ForwardPDPSConfig, |
|
86 pointsource_forward_pdps_pair |
51 }; |
87 }; |
52 use crate::pdps::{ |
88 use crate::pdps::{ |
53 PDPSConfig, |
89 PDPSConfig, |
54 L2Squared, |
|
55 pointsource_pdps_reg, |
90 pointsource_pdps_reg, |
56 }; |
91 }; |
57 use crate::frank_wolfe::{ |
92 use crate::frank_wolfe::{ |
58 FWConfig, |
93 FWConfig, |
59 FWVariant, |
94 FWVariant, |
60 pointsource_fw_reg, |
95 pointsource_fw_reg, |
61 WeightOptim, |
96 //WeightOptim, |
62 }; |
97 }; |
63 use crate::subproblem::InnerSettings; |
98 use crate::subproblem::{InnerSettings, InnerMethod}; |
64 use crate::seminorms::*; |
99 use crate::seminorms::*; |
65 use crate::plot::*; |
100 use crate::plot::*; |
66 use crate::{AlgorithmOverrides, CommandLineArgs}; |
101 use crate::{AlgorithmOverrides, CommandLineArgs}; |
67 use crate::tolerance::Tolerance; |
102 use crate::tolerance::Tolerance; |
68 use crate::regularisation::{Regularisation, RadonRegTerm, NonnegRadonRegTerm}; |
103 use crate::regularisation::{ |
|
104 Regularisation, |
|
105 RadonRegTerm, |
|
106 NonnegRadonRegTerm |
|
107 }; |
|
108 use crate::dataterm::{ |
|
109 L1, |
|
110 L2Squared, |
|
111 }; |
|
112 use crate::prox_penalty::{ |
|
113 RadonSquared, |
|
114 //ProxPenalty, |
|
115 }; |
|
116 use alg_tools::norms::{L2, NormExponent}; |
|
117 use alg_tools::operator_arithmetic::Weighted; |
|
118 use anyhow::anyhow; |
|
119 |
|
120 /// Available proximal terms |
|
121 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
|
122 pub enum ProxTerm { |
|
123 /// Partial-to-wave operator 𝒟. |
|
124 Wave, |
|
125 /// Radon-norm squared |
|
126 RadonSquared |
|
127 } |
69 |
128 |
70 /// Available algorithms and their configurations |
129 /// Available algorithms and their configurations |
71 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
130 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
72 pub enum AlgorithmConfig<F : Float> { |
131 pub enum AlgorithmConfig<F : Float> { |
73 FB(FBConfig<F>), |
132 FB(FBConfig<F>, ProxTerm), |
|
133 FISTA(FBConfig<F>, ProxTerm), |
74 FW(FWConfig<F>), |
134 FW(FWConfig<F>), |
75 PDPS(PDPSConfig<F>), |
135 PDPS(PDPSConfig<F>, ProxTerm), |
|
136 SlidingFB(SlidingFBConfig<F>, ProxTerm), |
|
137 ForwardPDPS(ForwardPDPSConfig<F>, ProxTerm), |
|
138 SlidingPDPS(SlidingPDPSConfig<F>, ProxTerm), |
76 } |
139 } |
77 |
140 |
78 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { |
141 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { |
79 assert!(v.len() == 3); |
142 assert!(v.len() == 3); |
80 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } |
143 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } |
81 } |
144 } |
82 |
145 |
83 impl<F : ClapFloat> AlgorithmConfig<F> { |
146 impl<F : ClapFloat> AlgorithmConfig<F> { |
84 /// Override supported parameters based on the command line. |
147 /// Override supported parameters based on the command line. |
85 pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { |
148 pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { |
|
149 let override_merging = |g : SpikeMergingMethod<F>| { |
|
150 SpikeMergingMethod { |
|
151 enabled : cli.merge.unwrap_or(g.enabled), |
|
152 radius : cli.merge_radius.unwrap_or(g.radius), |
|
153 interp : cli.merge_interp.unwrap_or(g.interp), |
|
154 } |
|
155 }; |
86 let override_fb_generic = |g : FBGenericConfig<F>| { |
156 let override_fb_generic = |g : FBGenericConfig<F>| { |
87 FBGenericConfig { |
157 FBGenericConfig { |
88 bootstrap_insertions : cli.bootstrap_insertions |
158 bootstrap_insertions : cli.bootstrap_insertions |
89 .as_ref() |
159 .as_ref() |
90 .map_or(g.bootstrap_insertions, |
160 .map_or(g.bootstrap_insertions, |
91 |n| Some((n[0], n[1]))), |
161 |n| Some((n[0], n[1]))), |
92 merge_every : cli.merge_every.unwrap_or(g.merge_every), |
162 merge_every : cli.merge_every.unwrap_or(g.merge_every), |
93 merging : cli.merging.clone().unwrap_or(g.merging), |
163 merging : override_merging(g.merging), |
94 final_merging : cli.final_merging.clone().unwrap_or(g.final_merging), |
164 final_merging : cli.final_merging.unwrap_or(g.final_merging), |
|
165 fitness_merging : cli.fitness_merging.unwrap_or(g.fitness_merging), |
95 tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance), |
166 tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance), |
96 .. g |
167 .. g |
97 } |
168 } |
98 }; |
169 }; |
|
170 let override_transport = |g : TransportConfig<F>| { |
|
171 TransportConfig { |
|
172 θ0 : cli.theta0.unwrap_or(g.θ0), |
|
173 tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con), |
|
174 adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), |
|
175 .. g |
|
176 } |
|
177 }; |
99 |
178 |
100 use AlgorithmConfig::*; |
179 use AlgorithmConfig::*; |
101 match self { |
180 match self { |
102 FB(fb) => FB(FBConfig { |
181 FB(fb, prox) => FB(FBConfig { |
103 τ0 : cli.tau0.unwrap_or(fb.τ0), |
182 τ0 : cli.tau0.unwrap_or(fb.τ0), |
104 insertion : override_fb_generic(fb.insertion), |
183 generic : override_fb_generic(fb.generic), |
105 .. fb |
184 .. fb |
106 }), |
185 }, prox), |
107 PDPS(pdps) => PDPS(PDPSConfig { |
186 FISTA(fb, prox) => FISTA(FBConfig { |
|
187 τ0 : cli.tau0.unwrap_or(fb.τ0), |
|
188 generic : override_fb_generic(fb.generic), |
|
189 .. fb |
|
190 }, prox), |
|
191 PDPS(pdps, prox) => PDPS(PDPSConfig { |
108 τ0 : cli.tau0.unwrap_or(pdps.τ0), |
192 τ0 : cli.tau0.unwrap_or(pdps.τ0), |
109 σ0 : cli.sigma0.unwrap_or(pdps.σ0), |
193 σ0 : cli.sigma0.unwrap_or(pdps.σ0), |
110 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
194 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
111 insertion : override_fb_generic(pdps.insertion), |
195 generic : override_fb_generic(pdps.generic), |
112 .. pdps |
196 .. pdps |
113 }), |
197 }, prox), |
114 FW(fw) => FW(FWConfig { |
198 FW(fw) => FW(FWConfig { |
115 merging : cli.merging.clone().unwrap_or(fw.merging), |
199 merging : override_merging(fw.merging), |
116 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), |
200 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), |
117 .. fw |
201 .. fw |
118 }) |
202 }), |
|
203 SlidingFB(sfb, prox) => SlidingFB(SlidingFBConfig { |
|
204 τ0 : cli.tau0.unwrap_or(sfb.τ0), |
|
205 transport : override_transport(sfb.transport), |
|
206 insertion : override_fb_generic(sfb.insertion), |
|
207 .. sfb |
|
208 }, prox), |
|
209 SlidingPDPS(spdps, prox) => SlidingPDPS(SlidingPDPSConfig { |
|
210 τ0 : cli.tau0.unwrap_or(spdps.τ0), |
|
211 σp0 : cli.sigmap0.unwrap_or(spdps.σp0), |
|
212 σd0 : cli.sigma0.unwrap_or(spdps.σd0), |
|
213 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
|
214 transport : override_transport(spdps.transport), |
|
215 insertion : override_fb_generic(spdps.insertion), |
|
216 .. spdps |
|
217 }, prox), |
|
218 ForwardPDPS(fpdps, prox) => ForwardPDPS(ForwardPDPSConfig { |
|
219 τ0 : cli.tau0.unwrap_or(fpdps.τ0), |
|
220 σp0 : cli.sigmap0.unwrap_or(fpdps.σp0), |
|
221 σd0 : cli.sigma0.unwrap_or(fpdps.σd0), |
|
222 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
|
223 insertion : override_fb_generic(fpdps.insertion), |
|
224 .. fpdps |
|
225 }, prox), |
119 } |
226 } |
120 } |
227 } |
121 } |
228 } |
122 |
229 |
123 /// Helper struct for tagging and [`AlgorithmConfig`] or [`Experiment`] with a name. |
230 /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name. |
124 #[derive(Clone, Debug, Serialize, Deserialize)] |
231 #[derive(Clone, Debug, Serialize, Deserialize)] |
125 pub struct Named<Data> { |
232 pub struct Named<Data> { |
126 pub name : String, |
233 pub name : String, |
127 #[serde(flatten)] |
234 #[serde(flatten)] |
128 pub data : Data, |
235 pub data : Data, |
298 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. |
477 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. |
299 pub spread : P, |
478 pub spread : P, |
300 /// Kernel $ρ$ of $𝒟$. |
479 /// Kernel $ρ$ of $𝒟$. |
301 pub kernel : K, |
480 pub kernel : K, |
302 /// True point sources |
481 /// True point sources |
303 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, |
482 pub μ_hat : RNDM<F, N>, |
304 /// Regularisation term and parameter |
483 /// Regularisation term and parameter |
305 pub regularisation : Regularisation<F>, |
484 pub regularisation : Regularisation<F>, |
306 /// For plotting : how wide should the kernels be plotted |
485 /// For plotting : how wide should the kernels be plotted |
307 pub kernel_plot_width : F, |
486 pub kernel_plot_width : F, |
308 /// Data term |
487 /// Data term |
309 pub dataterm : DataTerm, |
488 pub dataterm : DataTerm, |
310 /// A map of default configurations for algorithms |
489 /// A map of default configurations for algorithms |
311 #[serde(skip)] |
490 pub algorithm_overrides : HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>, |
312 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, |
491 /// Default merge radius |
|
492 pub default_merge_radius : F, |
|
493 } |
|
494 |
|
495 #[derive(Debug, Clone, Serialize)] |
|
496 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize> |
|
497 where F : Float + ClapFloat, |
|
498 [usize; N] : Serialize, |
|
499 NoiseDistr : Distribution<F>, |
|
500 S : Sensor<F, N>, |
|
501 P : Spread<F, N>, |
|
502 K : SimpleConvolutionKernel<F, N>, |
|
503 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, |
|
504 { |
|
505 /// Basic setup |
|
506 pub base : ExperimentV2<F, NoiseDistr, S, K, P, N>, |
|
507 /// Weight of TV term |
|
508 pub λ : F, |
|
509 /// Bias function |
|
510 pub bias : B, |
313 } |
511 } |
314 |
512 |
315 /// Trait for runnable experiments |
513 /// Trait for runnable experiments |
316 pub trait RunnableExperiment<F : ClapFloat> { |
514 pub trait RunnableExperiment<F : ClapFloat> { |
317 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. |
515 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. |
318 fn runall(&self, cli : &CommandLineArgs, |
516 fn runall(&self, cli : &CommandLineArgs, |
319 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; |
517 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; |
320 |
518 |
321 /// Return algorithm default config |
519 /// Return algorithm default config |
322 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
520 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F>; |
323 -> Named<AlgorithmConfig<F>>; |
521 } |
324 } |
522 |
325 |
523 /// Helper function to print experiment start message and save setup. |
326 // *** macro boilerplate *** |
524 /// Returns saving prefix. |
327 macro_rules! impl_experiment { |
525 fn start_experiment<E, S>( |
328 ($type:ident, $reg_field:ident, $reg_convert:path) => { |
526 experiment : &Named<E>, |
329 // *** macro *** |
527 cli : &CommandLineArgs, |
330 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for |
528 stats : S, |
331 Named<$type<F, NoiseDistr, S, K, P, N>> |
529 ) -> DynResult<String> |
332 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, |
530 where |
333 [usize; N] : Serialize, |
531 E : Serialize + std::fmt::Debug, |
334 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
532 S : Serialize, |
335 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
533 { |
336 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, |
534 let Named { name : experiment_name, data } = experiment; |
337 AutoConvolution<P> : BoundedBy<F, K>, |
535 |
338 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> |
536 println!("{}\n{}", |
339 + Copy + Serialize + std::fmt::Debug, |
537 format!("Performing experiment {}…", experiment_name).cyan(), |
340 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
538 format!("Experiment settings: {}", serde_json::to_string(&data)?).bright_black()); |
341 PlotLookup : Plotting<N>, |
539 |
342 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
540 // Set up output directory |
343 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
541 let prefix = format!("{}/{}/", cli.outdir, experiment_name); |
344 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
542 |
345 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug { |
543 // Save experiment configuration and statistics |
346 |
544 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); |
347 fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>) |
545 std::fs::create_dir_all(&prefix)?; |
348 -> Named<AlgorithmConfig<F>> { |
546 write_json(mkname_e("experiment"), experiment)?; |
349 alg.to_named( |
547 write_json(mkname_e("config"), cli)?; |
350 self.data |
548 write_json(mkname_e("stats"), &stats)?; |
351 .algorithm_defaults |
549 |
352 .get(&alg) |
550 Ok(prefix) |
353 .map_or_else(|| alg.default_config(), |
551 } |
354 |config| config.clone()) |
552 |
355 .cli_override(cli) |
553 /// Error codes for running an algorithm on an experiment. |
356 ) |
554 enum RunError { |
|
555 /// Algorithm not implemented for this experiment |
|
556 NotImplemented, |
|
557 } |
|
558 |
|
559 use RunError::*; |
|
560 |
|
561 type DoRunAllIt<'a, F, const N : usize> = LoggingIteratorFactory< |
|
562 'a, |
|
563 Timed<IterInfo<F, N>>, |
|
564 TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F, N>>> |
|
565 >; |
|
566 |
|
567 /// Helper function to run all algorithms on an experiment. |
|
568 fn do_runall<F : Float + for<'b> Deserialize<'b>, Z, const N : usize>( |
|
569 experiment_name : &String, |
|
570 prefix : &String, |
|
571 cli : &CommandLineArgs, |
|
572 algorithms : Vec<Named<AlgorithmConfig<F>>>, |
|
573 plotgrid : LinSpace<Loc<F, N>, [usize; N]>, |
|
574 mut save_extra : impl FnMut(String, Z) -> DynError, |
|
575 mut do_alg : impl FnMut( |
|
576 &AlgorithmConfig<F>, |
|
577 DoRunAllIt<F, N>, |
|
578 SeqPlotter<F, N>, |
|
579 String, |
|
580 ) -> Result<(RNDM<F, N>, Z), RunError>, |
|
581 ) -> DynError |
|
582 where |
|
583 PlotLookup : Plotting<N>, |
|
584 { |
|
585 let mut logs = Vec::new(); |
|
586 |
|
587 let iterator_options = AlgIteratorOptions{ |
|
588 max_iter : cli.max_iter, |
|
589 verbose_iter : cli.verbose_iter |
|
590 .map_or(Verbose::LogarithmicCap{base : 10, cap : 2}, |
|
591 |n| Verbose::Every(n)), |
|
592 quiet : cli.quiet, |
|
593 }; |
|
594 |
|
595 // Run the algorithm(s) |
|
596 for named @ Named { name : alg_name, data : alg } in algorithms.iter() { |
|
597 let this_prefix = format!("{}{}/", prefix, alg_name); |
|
598 |
|
599 // Create Logger and IteratorFactory |
|
600 let mut logger = Logger::new(); |
|
601 let iterator = iterator_options.instantiate() |
|
602 .timed() |
|
603 .into_log(&mut logger); |
|
604 |
|
605 let running = if !cli.quiet { |
|
606 format!("{}\n{}\n{}\n", |
|
607 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), |
|
608 format!("Iteration settings: {}", serde_json::to_string(&iterator_options)?).bright_black(), |
|
609 format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black()) |
|
610 } else { |
|
611 "".to_string() |
|
612 }; |
|
613 // |
|
614 // The following is for postprocessing, which has been disabled anyway. |
|
615 // |
|
616 // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation { |
|
617 // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), |
|
618 // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), |
|
619 // }; |
|
620 //let findim_data = reg.prepare_optimise_weights(&opA, &b); |
|
621 //let inner_config : InnerSettings<F> = Default::default(); |
|
622 //let inner_it = inner_config.iterator_options; |
|
623 |
|
624 // Create plotter and directory if needed. |
|
625 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
|
626 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()); |
|
627 |
|
628 let start = Instant::now(); |
|
629 let start_cpu = ProcessTime::now(); |
|
630 |
|
631 let (μ, z) = match do_alg(alg, iterator, plotter, running) { |
|
632 Ok(μ) => μ, |
|
633 Err(RunError::NotImplemented) => { |
|
634 let msg = format!("Algorithm “{alg_name}” not implemented for {experiment_name}. \ |
|
635 Skipping.").red(); |
|
636 eprintln!("{}", msg); |
|
637 continue |
|
638 } |
|
639 }; |
|
640 |
|
641 let elapsed = start.elapsed().as_secs_f64(); |
|
642 let cpu_time = start_cpu.elapsed().as_secs_f64(); |
|
643 |
|
644 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); |
|
645 |
|
646 // Save results |
|
647 println!("{}", "Saving results …".green()); |
|
648 |
|
649 let mkname = |t| format!("{prefix}{alg_name}_{t}"); |
|
650 |
|
651 write_json(mkname("config.json"), &named)?; |
|
652 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; |
|
653 μ.write_csv(mkname("reco.txt"))?; |
|
654 save_extra(mkname(""), z)?; |
|
655 //logger.write_csv(mkname("log.txt"))?; |
|
656 logs.push((mkname("log.txt"), logger)); |
|
657 } |
|
658 |
|
659 save_logs(logs, format!("{prefix}valuerange.json"), cli.load_valuerange) |
|
660 } |
|
661 |
|
662 #[replace_float_literals(F::cast_from(literal))] |
|
663 impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N : usize> RunnableExperiment<F> for |
|
664 Named<ExperimentV2<F, NoiseDistr, S, K, P, N>> |
|
665 where |
|
666 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> |
|
667 + Default + for<'b> Deserialize<'b>, |
|
668 [usize; N] : Serialize, |
|
669 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
|
670 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
|
671 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy |
|
672 // TODO: shold not have differentiability as a requirement, but |
|
673 // decide availability of sliding based on it. |
|
674 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
|
675 // TODO: very weird that rust only compiles with Differentiable |
|
676 // instead of the above one on references, which is required by |
|
677 // poitsource_sliding_fb_reg. |
|
678 + DifferentiableRealMapping<F, N> |
|
679 + Lipschitz<L2, FloatType=F>, |
|
680 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb. |
|
681 AutoConvolution<P> : BoundedBy<F, K>, |
|
682 K : SimpleConvolutionKernel<F, N> |
|
683 + LocalAnalysis<F, Bounds<F>, N> |
|
684 + Copy + Serialize + std::fmt::Debug, |
|
685 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
|
686 PlotLookup : Plotting<N>, |
|
687 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
|
688 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
689 RNDM<F, N> : SpikeMerging<F>, |
|
690 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, |
|
691 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, |
|
692 // PreadjointCodomain : Space + Bounded<F> + DifferentiableRealMapping<F, N>, |
|
693 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, |
|
694 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, |
|
695 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, |
|
696 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, |
|
697 { |
|
698 |
|
699 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { |
|
700 AlgorithmOverrides { |
|
701 merge_radius : Some(self.data.default_merge_radius), |
|
702 .. self.data.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) |
|
703 } |
357 } |
704 } |
358 |
705 |
359 fn runall(&self, cli : &CommandLineArgs, |
706 fn runall(&self, cli : &CommandLineArgs, |
360 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
707 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
361 // Get experiment configuration |
708 // Get experiment configuration |
362 let &Named { |
709 let &Named { |
363 name : ref experiment_name, |
710 name : ref experiment_name, |
364 data : $type { |
711 data : ExperimentV2 { |
365 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
712 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
366 ref μ_hat, /*regularisation,*/ kernel_plot_width, dataterm, noise_seed, |
713 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, |
367 .. |
714 .. |
368 } |
715 } |
369 } = self; |
716 } = self; |
370 #[allow(deprecated)] |
|
371 let regularisation = $reg_convert(self.data.$reg_field); |
|
372 |
|
373 println!("{}\n{}", |
|
374 format!("Performing experiment {}…", experiment_name).cyan(), |
|
375 format!("{:?}", &self.data).bright_black()); |
|
376 |
|
377 // Set up output directory |
|
378 let prefix = format!("{}/{}/", cli.outdir, self.name); |
|
379 |
717 |
380 // Set up algorithms |
718 // Set up algorithms |
381 let iterator_options = AlgIteratorOptions{ |
719 let algorithms = match (algs, dataterm) { |
382 max_iter : cli.max_iter, |
|
383 verbose_iter : cli.verbose_iter |
|
384 .map_or(Verbose::Logarithmic(10), |
|
385 |n| Verbose::Every(n)), |
|
386 quiet : cli.quiet, |
|
387 }; |
|
388 let algorithms = match (algs, self.data.dataterm) { |
|
389 (Some(algs), _) => algs, |
720 (Some(algs), _) => algs, |
390 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], |
721 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], |
391 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], |
722 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], |
392 }; |
723 }; |
393 |
724 |
405 let b = &b_hat + &noise; |
736 let b = &b_hat + &noise; |
406 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
737 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
407 // overloading log10 and conflicting with standard NumTraits one. |
738 // overloading log10 and conflicting with standard NumTraits one. |
408 let stats = ExperimentStats::new(&b, &noise); |
739 let stats = ExperimentStats::new(&b, &noise); |
409 |
740 |
410 // Save experiment configuration and statistics |
741 let prefix = start_experiment(&self, cli, stats)?; |
411 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); |
|
412 std::fs::create_dir_all(&prefix)?; |
|
413 write_json(mkname_e("experiment"), self)?; |
|
414 write_json(mkname_e("config"), cli)?; |
|
415 write_json(mkname_e("stats"), &stats)?; |
|
416 |
742 |
417 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
743 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
418 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
744 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
419 |
745 |
420 // Run the algorithm(s) |
746 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
421 for named @ Named { name : alg_name, data : alg } in algorithms.iter() { |
747 |
422 let this_prefix = format!("{}{}/", prefix, alg_name); |
748 let save_extra = |_, ()| Ok(()); |
423 |
749 |
424 let running = || if !cli.quiet { |
750 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, |
425 println!("{}\n{}\n{}", |
751 |alg, iterator, plotter, running| |
426 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), |
752 { |
427 format!("{:?}", iterator_options).bright_black(), |
|
428 format!("{:?}", alg).bright_black()); |
|
429 }; |
|
430 let not_implemented = || { |
|
431 let msg = format!("Algorithm “{alg_name}” not implemented for \ |
|
432 dataterm {dataterm:?} and regularisation {regularisation:?}. \ |
|
433 Skipping.").red(); |
|
434 eprintln!("{}", msg); |
|
435 }; |
|
436 // Create Logger and IteratorFactory |
|
437 let mut logger = Logger::new(); |
|
438 let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation { |
|
439 Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), |
|
440 Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), |
|
441 }; |
|
442 let findim_data = reg.prepare_optimise_weights(&opA, &b); |
|
443 let inner_config : InnerSettings<F> = Default::default(); |
|
444 let inner_it = inner_config.iterator_options; |
|
445 let logmap = |iter, Timed { cpu_time, data }| { |
|
446 let IterInfo { |
|
447 value, |
|
448 n_spikes, |
|
449 inner_iters, |
|
450 merged, |
|
451 pruned, |
|
452 postprocessing, |
|
453 this_iters, |
|
454 .. |
|
455 } = data; |
|
456 let post_value = match (postprocessing, dataterm) { |
|
457 (Some(mut μ), DataTerm::L2Squared) => { |
|
458 // Comparison postprocessing is only implemented for the case handled |
|
459 // by the FW variants. |
|
460 reg.optimise_weights( |
|
461 &mut μ, &opA, &b, &findim_data, &inner_config, |
|
462 inner_it |
|
463 ); |
|
464 dataterm.value_at_residual(opA.apply(&μ) - &b) |
|
465 + regularisation.apply(&μ) |
|
466 }, |
|
467 _ => value, |
|
468 }; |
|
469 CSVLog { |
|
470 iter, |
|
471 value, |
|
472 post_value, |
|
473 n_spikes, |
|
474 cpu_time : cpu_time.as_secs_f64(), |
|
475 inner_iters, |
|
476 merged, |
|
477 pruned, |
|
478 this_iters |
|
479 } |
|
480 }; |
|
481 let iterator = iterator_options.instantiate() |
|
482 .timed() |
|
483 .mapped(logmap) |
|
484 .into_log(&mut logger); |
|
485 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
|
486 |
|
487 // Create plotter and directory if needed. |
|
488 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
|
489 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid); |
|
490 |
|
491 // Run the algorithm |
|
492 let start = Instant::now(); |
|
493 let start_cpu = ProcessTime::now(); |
|
494 let μ = match alg { |
753 let μ = match alg { |
495 AlgorithmConfig::FB(ref algconfig) => { |
754 AlgorithmConfig::FB(ref algconfig, prox) => { |
496 match (regularisation, dataterm) { |
755 match (regularisation, dataterm, prox) { |
497 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
756 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
498 running(); |
757 print!("{running}"); |
499 pointsource_fb_reg( |
758 pointsource_fb_reg( |
500 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
759 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
501 iterator, plotter |
760 iterator, plotter |
502 ) |
761 ) |
503 }, |
762 }), |
504 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
763 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
505 running(); |
764 print!("{running}"); |
506 pointsource_fb_reg( |
765 pointsource_fb_reg( |
507 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
766 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
508 iterator, plotter |
767 iterator, plotter |
509 ) |
768 ) |
510 }, |
769 }), |
511 _ => { |
770 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
512 not_implemented(); |
771 print!("{running}"); |
513 continue |
772 pointsource_fb_reg( |
514 } |
773 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
|
774 iterator, plotter |
|
775 ) |
|
776 }), |
|
777 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
|
778 print!("{running}"); |
|
779 pointsource_fb_reg( |
|
780 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
|
781 iterator, plotter |
|
782 ) |
|
783 }), |
|
784 _ => Err(NotImplemented) |
515 } |
785 } |
516 }, |
786 }, |
517 AlgorithmConfig::PDPS(ref algconfig) => { |
787 AlgorithmConfig::FISTA(ref algconfig, prox) => { |
518 running(); |
788 match (regularisation, dataterm, prox) { |
519 match (regularisation, dataterm) { |
789 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
520 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
790 print!("{running}"); |
|
791 pointsource_fista_reg( |
|
792 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
|
793 iterator, plotter |
|
794 ) |
|
795 }), |
|
796 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
|
797 print!("{running}"); |
|
798 pointsource_fista_reg( |
|
799 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
|
800 iterator, plotter |
|
801 ) |
|
802 }), |
|
803 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
|
804 print!("{running}"); |
|
805 pointsource_fista_reg( |
|
806 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
|
807 iterator, plotter |
|
808 ) |
|
809 }), |
|
810 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
|
811 print!("{running}"); |
|
812 pointsource_fista_reg( |
|
813 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
|
814 iterator, plotter |
|
815 ) |
|
816 }), |
|
817 _ => Err(NotImplemented), |
|
818 } |
|
819 }, |
|
820 AlgorithmConfig::SlidingFB(ref algconfig, prox) => { |
|
821 match (regularisation, dataterm, prox) { |
|
822 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
|
823 print!("{running}"); |
|
824 pointsource_sliding_fb_reg( |
|
825 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
|
826 iterator, plotter |
|
827 ) |
|
828 }), |
|
829 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
|
830 print!("{running}"); |
|
831 pointsource_sliding_fb_reg( |
|
832 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
|
833 iterator, plotter |
|
834 ) |
|
835 }), |
|
836 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
|
837 print!("{running}"); |
|
838 pointsource_sliding_fb_reg( |
|
839 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
|
840 iterator, plotter |
|
841 ) |
|
842 }), |
|
843 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
|
844 print!("{running}"); |
|
845 pointsource_sliding_fb_reg( |
|
846 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
|
847 iterator, plotter |
|
848 ) |
|
849 }), |
|
850 _ => Err(NotImplemented), |
|
851 } |
|
852 }, |
|
853 AlgorithmConfig::PDPS(ref algconfig, prox) => { |
|
854 print!("{running}"); |
|
855 match (regularisation, dataterm, prox) { |
|
856 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
521 pointsource_pdps_reg( |
857 pointsource_pdps_reg( |
522 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
858 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
523 iterator, plotter, L2Squared |
859 iterator, plotter, L2Squared |
524 ) |
860 ) |
525 }, |
861 }), |
526 (Regularisation::Radon(α),DataTerm::L2Squared) => { |
862 (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
527 pointsource_pdps_reg( |
863 pointsource_pdps_reg( |
528 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
864 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
529 iterator, plotter, L2Squared |
865 iterator, plotter, L2Squared |
530 ) |
866 ) |
531 }, |
867 }), |
532 (Regularisation::NonnegRadon(α), DataTerm::L1) => { |
868 (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ |
533 pointsource_pdps_reg( |
869 pointsource_pdps_reg( |
534 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
870 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
535 iterator, plotter, L1 |
871 iterator, plotter, L1 |
536 ) |
872 ) |
537 }, |
873 }), |
538 (Regularisation::Radon(α), DataTerm::L1) => { |
874 (Regularisation::Radon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ |
539 pointsource_pdps_reg( |
875 pointsource_pdps_reg( |
540 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
876 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
541 iterator, plotter, L1 |
877 iterator, plotter, L1 |
542 ) |
878 ) |
543 }, |
879 }), |
|
880 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
|
881 pointsource_pdps_reg( |
|
882 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
|
883 iterator, plotter, L2Squared |
|
884 ) |
|
885 }), |
|
886 (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
|
887 pointsource_pdps_reg( |
|
888 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
|
889 iterator, plotter, L2Squared |
|
890 ) |
|
891 }), |
|
892 (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ |
|
893 pointsource_pdps_reg( |
|
894 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
|
895 iterator, plotter, L1 |
|
896 ) |
|
897 }), |
|
898 (Regularisation::Radon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ |
|
899 pointsource_pdps_reg( |
|
900 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
|
901 iterator, plotter, L1 |
|
902 ) |
|
903 }), |
|
904 // _ => Err(NotImplemented), |
544 } |
905 } |
545 }, |
906 }, |
546 AlgorithmConfig::FW(ref algconfig) => { |
907 AlgorithmConfig::FW(ref algconfig) => { |
547 match (regularisation, dataterm) { |
908 match (regularisation, dataterm) { |
548 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
909 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ |
549 running(); |
910 print!("{running}"); |
550 pointsource_fw_reg(&opA, &b, RadonRegTerm(α), |
911 pointsource_fw_reg(&opA, &b, RadonRegTerm(α), |
551 algconfig, iterator, plotter) |
912 algconfig, iterator, plotter) |
552 }, |
913 }), |
553 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
914 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ |
554 running(); |
915 print!("{running}"); |
555 pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), |
916 pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), |
556 algconfig, iterator, plotter) |
917 algconfig, iterator, plotter) |
557 }, |
918 }), |
558 _ => { |
919 _ => Err(NotImplemented), |
559 not_implemented(); |
|
560 continue |
|
561 } |
|
562 } |
920 } |
|
921 }, |
|
922 _ => Err(NotImplemented), |
|
923 }?; |
|
924 Ok((μ, ())) |
|
925 }) |
|
926 } |
|
927 } |
|
928 |
|
929 |
|
930 #[replace_float_literals(F::cast_from(literal))] |
|
931 impl<F, NoiseDistr, S, K, P, B, /*PreadjointCodomain,*/ const N : usize> RunnableExperiment<F> for |
|
932 Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> |
|
933 where |
|
934 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> |
|
935 + Default + for<'b> Deserialize<'b>, |
|
936 [usize; N] : Serialize, |
|
937 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
|
938 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
|
939 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy |
|
940 // TODO: shold not have differentiability as a requirement, but |
|
941 // decide availability of sliding based on it. |
|
942 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
|
943 // TODO: very weird that rust only compiles with Differentiable |
|
944 // instead of the above one on references, which is required by |
|
945 // poitsource_sliding_fb_reg. |
|
946 + DifferentiableRealMapping<F, N> |
|
947 + Lipschitz<L2, FloatType=F>, |
|
948 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb. |
|
949 AutoConvolution<P> : BoundedBy<F, K>, |
|
950 K : SimpleConvolutionKernel<F, N> |
|
951 + LocalAnalysis<F, Bounds<F>, N> |
|
952 + Copy + Serialize + std::fmt::Debug, |
|
953 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
|
954 PlotLookup : Plotting<N>, |
|
955 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
|
956 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
957 RNDM<F, N> : SpikeMerging<F>, |
|
958 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, |
|
959 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, |
|
960 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, |
|
961 // PreadjointCodomain : Bounded<F> + DifferentiableRealMapping<F, N>, |
|
962 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, |
|
963 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, |
|
964 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, |
|
965 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, |
|
966 { |
|
967 |
|
968 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { |
|
969 AlgorithmOverrides { |
|
970 merge_radius : Some(self.data.base.default_merge_radius), |
|
971 .. self.data.base.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) |
|
972 } |
|
973 } |
|
974 |
|
975 fn runall(&self, cli : &CommandLineArgs, |
|
976 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
|
977 // Get experiment configuration |
|
978 let &Named { |
|
979 name : ref experiment_name, |
|
980 data : ExperimentBiased { |
|
981 λ, |
|
982 ref bias, |
|
983 base : ExperimentV2 { |
|
984 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
|
985 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, |
|
986 .. |
563 } |
987 } |
564 }; |
988 } |
565 |
989 } = self; |
566 let elapsed = start.elapsed().as_secs_f64(); |
990 |
567 let cpu_time = start_cpu.elapsed().as_secs_f64(); |
991 // Set up algorithms |
568 |
992 let algorithms = match (algs, dataterm) { |
569 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); |
993 (Some(algs), _) => algs, |
570 |
994 _ => vec![DefaultAlgorithm::SlidingPDPS.get_named()], |
571 // Save results |
995 }; |
572 println!("{}", "Saving results…".green()); |
996 |
573 |
997 // Set up operators |
574 let mkname = |t| format!("{prefix}{alg_name}_{t}"); |
998 let depth = DynamicDepth(8); |
575 |
999 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); |
576 write_json(mkname("config.json"), &named)?; |
1000 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); |
577 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; |
1001 let opAext = RowOp(opA.clone(), IdOp::new()); |
578 μ.write_csv(mkname("reco.txt"))?; |
1002 let fnR = Zero::new(); |
579 logger.write_csv(mkname("log.txt"))?; |
1003 let h = map3(domain.span_start(), domain.span_end(), sensor_count, |
|
1004 |a, b, n| (b-a)/F::cast_from(n)) |
|
1005 .into_iter() |
|
1006 .reduce(NumTraitsFloat::max) |
|
1007 .unwrap(); |
|
1008 let z = DVector::zeros(sensor_count.iter().product()); |
|
1009 let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap(); |
|
1010 let y = opKz.apply(&z); |
|
1011 let fnH = Weighted{ base_fn : L1.as_mapping(), weight : λ}; // TODO: L_{2,1} |
|
1012 // let zero_y = y.clone(); |
|
1013 // let zeroBTFN = opA.preadjoint().apply(&zero_y); |
|
1014 // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); |
|
1015 |
|
1016 // Set up random number generator. |
|
1017 let mut rng = StdRng::seed_from_u64(noise_seed); |
|
1018 |
|
1019 // Generate the data and calculate SSNR statistic |
|
1020 let bias_vec = DVector::from_vec(opA.grid() |
|
1021 .into_iter() |
|
1022 .map(|v| bias.apply(v)) |
|
1023 .collect::<Vec<F>>()); |
|
1024 let b_hat : DVector<_> = opA.apply(μ_hat) + &bias_vec; |
|
1025 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); |
|
1026 let b = &b_hat + &noise; |
|
1027 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
|
1028 // overloading log10 and conflicting with standard NumTraits one. |
|
1029 let stats = ExperimentStats::new(&b, &noise); |
|
1030 |
|
1031 let prefix = start_experiment(&self, cli, stats)?; |
|
1032 |
|
1033 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
|
1034 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
|
1035 |
|
1036 opA.write_observable(&bias_vec, format!("{prefix}bias"))?; |
|
1037 |
|
1038 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
|
1039 |
|
1040 let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z")); |
|
1041 |
|
1042 // Run the algorithms |
|
1043 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, |
|
1044 |alg, iterator, plotter, running| |
|
1045 { |
|
1046 let Pair(μ, z) = match alg { |
|
1047 AlgorithmConfig::ForwardPDPS(ref algconfig, prox) => { |
|
1048 match (regularisation, dataterm, prox) { |
|
1049 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
|
1050 print!("{running}"); |
|
1051 pointsource_forward_pdps_pair( |
|
1052 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
|
1053 iterator, plotter, |
|
1054 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
1055 ) |
|
1056 }), |
|
1057 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
|
1058 print!("{running}"); |
|
1059 pointsource_forward_pdps_pair( |
|
1060 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, |
|
1061 iterator, plotter, |
|
1062 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
1063 ) |
|
1064 }), |
|
1065 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
|
1066 print!("{running}"); |
|
1067 pointsource_forward_pdps_pair( |
|
1068 &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
|
1069 iterator, plotter, |
|
1070 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
1071 ) |
|
1072 }), |
|
1073 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
|
1074 print!("{running}"); |
|
1075 pointsource_forward_pdps_pair( |
|
1076 &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
|
1077 iterator, plotter, |
|
1078 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
1079 ) |
|
1080 }), |
|
1081 _ => Err(NotImplemented) |
|
1082 } |
|
1083 }, |
|
1084 AlgorithmConfig::SlidingPDPS(ref algconfig, prox) => { |
|
1085 match (regularisation, dataterm, prox) { |
|
1086 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
|
1087 print!("{running}"); |
|
1088 pointsource_sliding_pdps_pair( |
|
1089 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
|
1090 iterator, plotter, |
|
1091 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
1092 ) |
|
1093 }), |
|
1094 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
|
1095 print!("{running}"); |
|
1096 pointsource_sliding_pdps_pair( |
|
1097 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, |
|
1098 iterator, plotter, |
|
1099 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
1100 ) |
|
1101 }), |
|
1102 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
|
1103 print!("{running}"); |
|
1104 pointsource_sliding_pdps_pair( |
|
1105 &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
|
1106 iterator, plotter, |
|
1107 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
1108 ) |
|
1109 }), |
|
1110 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
|
1111 print!("{running}"); |
|
1112 pointsource_sliding_pdps_pair( |
|
1113 &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
|
1114 iterator, plotter, |
|
1115 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
|
1116 ) |
|
1117 }), |
|
1118 _ => Err(NotImplemented) |
|
1119 } |
|
1120 }, |
|
1121 _ => Err(NotImplemented) |
|
1122 }?; |
|
1123 Ok((μ, z)) |
|
1124 }) |
|
1125 } |
|
1126 } |
|
1127 |
|
1128 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
|
1129 struct ValueRange<F : Float> { |
|
1130 ini : F, |
|
1131 min : F, |
|
1132 } |
|
1133 |
|
1134 impl<F : Float> ValueRange<F> { |
|
1135 fn expand_with(self, other : Self) -> Self { |
|
1136 ValueRange { |
|
1137 ini : self.ini.max(other.ini), |
|
1138 min : self.min.min(other.min), |
580 } |
1139 } |
581 |
1140 } |
582 Ok(()) |
1141 } |
583 } |
1142 |
584 } |
1143 /// Calculative minimum and maximum values of all the `logs`, and save them into |
585 // *** macro end boiler plate *** |
1144 /// corresponding file names given as the first elements of the tuples in the vectors. |
586 }} |
1145 fn save_logs<F : Float + for<'b> Deserialize<'b>, const N : usize>( |
587 // *** actual code *** |
1146 logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)>, |
588 |
1147 valuerange_file : String, |
589 impl_experiment!(ExperimentV2, regularisation, std::convert::identity); |
1148 load_valuerange : bool, |
|
1149 ) -> DynError { |
|
1150 // Process logs for relative values |
|
1151 println!("{}", "Processing logs…"); |
|
1152 |
|
1153 // Find minimum value and initial value within a single log |
|
1154 let proc_single_log = |log : &Logger<Timed<IterInfo<F, N>>>| { |
|
1155 let d = log.data(); |
|
1156 let mi = d.iter() |
|
1157 .map(|i| i.data.value) |
|
1158 .reduce(NumTraitsFloat::min); |
|
1159 d.first() |
|
1160 .map(|i| i.data.value) |
|
1161 .zip(mi) |
|
1162 .map(|(ini, min)| ValueRange{ ini, min }) |
|
1163 }; |
|
1164 |
|
1165 // Find minimum and maximum value over all logs |
|
1166 let mut v = logs.iter() |
|
1167 .filter_map(|&(_, ref log)| proc_single_log(log)) |
|
1168 .reduce(|v1, v2| v1.expand_with(v2)) |
|
1169 .ok_or(anyhow!("No algorithms found"))?; |
|
1170 |
|
1171 // Load existing range |
|
1172 if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() { |
|
1173 let data = std::fs::read_to_string(&valuerange_file)?; |
|
1174 v = v.expand_with(serde_json::from_str(&data)?); |
|
1175 } |
|
1176 |
|
1177 let logmap = |Timed { cpu_time, iter, data }| { |
|
1178 let IterInfo { |
|
1179 value, |
|
1180 n_spikes, |
|
1181 inner_iters, |
|
1182 merged, |
|
1183 pruned, |
|
1184 //postprocessing, |
|
1185 this_iters, |
|
1186 .. |
|
1187 } = data; |
|
1188 // let post_value = match (postprocessing, dataterm) { |
|
1189 // (Some(mut μ), DataTerm::L2Squared) => { |
|
1190 // // Comparison postprocessing is only implemented for the case handled |
|
1191 // // by the FW variants. |
|
1192 // reg.optimise_weights( |
|
1193 // &mut μ, &opA, &b, &findim_data, &inner_config, |
|
1194 // inner_it |
|
1195 // ); |
|
1196 // dataterm.value_at_residual(opA.apply(&μ) - &b) |
|
1197 // + regularisation.apply(&μ) |
|
1198 // }, |
|
1199 // _ => value, |
|
1200 // }; |
|
1201 let relative_value = (value - v.min)/(v.ini - v.min); |
|
1202 CSVLog { |
|
1203 iter, |
|
1204 value, |
|
1205 relative_value, |
|
1206 //post_value, |
|
1207 n_spikes, |
|
1208 cpu_time : cpu_time.as_secs_f64(), |
|
1209 inner_iters, |
|
1210 merged, |
|
1211 pruned, |
|
1212 this_iters |
|
1213 } |
|
1214 }; |
|
1215 |
|
1216 println!("{}", "Saving logs …".green()); |
|
1217 |
|
1218 serde_json::to_writer_pretty(std::fs::File::create(&valuerange_file)?, &v)?; |
|
1219 |
|
1220 for (name, logger) in logs { |
|
1221 logger.map(logmap).write_csv(name)?; |
|
1222 } |
|
1223 |
|
1224 Ok(()) |
|
1225 } |
|
1226 |
590 |
1227 |
591 /// Plot experiment setup |
1228 /// Plot experiment setup |
592 #[replace_float_literals(F::cast_from(literal))] |
1229 #[replace_float_literals(F::cast_from(literal))] |
593 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( |
1230 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( |
594 cli : &CommandLineArgs, |
1231 cli : &CommandLineArgs, |
595 prefix : &String, |
1232 prefix : &String, |
596 domain : &Cube<F, N>, |
1233 domain : &Cube<F, N>, |
597 sensor : &Sensor, |
1234 sensor : &Sensor, |
598 kernel : &Kernel, |
1235 kernel : &Kernel, |
599 spread : &Spread, |
1236 spread : &Spread, |
600 μ_hat : &DiscreteMeasure<Loc<F, N>, F>, |
1237 μ_hat : &RNDM<F, N>, |
601 op𝒟 : &𝒟, |
1238 op𝒟 : &𝒟, |
602 opA : &A, |
1239 opA : &A, |
603 b_hat : &A::Observable, |
1240 b_hat : &A::Observable, |
604 b : &A::Observable, |
1241 b : &A::Observable, |
605 kernel_plot_width : F, |
1242 kernel_plot_width : F, |
606 ) -> DynError |
1243 ) -> DynError |
607 where F : Float + ToNalgebraRealField, |
1244 where F : Float + ToNalgebraRealField, |
608 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
1245 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
609 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
1246 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
610 Kernel : RealMapping<F, N> + Support<F, N>, |
1247 Kernel : RealMapping<F, N> + Support<F, N>, |
611 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, |
1248 Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>, |
612 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
1249 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
613 𝒟::Codomain : RealMapping<F, N>, |
1250 𝒟::Codomain : RealMapping<F, N>, |
614 A : ForwardModel<Loc<F, N>, F>, |
1251 A : ForwardModel<RNDM<F, N>, F>, |
615 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, |
1252 for<'a> &'a A::Observable : Instance<A::Observable>, |
|
1253 A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>, |
616 PlotLookup : Plotting<N>, |
1254 PlotLookup : Plotting<N>, |
617 Cube<F, N> : SetOrd { |
1255 Cube<F, N> : SetOrd { |
618 |
1256 |
619 if cli.plot < PlotLevel::Data { |
1257 if cli.plot < PlotLevel::Data { |
620 return Ok(()) |
1258 return Ok(()) |
621 } |
1259 } |
622 |
1260 |
623 let base = Convolution(sensor.clone(), spread.clone()); |
1261 let base = Convolution(sensor.clone(), spread.clone()); |
624 |
1262 |
625 let resolution = if N==1 { 100 } else { 40 }; |
1263 let resolution = if N==1 { 100 } else { 40 }; |
626 let pfx = |n| format!("{}{}", prefix, n); |
1264 let pfx = |n| format!("{prefix}{n}"); |
627 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
1265 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
628 |
1266 |
629 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); |
1267 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor")); |
630 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); |
1268 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel")); |
631 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); |
1269 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread")); |
632 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); |
1270 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor")); |
633 |
1271 |
634 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
1272 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
635 |
1273 |
636 let ω_hat = op𝒟.apply(μ_hat); |
1274 let ω_hat = op𝒟.apply(μ_hat); |
637 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
1275 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
638 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string()); |
1276 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat")); |
639 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"), |
1277 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise")); |
640 "noise Aᵀ(Aμ̂ - b)".to_string()); |
|
641 |
1278 |
642 let preadj_b = opA.preadjoint().apply(b); |
1279 let preadj_b = opA.preadjoint().apply(b); |
643 let preadj_b_hat = opA.preadjoint().apply(b_hat); |
1280 let preadj_b_hat = opA.preadjoint().apply(b_hat); |
644 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); |
1281 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); |
645 PlotLookup::plot_into_file_spikes( |
1282 PlotLookup::plot_into_file_spikes( |
646 "Aᵀb".to_string(), &preadj_b, |
1283 Some(&preadj_b), |
647 "Aᵀb̂".to_string(), Some(&preadj_b_hat), |
1284 Some(&preadj_b_hat), |
648 plotgrid2, None, &μ_hat, |
1285 plotgrid2, |
|
1286 &μ_hat, |
649 pfx("omega_b") |
1287 pfx("omega_b") |
650 ); |
1288 ); |
|
1289 PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b")); |
|
1290 PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat")); |
651 |
1291 |
652 // Save true solution and observables |
1292 // Save true solution and observables |
653 let pfx = |n| format!("{}{}", prefix, n); |
|
654 μ_hat.write_csv(pfx("orig.txt"))?; |
1293 μ_hat.write_csv(pfx("orig.txt"))?; |
655 opA.write_observable(&b_hat, pfx("b_hat"))?; |
1294 opA.write_observable(&b_hat, pfx("b_hat"))?; |
656 opA.write_observable(&b, pfx("b_noisy")) |
1295 opA.write_observable(&b, pfx("b_noisy")) |
657 } |
1296 } |
658 |
|
659 // |
|
660 // Deprecated interface |
|
661 // |
|
662 |
|
663 /// Struct for experiment configurations |
|
664 #[derive(Debug, Clone, Serialize)] |
|
665 pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> |
|
666 where F : Float, |
|
667 [usize; N] : Serialize, |
|
668 NoiseDistr : Distribution<F>, |
|
669 S : Sensor<F, N>, |
|
670 P : Spread<F, N>, |
|
671 K : SimpleConvolutionKernel<F, N>, |
|
672 { |
|
673 /// Domain $Ω$. |
|
674 pub domain : Cube<F, N>, |
|
675 /// Number of sensors along each dimension |
|
676 pub sensor_count : [usize; N], |
|
677 /// Noise distribution |
|
678 pub noise_distr : NoiseDistr, |
|
679 /// Seed for random noise generation (for repeatable experiments) |
|
680 pub noise_seed : u64, |
|
681 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. |
|
682 pub sensor : S, |
|
683 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. |
|
684 pub spread : P, |
|
685 /// Kernel $ρ$ of $𝒟$. |
|
686 pub kernel : K, |
|
687 /// True point sources |
|
688 pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, |
|
689 /// Regularisation parameter |
|
690 #[deprecated(note = "Use [`ExperimentV2`], which replaces `α` by more generic `regularisation`")] |
|
691 pub α : F, |
|
692 /// For plotting : how wide should the kernels be plotted |
|
693 pub kernel_plot_width : F, |
|
694 /// Data term |
|
695 pub dataterm : DataTerm, |
|
696 /// A map of default configurations for algorithms |
|
697 #[serde(skip)] |
|
698 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, |
|
699 } |
|
700 |
|
701 impl_experiment!(Experiment, α, Regularisation::NonnegRadon); |
|