| 1 //! Random distribution wrappers and implementations |
1 //! Random distribution wrappers and implementations |
| 2 |
2 |
| |
3 use alg_tools::types::*; |
| 3 use numeric_literals::replace_float_literals; |
4 use numeric_literals::replace_float_literals; |
| 4 use rand::Rng; |
5 use rand::Rng; |
| 5 use rand_distr::{Distribution, Normal, StandardNormal, NormalError}; |
6 use rand_distr::{Distribution, Normal, NormalError, StandardNormal}; |
| 6 use serde::{Serialize, Deserialize}; |
7 use serde::ser::{SerializeStruct, Serializer}; |
| 7 use serde::ser::{Serializer, SerializeStruct}; |
8 use serde::{Deserialize, Serialize}; |
| 8 use alg_tools::types::*; |
|
| 9 |
9 |
| 10 /// Wrapper for [`Normal`] that can be serialized by serde. |
10 /// Wrapper for [`Normal`] that can be serialized by serde. |
| 11 #[derive(Debug)] |
11 #[derive(Debug)] |
| 12 pub struct SerializableNormal<T : Float>(Normal<T>) |
12 pub struct SerializableNormal<T: Float>(Normal<T>) |
| 13 where StandardNormal : Distribution<T>; |
13 where |
| |
14 StandardNormal: Distribution<T>; |
| 14 |
15 |
| 15 impl<T : Float> Distribution<T> for SerializableNormal<T> |
16 impl<T: Float> Distribution<T> for SerializableNormal<T> |
| 16 where StandardNormal : Distribution<T> { |
17 where |
| |
18 StandardNormal: Distribution<T>, |
| |
19 { |
| 17 fn sample<R>(&self, rng: &mut R) -> T |
20 fn sample<R>(&self, rng: &mut R) -> T |
| 18 where |
21 where |
| 19 R : Rng + ?Sized |
22 R: Rng + ?Sized, |
| 20 { self.0.sample(rng) } |
23 { |
| |
24 self.0.sample(rng) |
| |
25 } |
| 21 } |
26 } |
| 22 |
27 |
| 23 impl<T : Float> SerializableNormal<T> |
28 impl<T: Float> SerializableNormal<T> |
| 24 where StandardNormal : Distribution<T> { |
29 where |
| 25 pub fn new(mean : T, std_dev : T) -> Result<SerializableNormal<T>, NormalError> { |
30 StandardNormal: Distribution<T>, |
| |
31 { |
| |
32 pub fn new(mean: T, std_dev: T) -> Result<SerializableNormal<T>, NormalError> { |
| 26 Ok(SerializableNormal(Normal::new(mean, std_dev)?)) |
33 Ok(SerializableNormal(Normal::new(mean, std_dev)?)) |
| 27 } |
34 } |
| 28 } |
35 } |
| 29 |
36 |
| 30 impl<F> Serialize for SerializableNormal<F> |
37 impl<F> Serialize for SerializableNormal<F> |
| 31 where |
38 where |
| 32 StandardNormal : Distribution<F>, |
39 StandardNormal: Distribution<F>, |
| 33 F: Float + Serialize, |
40 F: Float + Serialize, |
| 34 { |
41 { |
| 35 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> |
42 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> |
| 36 where |
43 where |
| 37 S: Serializer, |
44 S: Serializer, |
| 46 /// Salt-and-pepper noise distribution |
53 /// Salt-and-pepper noise distribution |
| 47 /// |
54 /// |
| 48 /// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding |
55 /// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding |
| 49 /// probabilities $\\{1-p, p/2, p/2\\}$. |
56 /// probabilities $\\{1-p, p/2, p/2\\}$. |
| 50 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
57 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
| 51 pub struct SaltAndPepper<T : Float>{ |
58 pub struct SaltAndPepper<T: Float> { |
| 52 /// The magnitude parameter $m$ |
59 /// The magnitude parameter $m$ |
| 53 magnitude : T, |
60 magnitude: T, |
| 54 /// The probability parameter $p$ |
61 /// The probability parameter $p$ |
| 55 probability : T |
62 probability: T, |
| 56 } |
63 } |
| 57 |
64 |
| 58 /// Error for [`SaltAndPepper`]. |
65 /// Error for [`SaltAndPepper`]. |
| 59 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
66 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
| 60 pub enum SaltAndPepperError { |
67 pub enum SaltAndPepperError { |
| 64 impl std::error::Error for SaltAndPepperError {} |
71 impl std::error::Error for SaltAndPepperError {} |
| 65 |
72 |
| 66 impl std::fmt::Display for SaltAndPepperError { |
73 impl std::fmt::Display for SaltAndPepperError { |
| 67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 68 f.write_str(match self { |
75 f.write_str(match self { |
| 69 SaltAndPepperError::InvalidProbability => |
76 SaltAndPepperError::InvalidProbability => { |
| 70 " The probability parameter is not in the range [0, 1].", |
77 " The probability parameter is not in the range [0, 1]." |
| |
78 } |
| 71 }) |
79 }) |
| 72 } |
80 } |
| 73 } |
81 } |
| 74 |
82 |
| 75 #[replace_float_literals(T::cast_from(literal))] |
83 #[replace_float_literals(T::cast_from(literal))] |
| 76 impl<T : Float> SaltAndPepper<T> { |
84 impl<T: Float> SaltAndPepper<T> { |
| 77 pub fn new(magnitude : T, probability : T) -> Result<SaltAndPepper<T>, SaltAndPepperError> { |
85 pub fn new(magnitude: T, probability: T) -> Result<SaltAndPepper<T>, SaltAndPepperError> { |
| 78 if probability > 1.0 || probability < 0.0 { |
86 if probability > 1.0 || probability < 0.0 { |
| 79 Err(SaltAndPepperError::InvalidProbability) |
87 Err(SaltAndPepperError::InvalidProbability) |
| 80 } else { |
88 } else { |
| 81 Ok(SaltAndPepper { magnitude, probability }) |
89 Ok(SaltAndPepper { magnitude, probability }) |
| 82 } |
90 } |
| 83 } |
91 } |
| 84 } |
92 } |
| 85 |
93 |
| 86 #[replace_float_literals(T::cast_from(literal))] |
94 #[replace_float_literals(T::cast_from(literal))] |
| 87 impl<T : Float> Distribution<T> for SaltAndPepper<T> { |
95 impl<T: Float> Distribution<T> for SaltAndPepper<T> { |
| 88 fn sample<R>(&self, rng: &mut R) -> T |
96 fn sample<R>(&self, rng: &mut R) -> T |
| 89 where |
97 where |
| 90 R : Rng + ?Sized |
98 R: Rng + ?Sized, |
| 91 { |
99 { |
| 92 let (p, sign) : (float, bool) = rng.gen(); |
100 let (p, sign): (float, bool) = rng.random(); |
| 93 match (p < self.probability.as_(), sign) { |
101 match (p < self.probability.as_(), sign) { |
| 94 (false, _) => 0.0, |
102 (false, _) => 0.0, |
| 95 (true, true) => self.magnitude, |
103 (true, true) => self.magnitude, |
| 96 (true, false) => -self.magnitude, |
104 (true, false) => -self.magnitude, |
| 97 } |
105 } |
| 98 } |
106 } |
| 99 } |
107 } |