src/rand_distr.rs

changeset 0
eb3c7813b67a
child 23
9869fa1e0ccd
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/rand_distr.rs	Thu Dec 01 23:07:35 2022 +0200
@@ -0,0 +1,98 @@
+//! Random distribution wrappers and implementations
+
+use numeric_literals::replace_float_literals;
+use rand::Rng;
+use rand_distr::{Distribution, Normal, StandardNormal, NormalError};
+use serde::{Serialize, Deserialize};
+use serde::ser::{Serializer, SerializeStruct};
+use alg_tools::types::*;
+
+/// Wrapper for [`Normal`] that can be serialized by serde.
+pub struct SerializableNormal<T : Float>(Normal<T>)
+where StandardNormal : Distribution<T>;
+
+impl<T : Float> Distribution<T> for SerializableNormal<T>
+where StandardNormal : Distribution<T> {
+    fn sample<R>(&self, rng: &mut R) -> T
+    where
+        R : Rng + ?Sized
+    { self.0.sample(rng) }
+}
+
+impl<T : Float> SerializableNormal<T>
+where StandardNormal : Distribution<T> {
+    pub fn new(mean : T, std_dev : T) -> Result<SerializableNormal<T>, NormalError> {
+        Ok(SerializableNormal(Normal::new(mean, std_dev)?))
+    }
+}
+
+impl<F> Serialize for SerializableNormal<F>
+where
+    StandardNormal : Distribution<F>,
+    F: Float + Serialize,
+{
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        let mut s = serializer.serialize_struct("Normal", 2)?;
+        s.serialize_field("mean", &self.0.mean())?;
+        s.serialize_field("std_dev", &self.0.std_dev())?;
+        s.end()
+    }
+}
+
+/// Salt-and-pepper noise distribution
+///
+/// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding
+/// probabilities $\\{1-p, p/2, p/2\\}$.
+#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
+pub struct SaltAndPepper<T : Float>{
+    /// The magnitude parameter $m$
+    magnitude : T,
+    /// The probability parameter $p$
+    probability : T
+}
+
+/// Error for [`SaltAndPepper`].
+#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
+pub enum SaltAndPepperError {
+    /// The probability parameter $p$ is not in the range [0, 1].
+    InvalidProbability,
+}
+impl std::error::Error for SaltAndPepperError {}
+
+impl std::fmt::Display for SaltAndPepperError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.write_str(match self {
+            SaltAndPepperError::InvalidProbability =>
+                " The probability parameter is not in the range [0, 1].",
+        })
+    }
+}
+
+#[replace_float_literals(T::cast_from(literal))]
+impl<T : Float> SaltAndPepper<T> {
+    pub fn new(magnitude : T, probability : T) -> Result<SaltAndPepper<T>, SaltAndPepperError> {
+        if probability > 1.0 || probability < 0.0 {
+            Err(SaltAndPepperError::InvalidProbability)
+        } else {
+            Ok(SaltAndPepper { magnitude, probability })
+        }
+    }
+}
+
+#[replace_float_literals(T::cast_from(literal))]
+impl<T : Float> Distribution<T> for SaltAndPepper<T> {
+    fn sample<R>(&self, rng: &mut R) -> T
+    where
+        R : Rng + ?Sized
+    {
+        let (p, sign) : (float, bool) = rng.gen();
+        match (p < self.probability.as_(), sign) {
+            (false, _)      =>  0.0,
+            (true, true)    =>  self.magnitude,
+            (true, false)   => -self.magnitude,
+        }
+    }
+}

mercurial