--- a/src/operator_arithmetic.rs Thu May 08 22:53:31 2025 -0500 +++ b/src/operator_arithmetic.rs Sun May 11 02:03:45 2025 -0500 @@ -2,72 +2,74 @@ Arithmetic of [`Mapping`]s. */ -use serde::Serialize; +use crate::instance::{Instance, Space}; +use crate::mapping::{DifferentiableImpl, DifferentiableMapping, Mapping}; use crate::types::*; -use crate::instance::{Space, Instance}; -use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping}; +use serde::Serialize; /// A trait for encoding constant [`Float`] values -pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> { +pub trait Constant: Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> { /// The type of the value - type Type : Float; + type Type: Float; /// Returns the value of the constant fn value(&self) -> Self::Type; } -impl<F : Float> Constant for F { +impl<F: Float> Constant for F { type Type = F; #[inline] - fn value(&self) -> F { *self } + fn value(&self) -> F { + *self + } } /// Weighting of a [`Mapping`] by scalar multiplication. -#[derive(Copy,Clone,Debug,Serialize)] -pub struct Weighted<T, C : Constant> { +#[derive(Copy, Clone, Debug, Serialize)] +pub struct Weighted<T, C: Constant> { /// The weight - pub weight : C, + pub weight: C, /// The base [`Mapping`] being weighted. - pub base_fn : T, + pub base_fn: T, } impl<T, C> Weighted<T, C> where - C : Constant, + C: Constant, { /// Construct from an iterator. - pub fn new(weight : C, base_fn : T) -> Self { - Weighted{ weight, base_fn } + pub fn new(weight: C, base_fn: T) -> Self { + Weighted { weight, base_fn } } } impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C> where - F : Float, - D : Space, - T : Mapping<D, Codomain=V>, - V : Space + ClosedMul<F>, - C : Constant<Type=F> + F: Float, + D: Space, + T: Mapping<D, Codomain = V>, + V: Space + ClosedMul<F>, + C: Constant<Type = F>, { type Codomain = V; #[inline] - fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain { + fn apply<I: Instance<D>>(&self, x: I) -> Self::Codomain { self.base_fn.apply(x) * self.weight.value() } } impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C> where - F : Float, - D : Space, - T : DifferentiableMapping<D, DerivativeDomain=V>, - V : Space + std::ops::Mul<F, Output=V>, - C : Constant<Type=F> + F: Float, + D: Space, + T: DifferentiableMapping<D, DerivativeDomain = V>, + V: Space + std::ops::Mul<F, Output = V>, + C: Constant<Type = F>, { type Derivative = V; #[inline] - fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative { + fn differential_impl<I: Instance<D>>(&self, x: I) -> Self::Derivative { self.base_fn.differential(x) * self.weight.value() } } @@ -76,9 +78,9 @@ #[derive(Serialize, Debug, Clone)] pub struct MappingSum<M>(Vec<M>); -impl< M> MappingSum<M> { +impl<M> MappingSum<M> { /// Construct from an iterator. - pub fn new<I : IntoIterator<Item = M>>(iter : I) -> Self { + pub fn new<I: IntoIterator<Item = M>>(iter: I) -> Self { MappingSum(iter.into_iter().collect()) } @@ -90,27 +92,27 @@ impl<Domain, M> Mapping<Domain> for MappingSum<M> where - Domain : Space + Clone, - M : Mapping<Domain>, - M::Codomain : std::iter::Sum + Clone + Domain: Space + Clone, + M: Mapping<Domain>, + M::Codomain: std::iter::Sum + Clone, { type Codomain = M::Codomain; - fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain { + fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain { let xr = x.ref_instance(); self.0.iter().map(|c| c.apply(xr)).sum() } } -impl<Domain, M> DifferentiableImpl<Domain> for MappingSum< M> +impl<Domain, M> DifferentiableImpl<Domain> for MappingSum<M> where - Domain : Space + Clone, - M : DifferentiableMapping<Domain>, - M :: DerivativeDomain : std::iter::Sum + Domain: Space, + M: DifferentiableMapping<Domain>, + M::DerivativeDomain: std::iter::Sum, { type Derivative = M::DerivativeDomain; - fn differential_impl<I : Instance<Domain>>(&self, x : I) -> Self::Derivative { + fn differential_impl<I: Instance<Domain>>(&self, x: I) -> Self::Derivative { let xr = x.ref_instance(); self.0.iter().map(|c| c.differential(xr)).sum() }