Tue, 06 Dec 2022 14:12:20 +0200
v1.0.0-pre-arxiv (missing arXiv links)
//! Random distribution wrappers and implementations use numeric_literals::replace_float_literals; use rand::Rng; use rand_distr::{Distribution, Normal, StandardNormal, NormalError}; use serde::{Serialize, Deserialize}; use serde::ser::{Serializer, SerializeStruct}; use alg_tools::types::*; /// Wrapper for [`Normal`] that can be serialized by serde. pub struct SerializableNormal<T : Float>(Normal<T>) where StandardNormal : Distribution<T>; impl<T : Float> Distribution<T> for SerializableNormal<T> where StandardNormal : Distribution<T> { fn sample<R>(&self, rng: &mut R) -> T where R : Rng + ?Sized { self.0.sample(rng) } } impl<T : Float> SerializableNormal<T> where StandardNormal : Distribution<T> { pub fn new(mean : T, std_dev : T) -> Result<SerializableNormal<T>, NormalError> { Ok(SerializableNormal(Normal::new(mean, std_dev)?)) } } impl<F> Serialize for SerializableNormal<F> where StandardNormal : Distribution<F>, F: Float + Serialize, { fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer, { let mut s = serializer.serialize_struct("Normal", 2)?; s.serialize_field("mean", &self.0.mean())?; s.serialize_field("std_dev", &self.0.std_dev())?; s.end() } } /// Salt-and-pepper noise distribution /// /// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding /// probabilities $\\{1-p, p/2, p/2\\}$. #[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub struct SaltAndPepper<T : Float>{ /// The magnitude parameter $m$ magnitude : T, /// The probability parameter $p$ probability : T } /// Error for [`SaltAndPepper`]. #[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub enum SaltAndPepperError { /// The probability parameter $p$ is not in the range [0, 1]. InvalidProbability, } impl std::error::Error for SaltAndPepperError {} impl std::fmt::Display for SaltAndPepperError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str(match self { SaltAndPepperError::InvalidProbability => " The probability parameter is not in the range [0, 1].", }) } } #[replace_float_literals(T::cast_from(literal))] impl<T : Float> SaltAndPepper<T> { pub fn new(magnitude : T, probability : T) -> Result<SaltAndPepper<T>, SaltAndPepperError> { if probability > 1.0 || probability < 0.0 { Err(SaltAndPepperError::InvalidProbability) } else { Ok(SaltAndPepper { magnitude, probability }) } } } #[replace_float_literals(T::cast_from(literal))] impl<T : Float> Distribution<T> for SaltAndPepper<T> { fn sample<R>(&self, rng: &mut R) -> T where R : Rng + ?Sized { let (p, sign) : (float, bool) = rng.gen(); match (p < self.probability.as_(), sign) { (false, _) => 0.0, (true, true) => self.magnitude, (true, false) => -self.magnitude, } } }