--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/rand_distr.rs Thu Dec 01 23:07:35 2022 +0200 @@ -0,0 +1,98 @@ +//! Random distribution wrappers and implementations + +use numeric_literals::replace_float_literals; +use rand::Rng; +use rand_distr::{Distribution, Normal, StandardNormal, NormalError}; +use serde::{Serialize, Deserialize}; +use serde::ser::{Serializer, SerializeStruct}; +use alg_tools::types::*; + +/// Wrapper for [`Normal`] that can be serialized by serde. +pub struct SerializableNormal<T : Float>(Normal<T>) +where StandardNormal : Distribution<T>; + +impl<T : Float> Distribution<T> for SerializableNormal<T> +where StandardNormal : Distribution<T> { + fn sample<R>(&self, rng: &mut R) -> T + where + R : Rng + ?Sized + { self.0.sample(rng) } +} + +impl<T : Float> SerializableNormal<T> +where StandardNormal : Distribution<T> { + pub fn new(mean : T, std_dev : T) -> Result<SerializableNormal<T>, NormalError> { + Ok(SerializableNormal(Normal::new(mean, std_dev)?)) + } +} + +impl<F> Serialize for SerializableNormal<F> +where + StandardNormal : Distribution<F>, + F: Float + Serialize, +{ + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + let mut s = serializer.serialize_struct("Normal", 2)?; + s.serialize_field("mean", &self.0.mean())?; + s.serialize_field("std_dev", &self.0.std_dev())?; + s.end() + } +} + +/// Salt-and-pepper noise distribution +/// +/// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding +/// probabilities $\\{1-p, p/2, p/2\\}$. +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +pub struct SaltAndPepper<T : Float>{ + /// The magnitude parameter $m$ + magnitude : T, + /// The probability parameter $p$ + probability : T +} + +/// Error for [`SaltAndPepper`]. +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +pub enum SaltAndPepperError { + /// The probability parameter $p$ is not in the range [0, 1]. + InvalidProbability, +} +impl std::error::Error for SaltAndPepperError {} + +impl std::fmt::Display for SaltAndPepperError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + SaltAndPepperError::InvalidProbability => + " The probability parameter is not in the range [0, 1].", + }) + } +} + +#[replace_float_literals(T::cast_from(literal))] +impl<T : Float> SaltAndPepper<T> { + pub fn new(magnitude : T, probability : T) -> Result<SaltAndPepper<T>, SaltAndPepperError> { + if probability > 1.0 || probability < 0.0 { + Err(SaltAndPepperError::InvalidProbability) + } else { + Ok(SaltAndPepper { magnitude, probability }) + } + } +} + +#[replace_float_literals(T::cast_from(literal))] +impl<T : Float> Distribution<T> for SaltAndPepper<T> { + fn sample<R>(&self, rng: &mut R) -> T + where + R : Rng + ?Sized + { + let (p, sign) : (float, bool) = rng.gen(); + match (p < self.probability.as_(), sign) { + (false, _) => 0.0, + (true, true) => self.magnitude, + (true, false) => -self.magnitude, + } + } +}