diff -r 1f19c6bbf07b -r 3868555d135c src/operator_arithmetic.rs --- a/src/operator_arithmetic.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/operator_arithmetic.rs Fri May 15 14:46:30 2026 -0500 @@ -2,72 +2,74 @@ Arithmetic of [`Mapping`]s. */ -use serde::Serialize; +use crate::instance::{ClosedSpace, Instance, Space}; +use crate::mapping::{DifferentiableImpl, DifferentiableMapping, Mapping}; use crate::types::*; -use crate::instance::{Space, Instance}; -use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping}; +use serde::Serialize; /// A trait for encoding constant [`Float`] values -pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into { +pub trait Constant: Copy + Sync + Send + 'static + std::fmt::Debug + Into { /// The type of the value - type Type : Float; + type Type: Float; /// Returns the value of the constant fn value(&self) -> Self::Type; } -impl Constant for F { +impl Constant for F { type Type = F; #[inline] - fn value(&self) -> F { *self } + fn value(&self) -> F { + *self + } } /// Weighting of a [`Mapping`] by scalar multiplication. -#[derive(Copy,Clone,Debug,Serialize)] -pub struct Weighted { +#[derive(Copy, Clone, Debug, Serialize)] +pub struct Weighted { /// The weight - pub weight : C, + pub weight: C, /// The base [`Mapping`] being weighted. - pub base_fn : T, + pub base_fn: T, } impl Weighted where - C : Constant, + C: Constant, { /// Construct from an iterator. - pub fn new(weight : C, base_fn : T) -> Self { - Weighted{ weight, base_fn } + pub fn new(weight: C, base_fn: T) -> Self { + Weighted { weight, base_fn } } } -impl<'a, T, V, D, F, C> Mapping for Weighted +impl<'a, T, D, F, C> Mapping for Weighted where - F : Float, - D : Space, - T : Mapping, - V : Space + ClosedMul, - C : Constant + F: Float, + D: Space, + T: Mapping, + T::Codomain: ClosedMul, + C: Constant, { - type Codomain = V; + type Codomain = T::Codomain; #[inline] - fn apply>(&self, x : I) -> Self::Codomain { + fn apply>(&self, x: I) -> Self::Codomain { self.base_fn.apply(x) * self.weight.value() } } impl<'a, T, V, D, F, C> DifferentiableImpl for Weighted where - F : Float, - D : Space, - T : DifferentiableMapping, - V : Space + std::ops::Mul, - C : Constant + F: Float, + D: Space, + T: DifferentiableMapping, + V: ClosedSpace + std::ops::Mul, + C: Constant, { type Derivative = V; #[inline] - fn differential_impl>(&self, x : I) -> Self::Derivative { + fn differential_impl>(&self, x: I) -> Self::Derivative { self.base_fn.differential(x) * self.weight.value() } } @@ -76,9 +78,9 @@ #[derive(Serialize, Debug, Clone)] pub struct MappingSum(Vec); -impl< M> MappingSum { +impl MappingSum { /// Construct from an iterator. - pub fn new>(iter : I) -> Self { + pub fn new>(iter: I) -> Self { MappingSum(iter.into_iter().collect()) } @@ -90,28 +92,26 @@ impl Mapping for MappingSum where - Domain : Space + Clone, - M : Mapping, - M::Codomain : std::iter::Sum + Clone + Domain: Space + Clone, + M: Mapping, + M::Codomain: std::iter::Sum + Clone, { type Codomain = M::Codomain; - fn apply>(&self, x : I) -> Self::Codomain { - let xr = x.ref_instance(); - self.0.iter().map(|c| c.apply(xr)).sum() + fn apply>(&self, x: I) -> Self::Codomain { + x.eval_ref(|xr| self.0.iter().map(|c| c.apply(xr)).sum()) } } -impl DifferentiableImpl for MappingSum< M> +impl DifferentiableImpl for MappingSum where - Domain : Space + Clone, - M : DifferentiableMapping, - M :: DerivativeDomain : std::iter::Sum + Domain: Space, + M: DifferentiableMapping, + M::DerivativeDomain: std::iter::Sum, { type Derivative = M::DerivativeDomain; - fn differential_impl>(&self, x : I) -> Self::Derivative { - let xr = x.ref_instance(); - self.0.iter().map(|c| c.differential(xr)).sum() + fn differential_impl>(&self, x: I) -> Self::Derivative { + x.eval_ref(|xr| self.0.iter().map(|c| c.differential(xr)).sum()) } }