|
1 //! Random distribution wrappers and implementations |
|
2 |
|
3 use numeric_literals::replace_float_literals; |
|
4 use rand::Rng; |
|
5 use rand_distr::{Distribution, Normal, StandardNormal, NormalError}; |
|
6 use serde::{Serialize, Deserialize}; |
|
7 use serde::ser::{Serializer, SerializeStruct}; |
|
8 use alg_tools::types::*; |
|
9 |
|
10 /// Wrapper for [`Normal`] that can be serialized by serde. |
|
11 pub struct SerializableNormal<T : Float>(Normal<T>) |
|
12 where StandardNormal : Distribution<T>; |
|
13 |
|
14 impl<T : Float> Distribution<T> for SerializableNormal<T> |
|
15 where StandardNormal : Distribution<T> { |
|
16 fn sample<R>(&self, rng: &mut R) -> T |
|
17 where |
|
18 R : Rng + ?Sized |
|
19 { self.0.sample(rng) } |
|
20 } |
|
21 |
|
22 impl<T : Float> SerializableNormal<T> |
|
23 where StandardNormal : Distribution<T> { |
|
24 pub fn new(mean : T, std_dev : T) -> Result<SerializableNormal<T>, NormalError> { |
|
25 Ok(SerializableNormal(Normal::new(mean, std_dev)?)) |
|
26 } |
|
27 } |
|
28 |
|
29 impl<F> Serialize for SerializableNormal<F> |
|
30 where |
|
31 StandardNormal : Distribution<F>, |
|
32 F: Float + Serialize, |
|
33 { |
|
34 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> |
|
35 where |
|
36 S: Serializer, |
|
37 { |
|
38 let mut s = serializer.serialize_struct("Normal", 2)?; |
|
39 s.serialize_field("mean", &self.0.mean())?; |
|
40 s.serialize_field("std_dev", &self.0.std_dev())?; |
|
41 s.end() |
|
42 } |
|
43 } |
|
44 |
|
45 /// Salt-and-pepper noise distribution |
|
46 /// |
|
47 /// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding |
|
48 /// probabilities $\\{1-p, p/2, p/2\\}$. |
|
49 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
|
50 pub struct SaltAndPepper<T : Float>{ |
|
51 /// The magnitude parameter $m$ |
|
52 magnitude : T, |
|
53 /// The probability parameter $p$ |
|
54 probability : T |
|
55 } |
|
56 |
|
57 /// Error for [`SaltAndPepper`]. |
|
58 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
|
59 pub enum SaltAndPepperError { |
|
60 /// The probability parameter $p$ is not in the range [0, 1]. |
|
61 InvalidProbability, |
|
62 } |
|
63 impl std::error::Error for SaltAndPepperError {} |
|
64 |
|
65 impl std::fmt::Display for SaltAndPepperError { |
|
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
|
67 f.write_str(match self { |
|
68 SaltAndPepperError::InvalidProbability => |
|
69 " The probability parameter is not in the range [0, 1].", |
|
70 }) |
|
71 } |
|
72 } |
|
73 |
|
74 #[replace_float_literals(T::cast_from(literal))] |
|
75 impl<T : Float> SaltAndPepper<T> { |
|
76 pub fn new(magnitude : T, probability : T) -> Result<SaltAndPepper<T>, SaltAndPepperError> { |
|
77 if probability > 1.0 || probability < 0.0 { |
|
78 Err(SaltAndPepperError::InvalidProbability) |
|
79 } else { |
|
80 Ok(SaltAndPepper { magnitude, probability }) |
|
81 } |
|
82 } |
|
83 } |
|
84 |
|
85 #[replace_float_literals(T::cast_from(literal))] |
|
86 impl<T : Float> Distribution<T> for SaltAndPepper<T> { |
|
87 fn sample<R>(&self, rng: &mut R) -> T |
|
88 where |
|
89 R : Rng + ?Sized |
|
90 { |
|
91 let (p, sign) : (float, bool) = rng.gen(); |
|
92 match (p < self.probability.as_(), sign) { |
|
93 (false, _) => 0.0, |
|
94 (true, true) => self.magnitude, |
|
95 (true, false) => -self.magnitude, |
|
96 } |
|
97 } |
|
98 } |