src/rand_distr.rs

branch
dev
changeset 61
4f468d35fa29
parent 23
9869fa1e0ccd
--- a/src/rand_distr.rs	Sun Apr 27 15:03:51 2025 -0500
+++ b/src/rand_distr.rs	Thu Feb 26 11:38:43 2026 -0500
@@ -1,35 +1,42 @@
 //! Random distribution wrappers and implementations
 
+use alg_tools::types::*;
 use numeric_literals::replace_float_literals;
 use rand::Rng;
-use rand_distr::{Distribution, Normal, StandardNormal, NormalError};
-use serde::{Serialize, Deserialize};
-use serde::ser::{Serializer, SerializeStruct};
-use alg_tools::types::*;
+use rand_distr::{Distribution, Normal, NormalError, StandardNormal};
+use serde::ser::{SerializeStruct, Serializer};
+use serde::{Deserialize, Serialize};
 
 /// Wrapper for [`Normal`] that can be serialized by serde.
 #[derive(Debug)]
-pub struct SerializableNormal<T : Float>(Normal<T>)
-where StandardNormal : Distribution<T>;
+pub struct SerializableNormal<T: Float>(Normal<T>)
+where
+    StandardNormal: Distribution<T>;
 
-impl<T : Float> Distribution<T> for SerializableNormal<T>
-where StandardNormal : Distribution<T> {
+impl<T: Float> Distribution<T> for SerializableNormal<T>
+where
+    StandardNormal: Distribution<T>,
+{
     fn sample<R>(&self, rng: &mut R) -> T
     where
-        R : Rng + ?Sized
-    { self.0.sample(rng) }
+        R: Rng + ?Sized,
+    {
+        self.0.sample(rng)
+    }
 }
 
-impl<T : Float> SerializableNormal<T>
-where StandardNormal : Distribution<T> {
-    pub fn new(mean : T, std_dev : T) -> Result<SerializableNormal<T>, NormalError> {
+impl<T: Float> SerializableNormal<T>
+where
+    StandardNormal: Distribution<T>,
+{
+    pub fn new(mean: T, std_dev: T) -> Result<SerializableNormal<T>, NormalError> {
         Ok(SerializableNormal(Normal::new(mean, std_dev)?))
     }
 }
 
 impl<F> Serialize for SerializableNormal<F>
 where
-    StandardNormal : Distribution<F>,
+    StandardNormal: Distribution<F>,
     F: Float + Serialize,
 {
     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
@@ -48,11 +55,11 @@
 /// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding
 /// probabilities $\\{1-p, p/2, p/2\\}$.
 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
-pub struct SaltAndPepper<T : Float>{
+pub struct SaltAndPepper<T: Float> {
     /// The magnitude parameter $m$
-    magnitude : T,
+    magnitude: T,
     /// The probability parameter $p$
-    probability : T
+    probability: T,
 }
 
 /// Error for [`SaltAndPepper`].
@@ -66,15 +73,16 @@
 impl std::fmt::Display for SaltAndPepperError {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         f.write_str(match self {
-            SaltAndPepperError::InvalidProbability =>
-                " The probability parameter is not in the range [0, 1].",
+            SaltAndPepperError::InvalidProbability => {
+                " The probability parameter is not in the range [0, 1]."
+            }
         })
     }
 }
 
 #[replace_float_literals(T::cast_from(literal))]
-impl<T : Float> SaltAndPepper<T> {
-    pub fn new(magnitude : T, probability : T) -> Result<SaltAndPepper<T>, SaltAndPepperError> {
+impl<T: Float> SaltAndPepper<T> {
+    pub fn new(magnitude: T, probability: T) -> Result<SaltAndPepper<T>, SaltAndPepperError> {
         if probability > 1.0 || probability < 0.0 {
             Err(SaltAndPepperError::InvalidProbability)
         } else {
@@ -84,16 +92,16 @@
 }
 
 #[replace_float_literals(T::cast_from(literal))]
-impl<T : Float> Distribution<T> for SaltAndPepper<T> {
+impl<T: Float> Distribution<T> for SaltAndPepper<T> {
     fn sample<R>(&self, rng: &mut R) -> T
     where
-        R : Rng + ?Sized
+        R: Rng + ?Sized,
     {
-        let (p, sign) : (float, bool) = rng.gen();
+        let (p, sign): (float, bool) = rng.random();
         match (p < self.probability.as_(), sign) {
-            (false, _)      =>  0.0,
-            (true, true)    =>  self.magnitude,
-            (true, false)   => -self.magnitude,
+            (false, _) => 0.0,
+            (true, true) => self.magnitude,
+            (true, false) => -self.magnitude,
         }
     }
 }

mercurial