src/operator_arithmetic.rs

changeset 90
b3c35d16affe
parent 75
e9f4550cfa18
child 86
d5b0e496b72f
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/operator_arithmetic.rs	Mon Feb 03 19:22:16 2025 -0500
@@ -0,0 +1,117 @@
+/*!
+Arithmetic of [`Mapping`]s.
+ */
+
+use serde::Serialize;
+use crate::types::*;
+use crate::instance::{Space, Instance};
+use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping};
+
+/// A trait for encoding constant [`Float`] values
+pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> {
+    /// The type of the value
+    type Type : Float;
+    /// Returns the value of the constant
+    fn value(&self) -> Self::Type;
+}
+
+impl<F : Float> Constant for F {
+    type Type = F;
+    #[inline]
+    fn value(&self) -> F { *self }
+}
+
+/// Weighting of a [`Mapping`] by scalar multiplication.
+#[derive(Copy,Clone,Debug,Serialize)]
+pub struct Weighted<T, C : Constant> {
+    /// The weight
+    pub weight : C,
+    /// The base [`Mapping`] being weighted.
+    pub base_fn : T,
+}
+
+impl<T, C> Weighted<T, C>
+where
+    C : Constant,
+{
+    /// Construct from an iterator.
+    pub fn new(weight : C, base_fn : T) -> Self {
+        Weighted{ weight, base_fn }
+    }
+}
+
+impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C>
+where
+    F : Float,
+    D : Space,
+    T : Mapping<D, Codomain=V>,
+    V : Space + ClosedMul<F>,
+    C : Constant<Type=F>
+{
+    type Codomain = V;
+
+    #[inline]
+    fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain {
+        self.base_fn.apply(x) * self.weight.value()
+    }
+}
+
+impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C>
+where
+    F : Float,
+    D : Space,
+    T : DifferentiableMapping<D, DerivativeDomain=V>,
+    V : Space + std::ops::Mul<F, Output=V>,
+    C : Constant<Type=F>
+{
+    type Derivative = V;
+
+    #[inline]
+    fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative {
+        self.base_fn.differential(x) * self.weight.value()
+    }
+}
+
+/// A sum of [`Mapping`]s.
+#[derive(Serialize, Debug, Clone)]
+pub struct MappingSum<M>(Vec<M>);
+
+impl< M> MappingSum<M> {
+    /// Construct from an iterator.
+    pub fn new<I : IntoIterator<Item = M>>(iter : I) -> Self {
+        MappingSum(iter.into_iter().collect())
+    }
+
+    /// Iterate over the component functions of the sum
+    pub fn iter(&self) -> std::slice::Iter<'_, M> {
+        self.0.iter()
+    }
+}
+
+impl<Domain, M> Mapping<Domain> for MappingSum<M>
+where
+    Domain : Space + Clone,
+    M : Mapping<Domain>,
+    M::Codomain : std::iter::Sum + Clone
+{
+    type Codomain = M::Codomain;
+
+    fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain {
+        let xr = x.ref_instance();
+        self.0.iter().map(|c| c.apply(xr)).sum()
+    }
+}
+
+impl<Domain, M> DifferentiableImpl<Domain> for MappingSum< M>
+where
+    Domain : Space + Clone,
+    M : DifferentiableMapping<Domain>,
+    M :: DerivativeDomain : std::iter::Sum
+{
+    type Derivative = M::DerivativeDomain;
+
+    fn differential_impl<I : Instance<Domain>>(&self, x : I) -> Self::Derivative {
+        let xr = x.ref_instance();
+        self.0.iter().map(|c| c.differential(xr)).sum()
+    }
+}

mercurial