src/rand_distr.rs

Tue, 21 Mar 2023 20:31:01 +0200

author
Tuomo Valkonen <tuomov@iki.fi>
date
Tue, 21 Mar 2023 20:31:01 +0200
changeset 25
79943be70720
parent 23
9869fa1e0ccd
permissions
-rw-r--r--

Implement non-negativity constraints for the conditional gradient methods

//! Random distribution wrappers and implementations

use numeric_literals::replace_float_literals;
use rand::Rng;
use rand_distr::{Distribution, Normal, StandardNormal, NormalError};
use serde::{Serialize, Deserialize};
use serde::ser::{Serializer, SerializeStruct};
use alg_tools::types::*;

/// Wrapper for [`Normal`] that can be serialized by serde.
#[derive(Debug)]
pub struct SerializableNormal<T : Float>(Normal<T>)
where StandardNormal : Distribution<T>;

impl<T : Float> Distribution<T> for SerializableNormal<T>
where StandardNormal : Distribution<T> {
    fn sample<R>(&self, rng: &mut R) -> T
    where
        R : Rng + ?Sized
    { self.0.sample(rng) }
}

impl<T : Float> SerializableNormal<T>
where StandardNormal : Distribution<T> {
    pub fn new(mean : T, std_dev : T) -> Result<SerializableNormal<T>, NormalError> {
        Ok(SerializableNormal(Normal::new(mean, std_dev)?))
    }
}

impl<F> Serialize for SerializableNormal<F>
where
    StandardNormal : Distribution<F>,
    F: Float + Serialize,
{
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let mut s = serializer.serialize_struct("Normal", 2)?;
        s.serialize_field("mean", &self.0.mean())?;
        s.serialize_field("std_dev", &self.0.std_dev())?;
        s.end()
    }
}

/// Salt-and-pepper noise distribution
///
/// This is the distribution that outputs each $\\{-m,0,m\\}$ with the corresponding
/// probabilities $\\{1-p, p/2, p/2\\}$.
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct SaltAndPepper<T : Float>{
    /// The magnitude parameter $m$
    magnitude : T,
    /// The probability parameter $p$
    probability : T
}

/// Error for [`SaltAndPepper`].
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub enum SaltAndPepperError {
    /// The probability parameter $p$ is not in the range [0, 1].
    InvalidProbability,
}
impl std::error::Error for SaltAndPepperError {}

impl std::fmt::Display for SaltAndPepperError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(match self {
            SaltAndPepperError::InvalidProbability =>
                " The probability parameter is not in the range [0, 1].",
        })
    }
}

#[replace_float_literals(T::cast_from(literal))]
impl<T : Float> SaltAndPepper<T> {
    pub fn new(magnitude : T, probability : T) -> Result<SaltAndPepper<T>, SaltAndPepperError> {
        if probability > 1.0 || probability < 0.0 {
            Err(SaltAndPepperError::InvalidProbability)
        } else {
            Ok(SaltAndPepper { magnitude, probability })
        }
    }
}

#[replace_float_literals(T::cast_from(literal))]
impl<T : Float> Distribution<T> for SaltAndPepper<T> {
    fn sample<R>(&self, rng: &mut R) -> T
    where
        R : Rng + ?Sized
    {
        let (p, sign) : (float, bool) = rng.gen();
        match (p < self.probability.as_(), sign) {
            (false, _)      =>  0.0,
            (true, true)    =>  self.magnitude,
            (true, false)   => -self.magnitude,
        }
    }
}

mercurial