src/operator_arithmetic.rs

Wed, 03 Sep 2025 20:19:41 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Wed, 03 Sep 2025 20:19:41 -0500
branch
dev
changeset 171
fa8df5a14486
parent 151
402d717bb5c0
permissions
-rw-r--r--

decompose

/*!
Arithmetic of [`Mapping`]s.
 */

use crate::instance::{ClosedSpace, Instance, Space};
use crate::mapping::{DifferentiableImpl, DifferentiableMapping, Mapping};
use crate::types::*;
use serde::Serialize;

/// A trait for encoding constant [`Float`] values
pub trait Constant: Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> {
    /// The type of the value
    type Type: Float;
    /// Returns the value of the constant
    fn value(&self) -> Self::Type;
}

impl<F: Float> Constant for F {
    type Type = F;
    #[inline]
    fn value(&self) -> F {
        *self
    }
}

/// Weighting of a [`Mapping`] by scalar multiplication.
#[derive(Copy, Clone, Debug, Serialize)]
pub struct Weighted<T, C: Constant> {
    /// The weight
    pub weight: C,
    /// The base [`Mapping`] being weighted.
    pub base_fn: T,
}

impl<T, C> Weighted<T, C>
where
    C: Constant,
{
    /// Construct from an iterator.
    pub fn new(weight: C, base_fn: T) -> Self {
        Weighted { weight, base_fn }
    }
}

impl<'a, T, D, F, C> Mapping<D> for Weighted<T, C>
where
    F: Float,
    D: Space,
    T: Mapping<D>,
    T::Codomain: ClosedMul<F>,
    C: Constant<Type = F>,
{
    type Codomain = T::Codomain;

    #[inline]
    fn apply<I: Instance<D>>(&self, x: I) -> Self::Codomain {
        self.base_fn.apply(x) * self.weight.value()
    }
}

impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C>
where
    F: Float,
    D: Space,
    T: DifferentiableMapping<D, DerivativeDomain = V>,
    V: ClosedSpace + std::ops::Mul<F, Output = V>,
    C: Constant<Type = F>,
{
    type Derivative = V;

    #[inline]
    fn differential_impl<I: Instance<D>>(&self, x: I) -> Self::Derivative {
        self.base_fn.differential(x) * self.weight.value()
    }
}

/// A sum of [`Mapping`]s.
#[derive(Serialize, Debug, Clone)]
pub struct MappingSum<M>(Vec<M>);

impl<M> MappingSum<M> {
    /// Construct from an iterator.
    pub fn new<I: IntoIterator<Item = M>>(iter: I) -> Self {
        MappingSum(iter.into_iter().collect())
    }

    /// Iterate over the component functions of the sum
    pub fn iter(&self) -> std::slice::Iter<'_, M> {
        self.0.iter()
    }
}

impl<Domain, M> Mapping<Domain> for MappingSum<M>
where
    Domain: Space + Clone,
    M: Mapping<Domain>,
    M::Codomain: std::iter::Sum + Clone,
{
    type Codomain = M::Codomain;

    fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain {
        x.eval_ref(|xr| self.0.iter().map(|c| c.apply(xr)).sum())
    }
}

impl<Domain, M> DifferentiableImpl<Domain> for MappingSum<M>
where
    Domain: Space,
    M: DifferentiableMapping<Domain>,
    M::DerivativeDomain: std::iter::Sum,
{
    type Derivative = M::DerivativeDomain;

    fn differential_impl<I: Instance<Domain>>(&self, x: I) -> Self::Derivative {
        x.eval_ref(|xr| self.0.iter().map(|c| c.differential(xr)).sum())
    }
}

mercurial