src/run.rs

branch
dev
changeset 38
0f59c0d02e13
parent 37
c5d8bd1a7728
equal deleted inserted replaced
37:c5d8bd1a7728 38:0f59c0d02e13
325 // format!("{:?}", self).hash(state); 325 // format!("{:?}", self).hash(state);
326 // } 326 // }
327 // } 327 // }
328 328
329 /// Plotting level configuration 329 /// Plotting level configuration
330 #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, ValueEnum, Debug)] 330 #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, ValueEnum, Debug)]
331 pub enum PlotLevel { 331 pub enum PlotLevel {
332 /// Plot nothing 332 /// Plot nothing
333 #[clap(name = "none")] 333 #[clap(name = "none")]
334 None, 334 None,
335 /// Plot problem data 335 /// Plot problem data
355 DefaultBT<F, N>, 355 DefaultBT<F, N>,
356 N 356 N
357 >; 357 >;
358 358
359 /// This is a dirty workaround to rust-csv not supporting struct flattening etc. 359 /// This is a dirty workaround to rust-csv not supporting struct flattening etc.
360 #[derive(Serialize)] 360 #[derive(Serialize, Deserialize)]
361 struct CSVLog<F> { 361 struct CSVLog<F> {
362 iter : usize, 362 iter : usize,
363 cpu_time : f64, 363 cpu_time : f64,
364 value : F, 364 value : F,
365 relative_value : F, 365 relative_value : F,
370 pruned : usize, 370 pruned : usize,
371 this_iters : usize, 371 this_iters : usize,
372 } 372 }
373 373
374 /// Collected experiment statistics 374 /// Collected experiment statistics
375 #[derive(Clone, Debug, Serialize)] 375 #[derive(Clone, Debug, Serialize, Deserialize)]
376 struct ExperimentStats<F : Float> { 376 struct ExperimentStats<F : Float> {
377 /// Signal-to-noise ratio in decibels 377 /// Signal-to-noise ratio in decibels
378 ssnr : F, 378 ssnr : F,
379 /// Proportion of noise in the signal as a number in $[0, 1]$. 379 /// Proportion of noise in the signal as a number in $[0, 1]$.
380 noise_ratio : F, 380 noise_ratio : F,
396 when : Utc::now(), 396 when : Utc::now(),
397 } 397 }
398 } 398 }
399 } 399 }
400 /// Collected algorithm statistics 400 /// Collected algorithm statistics
401 #[derive(Clone, Debug, Serialize)] 401 #[derive(Clone, Debug, Serialize, Deserialize)]
402 struct AlgorithmStats<F : Float> { 402 struct AlgorithmStats<F : Float> {
403 /// Overall CPU time spent 403 /// Overall CPU time spent
404 cpu_time : F, 404 cpu_time : F,
405 /// Real time spent 405 /// Real time spent
406 elapsed : F 406 elapsed : F
414 Ok(()) 414 Ok(())
415 } 415 }
416 416
417 417
418 /// Struct for experiment configurations 418 /// Struct for experiment configurations
419 #[derive(Debug, Clone, Serialize)] 419 #[derive(Debug, Clone, Serialize, Deserialize)]
420 #[serde(bound(
421 serialize = "Cube<F, N> : Serialize,
422 NoiseDistr : Serialize,
423 [usize; N] : Serialize,
424 RNDM<F, N> : Serialize,
425 Regularisation<F> : Serialize,
426 F : Serialize,
427 S : Serialize,
428 P : Serialize,
429 K : Serialize",
430 deserialize = "Cube<F, N> : for<'a> Deserialize<'a>,
431 NoiseDistr : for<'a> Deserialize<'a>,
432 [usize; N] : for<'a> Deserialize<'a>,
433 RNDM<F, N> : for<'a> Deserialize<'a>,
434 Regularisation<F> : for<'a> Deserialize<'a>,
435 F : for<'a> Deserialize<'a>,
436 S : for<'a> Deserialize<'a>,
437 P : for<'a> Deserialize<'a>,
438 K : for<'a> Deserialize<'a>,",
439 ))]
420 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize> 440 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize>
421 where F : Float, 441 where F : Float,
422 [usize; N] : Serialize,
423 NoiseDistr : Distribution<F>, 442 NoiseDistr : Distribution<F>,
424 S : Sensor<F, N>, 443 S : Sensor<F, N>,
425 P : Spread<F, N>, 444 P : Spread<F, N>,
426 K : SimpleConvolutionKernel<F, N>, 445 K : SimpleConvolutionKernel<F, N>,
427 { 446 {
450 /// A map of default configurations for algorithms 469 /// A map of default configurations for algorithms
451 #[serde(skip)] 470 #[serde(skip)]
452 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, 471 pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>,
453 } 472 }
454 473
455 #[derive(Debug, Clone, Serialize)] 474 #[derive(Debug, Clone, Serialize, Deserialize)]
475 #[serde(bound(
476 serialize = "ExperimentV2<F, NoiseDistr, S, K, P, N> : Serialize,
477 B : Serialize,
478 F : Serialize",
479 deserialize = "ExperimentV2<F, NoiseDistr, S, K, P, N >: for<'a> Deserialize<'a>,
480 B : for<'a> Deserialize<'a>,
481 F : for<'a> Deserialize<'a>",
482 ))]
456 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize> 483 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize>
457 where F : Float, 484 where F : Float,
458 [usize; N] : Serialize, 485 [usize; N] : Serialize,
459 NoiseDistr : Distribution<F>, 486 NoiseDistr : Distribution<F>,
460 S : Sensor<F, N>, 487 S : Sensor<F, N>,
539 String, 566 String,
540 ) -> Result<(RNDM<F, N>, Z), RunError>, 567 ) -> Result<(RNDM<F, N>, Z), RunError>,
541 ) -> DynError 568 ) -> DynError
542 where 569 where
543 PlotLookup : Plotting<N>, 570 PlotLookup : Plotting<N>,
571 DeltaMeasure<Loc<F, N>, F> : Serialize,
544 { 572 {
545 let mut logs = Vec::new(); 573 let mut logs = Vec::new();
546 574
547 let iterator_options = AlgIteratorOptions{ 575 let iterator_options = AlgIteratorOptions{
548 max_iter : cli.max_iter, 576 max_iter : cli.max_iter,
645 PlotLookup : Plotting<N>, 673 PlotLookup : Plotting<N>,
646 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, 674 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
647 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 675 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
648 RNDM<F, N> : SpikeMerging<F>, 676 RNDM<F, N> : SpikeMerging<F>,
649 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, 677 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug,
678 [F; N] : Serialize,
679 [[F; 2]; N] : Serialize,
650 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, 680 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>,
651 // PreadjointCodomain : Space + Bounded<F> + DifferentiableRealMapping<F, N>, 681 // PreadjointCodomain : Space + Bounded<F> + DifferentiableRealMapping<F, N>,
652 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, 682 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
653 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, 683 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
654 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, 684 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
910 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, 940 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
911 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 941 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
912 RNDM<F, N> : SpikeMerging<F>, 942 RNDM<F, N> : SpikeMerging<F>,
913 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, 943 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug,
914 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, 944 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug,
945 [F; N] : Serialize,
946 [[F; 2]; N] : Serialize,
915 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, 947 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>,
916 // PreadjointCodomain : Bounded<F> + DifferentiableRealMapping<F, N>, 948 // PreadjointCodomain : Bounded<F> + DifferentiableRealMapping<F, N>,
917 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, 949 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
918 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, 950 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
919 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, 951 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
1167 opA : &A, 1199 opA : &A,
1168 b_hat : &A::Observable, 1200 b_hat : &A::Observable,
1169 b : &A::Observable, 1201 b : &A::Observable,
1170 kernel_plot_width : F, 1202 kernel_plot_width : F,
1171 ) -> DynError 1203 ) -> DynError
1172 where F : Float + ToNalgebraRealField, 1204 where
1173 Sensor : RealMapping<F, N> + Support<F, N> + Clone, 1205 F : Float + ToNalgebraRealField,
1174 Spread : RealMapping<F, N> + Support<F, N> + Clone, 1206 Sensor : RealMapping<F, N> + Support<F, N> + Clone,
1175 Kernel : RealMapping<F, N> + Support<F, N>, 1207 Spread : RealMapping<F, N> + Support<F, N> + Clone,
1176 Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>, 1208 Kernel : RealMapping<F, N> + Support<F, N>,
1177 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, 1209 Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>,
1178 𝒟::Codomain : RealMapping<F, N>, 1210 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>,
1179 A : ForwardModel<RNDM<F, N>, F>, 1211 𝒟::Codomain : RealMapping<F, N>,
1180 for<'a> &'a A::Observable : Instance<A::Observable>, 1212 A : ForwardModel<RNDM<F, N>, F>,
1181 A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>, 1213 for<'a> &'a A::Observable : Instance<A::Observable>,
1182 PlotLookup : Plotting<N>, 1214 A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>,
1183 Cube<F, N> : SetOrd { 1215 PlotLookup : Plotting<N>,
1216 Cube<F, N> : SetOrd,
1217 DeltaMeasure<Loc<F, N>, F> : Serialize,
1218 {
1184 1219
1185 if cli.plot < PlotLevel::Data { 1220 if cli.plot < PlotLevel::Data {
1186 return Ok(()) 1221 return Ok(())
1187 } 1222 }
1188 1223

mercurial