--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/operator_arithmetic.rs Mon Feb 03 19:22:16 2025 -0500 @@ -0,0 +1,117 @@ +/*! +Arithmetic of [`Mapping`]s. + */ + +use serde::Serialize; +use crate::types::*; +use crate::instance::{Space, Instance}; +use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping}; + +/// A trait for encoding constant [`Float`] values +pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> { + /// The type of the value + type Type : Float; + /// Returns the value of the constant + fn value(&self) -> Self::Type; +} + +impl<F : Float> Constant for F { + type Type = F; + #[inline] + fn value(&self) -> F { *self } +} + +/// Weighting of a [`Mapping`] by scalar multiplication. +#[derive(Copy,Clone,Debug,Serialize)] +pub struct Weighted<T, C : Constant> { + /// The weight + pub weight : C, + /// The base [`Mapping`] being weighted. + pub base_fn : T, +} + +impl<T, C> Weighted<T, C> +where + C : Constant, +{ + /// Construct from an iterator. + pub fn new(weight : C, base_fn : T) -> Self { + Weighted{ weight, base_fn } + } +} + +impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C> +where + F : Float, + D : Space, + T : Mapping<D, Codomain=V>, + V : Space + ClosedMul<F>, + C : Constant<Type=F> +{ + type Codomain = V; + + #[inline] + fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain { + self.base_fn.apply(x) * self.weight.value() + } +} + +impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C> +where + F : Float, + D : Space, + T : DifferentiableMapping<D, DerivativeDomain=V>, + V : Space + std::ops::Mul<F, Output=V>, + C : Constant<Type=F> +{ + type Derivative = V; + + #[inline] + fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative { + self.base_fn.differential(x) * self.weight.value() + } +} + +/// A sum of [`Mapping`]s. +#[derive(Serialize, Debug, Clone)] +pub struct MappingSum<M>(Vec<M>); + +impl< M> MappingSum<M> { + /// Construct from an iterator. + pub fn new<I : IntoIterator<Item = M>>(iter : I) -> Self { + MappingSum(iter.into_iter().collect()) + } + + /// Iterate over the component functions of the sum + pub fn iter(&self) -> std::slice::Iter<'_, M> { + self.0.iter() + } +} + +impl<Domain, M> Mapping<Domain> for MappingSum<M> +where + Domain : Space + Clone, + M : Mapping<Domain>, + M::Codomain : std::iter::Sum + Clone +{ + type Codomain = M::Codomain; + + fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain { + let xr = x.ref_instance(); + self.0.iter().map(|c| c.apply(xr)).sum() + } +} + +impl<Domain, M> DifferentiableImpl<Domain> for MappingSum< M> +where + Domain : Space + Clone, + M : DifferentiableMapping<Domain>, + M :: DerivativeDomain : std::iter::Sum +{ + type Derivative = M::DerivativeDomain; + + fn differential_impl<I : Instance<Domain>>(&self, x : I) -> Self::Derivative { + let xr = x.ref_instance(); + self.0.iter().map(|c| c.differential(xr)).sum() + } +}