src/rand_distr.rs

branch
dev
changeset 61
4f468d35fa29
parent 23
9869fa1e0ccd
equal deleted inserted replaced
60:9738b51d90d7 61:4f468d35fa29
1 //! Random distribution wrappers and implementations 1 //! Random distribution wrappers and implementations
2 2
3 use alg_tools::types::*;
3 use numeric_literals::replace_float_literals; 4 use numeric_literals::replace_float_literals;
4 use rand::Rng; 5 use rand::Rng;
5 use rand_distr::{Distribution, Normal, StandardNormal, NormalError}; 6 use rand_distr::{Distribution, Normal, NormalError, StandardNormal};
6 use serde::{Serialize, Deserialize}; 7 use serde::ser::{SerializeStruct, Serializer};
7 use serde::ser::{Serializer, SerializeStruct}; 8 use serde::{Deserialize, Serialize};
8 use alg_tools::types::*;
9 9
10 /// Wrapper for [`Normal`] that can be serialized by serde. 10 /// Wrapper for [`Normal`] that can be serialized by serde.
11 #[derive(Debug)] 11 #[derive(Debug)]
12 pub struct SerializableNormal<T : Float>(Normal<T>) 12 pub struct SerializableNormal<T: Float>(Normal<T>)
13 where StandardNormal : Distribution<T>; 13 where
14 StandardNormal: Distribution<T>;
14 15
15 impl<T : Float> Distribution<T> for SerializableNormal<T> 16 impl<T: Float> Distribution<T> for SerializableNormal<T>
16 where StandardNormal : Distribution<T> { 17 where
18 StandardNormal: Distribution<T>,
19 {
17 fn sample<R>(&self, rng: &mut R) -> T 20 fn sample<R>(&self, rng: &mut R) -> T
18 where 21 where
19 R : Rng + ?Sized 22 R: Rng + ?Sized,
20 { self.0.sample(rng) } 23 {
24 self.0.sample(rng)
25 }
21 } 26 }
22 27
23 impl<T : Float> SerializableNormal<T> 28 impl<T: Float> SerializableNormal<T>
24 where StandardNormal : Distribution<T> { 29 where
25 pub fn new(mean : T, std_dev : T) -> Result<SerializableNormal<T>, NormalError> { 30 StandardNormal: Distribution<T>,
31 {
32 pub fn new(mean: T, std_dev: T) -> Result<SerializableNormal<T>, NormalError> {
26 Ok(SerializableNormal(Normal::new(mean, std_dev)?)) 33 Ok(SerializableNormal(Normal::new(mean, std_dev)?))
27 } 34 }
28 } 35 }
29 36
30 impl<F> Serialize for SerializableNormal<F> 37 impl<F> Serialize for SerializableNormal<F>
31 where 38 where
32 StandardNormal : Distribution<F>, 39 StandardNormal: Distribution<F>,
33 F: Float + Serialize, 40 F: Float + Serialize,
34 { 41 {
35 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> 42 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
36 where 43 where
37 S: Serializer, 44 S: Serializer,
46 /// Salt-and-pepper noise distribution 53 /// Salt-and-pepper noise distribution
47 /// 54 ///
48 /// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding 55 /// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding
49 /// probabilities $\\{1-p, p/2, p/2\\}$. 56 /// probabilities $\\{1-p, p/2, p/2\\}$.
50 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 57 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
51 pub struct SaltAndPepper<T : Float>{ 58 pub struct SaltAndPepper<T: Float> {
52 /// The magnitude parameter $m$ 59 /// The magnitude parameter $m$
53 magnitude : T, 60 magnitude: T,
54 /// The probability parameter $p$ 61 /// The probability parameter $p$
55 probability : T 62 probability: T,
56 } 63 }
57 64
58 /// Error for [`SaltAndPepper`]. 65 /// Error for [`SaltAndPepper`].
59 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 66 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
60 pub enum SaltAndPepperError { 67 pub enum SaltAndPepperError {
64 impl std::error::Error for SaltAndPepperError {} 71 impl std::error::Error for SaltAndPepperError {}
65 72
66 impl std::fmt::Display for SaltAndPepperError { 73 impl std::fmt::Display for SaltAndPepperError {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.write_str(match self { 75 f.write_str(match self {
69 SaltAndPepperError::InvalidProbability => 76 SaltAndPepperError::InvalidProbability => {
70 " The probability parameter is not in the range [0, 1].", 77 " The probability parameter is not in the range [0, 1]."
78 }
71 }) 79 })
72 } 80 }
73 } 81 }
74 82
75 #[replace_float_literals(T::cast_from(literal))] 83 #[replace_float_literals(T::cast_from(literal))]
76 impl<T : Float> SaltAndPepper<T> { 84 impl<T: Float> SaltAndPepper<T> {
77 pub fn new(magnitude : T, probability : T) -> Result<SaltAndPepper<T>, SaltAndPepperError> { 85 pub fn new(magnitude: T, probability: T) -> Result<SaltAndPepper<T>, SaltAndPepperError> {
78 if probability > 1.0 || probability < 0.0 { 86 if probability > 1.0 || probability < 0.0 {
79 Err(SaltAndPepperError::InvalidProbability) 87 Err(SaltAndPepperError::InvalidProbability)
80 } else { 88 } else {
81 Ok(SaltAndPepper { magnitude, probability }) 89 Ok(SaltAndPepper { magnitude, probability })
82 } 90 }
83 } 91 }
84 } 92 }
85 93
86 #[replace_float_literals(T::cast_from(literal))] 94 #[replace_float_literals(T::cast_from(literal))]
87 impl<T : Float> Distribution<T> for SaltAndPepper<T> { 95 impl<T: Float> Distribution<T> for SaltAndPepper<T> {
88 fn sample<R>(&self, rng: &mut R) -> T 96 fn sample<R>(&self, rng: &mut R) -> T
89 where 97 where
90 R : Rng + ?Sized 98 R: Rng + ?Sized,
91 { 99 {
92 let (p, sign) : (float, bool) = rng.gen(); 100 let (p, sign): (float, bool) = rng.random();
93 match (p < self.probability.as_(), sign) { 101 match (p < self.probability.as_(), sign) {
94 (false, _) => 0.0, 102 (false, _) => 0.0,
95 (true, true) => self.magnitude, 103 (true, true) => self.magnitude,
96 (true, false) => -self.magnitude, 104 (true, false) => -self.magnitude,
97 } 105 }
98 } 106 }
99 } 107 }

mercurial