Tue, 31 Dec 2024 08:48:50 -0500
Split out and generalise Weighted
/*! Arithmetic of [`Mapping`]s. */ use serde::Serialize; use crate::types::*; use crate::instance::{Space, Instance}; use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping}; /// A trait for encoding constant [`Float`] values pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> { /// The type of the value type Type : Float; /// Returns the value of the constant fn value(&self) -> Self::Type; } impl<F : Float> Constant for F { type Type = F; #[inline] fn value(&self) -> F { *self } } /// Weighting of a [`Support`] and [`Apply`] by scalar multiplication; /// output of [`Support::weigh`]. #[derive(Copy,Clone,Debug,Serialize)] pub struct Weighted<T, C : Constant> { /// The weight pub weight : C, /// The base [`Support`] or [`Apply`] being weighted. pub base_fn : T, } impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C> where F : Float, D : Space, T : Mapping<D, Codomain=V>, V : Space + ClosedMul<F>, C : Constant<Type=F> { type Codomain = V; #[inline] fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain { self.base_fn.apply(x) * self.weight.value() } } impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C> where F : Float, D : Space, T : DifferentiableMapping<D, DerivativeDomain=V>, V : Space + std::ops::Mul<F, Output=V>, C : Constant<Type=F> { type Derivative = V; #[inline] fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative { self.base_fn.differential(x) * self.weight.value() } }