src/rand_distr.rs

changeset 0
eb3c7813b67a
child 23
9869fa1e0ccd
equal deleted inserted replaced
-1:000000000000 0:eb3c7813b67a
1 //! Random distribution wrappers and implementations
2
3 use numeric_literals::replace_float_literals;
4 use rand::Rng;
5 use rand_distr::{Distribution, Normal, StandardNormal, NormalError};
6 use serde::{Serialize, Deserialize};
7 use serde::ser::{Serializer, SerializeStruct};
8 use alg_tools::types::*;
9
10 /// Wrapper for [`Normal`] that can be serialized by serde.
11 pub struct SerializableNormal<T : Float>(Normal<T>)
12 where StandardNormal : Distribution<T>;
13
14 impl<T : Float> Distribution<T> for SerializableNormal<T>
15 where StandardNormal : Distribution<T> {
16 fn sample<R>(&self, rng: &mut R) -> T
17 where
18 R : Rng + ?Sized
19 { self.0.sample(rng) }
20 }
21
22 impl<T : Float> SerializableNormal<T>
23 where StandardNormal : Distribution<T> {
24 pub fn new(mean : T, std_dev : T) -> Result<SerializableNormal<T>, NormalError> {
25 Ok(SerializableNormal(Normal::new(mean, std_dev)?))
26 }
27 }
28
29 impl<F> Serialize for SerializableNormal<F>
30 where
31 StandardNormal : Distribution<F>,
32 F: Float + Serialize,
33 {
34 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35 where
36 S: Serializer,
37 {
38 let mut s = serializer.serialize_struct("Normal", 2)?;
39 s.serialize_field("mean", &self.0.mean())?;
40 s.serialize_field("std_dev", &self.0.std_dev())?;
41 s.end()
42 }
43 }
44
45 /// Salt-and-pepper noise distribution
46 ///
47 /// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding
48 /// probabilities $\\{1-p, p/2, p/2\\}$.
49 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
50 pub struct SaltAndPepper<T : Float>{
51 /// The magnitude parameter $m$
52 magnitude : T,
53 /// The probability parameter $p$
54 probability : T
55 }
56
57 /// Error for [`SaltAndPepper`].
58 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
59 pub enum SaltAndPepperError {
60 /// The probability parameter $p$ is not in the range [0, 1].
61 InvalidProbability,
62 }
63 impl std::error::Error for SaltAndPepperError {}
64
65 impl std::fmt::Display for SaltAndPepperError {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.write_str(match self {
68 SaltAndPepperError::InvalidProbability =>
69 " The probability parameter is not in the range [0, 1].",
70 })
71 }
72 }
73
74 #[replace_float_literals(T::cast_from(literal))]
75 impl<T : Float> SaltAndPepper<T> {
76 pub fn new(magnitude : T, probability : T) -> Result<SaltAndPepper<T>, SaltAndPepperError> {
77 if probability > 1.0 || probability < 0.0 {
78 Err(SaltAndPepperError::InvalidProbability)
79 } else {
80 Ok(SaltAndPepper { magnitude, probability })
81 }
82 }
83 }
84
85 #[replace_float_literals(T::cast_from(literal))]
86 impl<T : Float> Distribution<T> for SaltAndPepper<T> {
87 fn sample<R>(&self, rng: &mut R) -> T
88 where
89 R : Rng + ?Sized
90 {
91 let (p, sign) : (float, bool) = rng.gen();
92 match (p < self.probability.as_(), sign) {
93 (false, _) => 0.0,
94 (true, true) => self.magnitude,
95 (true, false) => -self.magnitude,
96 }
97 }
98 }

mercurial