Sat, 09 Nov 2024 20:54:32 -0500
Some Differential GATs
src/mapping.rs | file | annotate | diff | comparison | revisions |
--- a/src/mapping.rs Sat Nov 09 20:36:23 2024 -0500 +++ b/src/mapping.rs Sat Nov 09 20:54:32 2024 -0500 @@ -3,7 +3,7 @@ */ use std::marker::PhantomData; -use crate::types::{Float}; +use crate::types::Float; use serde::Serialize; use crate::loc::Loc; @@ -52,10 +52,10 @@ /// Automatically implemented shorthand for referring to differentiable [`Mapping`]s from /// [`Loc<F, N>`] to `F`. pub trait DifferentiableRealMapping<F : Float, const N : usize> -: DifferentiableMapping<Loc<F, N>, Codomain = F, Differential=Loc<F, N>> {} +: DifferentiableMapping<Loc<F, N>, Codomain = F, DerivativeDomain=Loc<F, N>> {} impl<F : Float, T, const N : usize> DifferentiableRealMapping<F, N> for T -where T : DifferentiableMapping<Loc<F, N>, Codomain = F, Differential=Loc<F, N>> {} +where T : DifferentiableMapping<Loc<F, N>, Codomain = F, DerivativeDomain=Loc<F, N>> {} /// A helper trait alias for referring to [`Mapping`]s from [`Loc<F, N>`] to [`Loc<F, M>`]. @@ -80,29 +80,38 @@ /// This is automatically implemented when the relevant [`Differentiate`] are implemented. pub trait DifferentiableMapping<Domain> : Mapping<Domain> - + Differentiable<Domain, Derivative=Self::Differential> - + for<'a> Differentiable<&'a Domain, Derivative=Self::Differential> { - type Differential; + + Differentiable<Domain, Derivative=Self::DerivativeDomain> + + for<'a> Differentiable<&'a Domain, Derivative=Self::DerivativeDomain> { + type DerivativeDomain; + type Differential : Mapping<Domain, Codomain=Self::DerivativeDomain>; + type DifferentialRef<'b> : Mapping<Domain, Codomain=Self::DerivativeDomain> where Self : 'b; /// Form the differential mapping of `self`. - fn diff(self) -> Differential<Domain, Self> { + fn diff(self) -> Self::Differential; + /// Form the differential mapping of `self`. + fn diff_ref(&self) -> Self::DifferentialRef<'_>; +} + + +impl<Domain, Derivative, T> DifferentiableMapping<Domain> for T +where T : Mapping<Domain> + + Differentiable<Domain, Derivative=Derivative> + + for<'a> Differentiable<&'a Domain,Derivative=Derivative> { + type DerivativeDomain = Derivative; + type Differential = Differential<Domain, Self>; + type DifferentialRef<'b> = Differential<Domain, Ref<'b, Self>> where Self : 'b; + + /// Form the differential mapping of `self`. + fn diff(self) -> Self::Differential { Differential{ g : self, _space : PhantomData } } /// Form the differential mapping of `self`. - fn diff_ref(&self) -> Differential<Domain, Ref<'_, Self>> { + fn diff_ref(&self) -> Self::DifferentialRef<'_> { Differential{ g : Ref(self), _space : PhantomData } } } - -impl<Domain, Differential, T> DifferentiableMapping<Domain> for T -where T : Mapping<Domain> - + Differentiable<Domain, Derivative=Differential> - + for<'a> Differentiable<&'a Domain, Derivative=Differential> { - type Differential = Differential; -} - /// A sum of [`Mapping`]s. #[derive(Serialize, Debug, Clone)] pub struct Sum<Domain, M : Mapping<Domain>> { @@ -146,10 +155,10 @@ impl<Domain, M> Differentiable<Domain> for Sum<Domain, M> where M : DifferentiableMapping<Domain>, M :: Codomain : std::iter::Sum, - M :: Differential : std::iter::Sum, + M :: DerivativeDomain : std::iter::Sum, Domain : Copy { - type Derivative = M::Differential; + type Derivative = M::DerivativeDomain; fn differential(&self, x : Domain) -> Self::Derivative { self.components.iter().map(|c| c.differential(x)).sum() @@ -163,19 +172,19 @@ } impl<X, G : DifferentiableMapping<X>> Apply<X> for Differential<X, G> { - type Output = G::Differential; + type Output = G::DerivativeDomain; #[inline] - fn apply(&self, x : X) -> G::Differential { + fn apply(&self, x : X) -> Self::Output { self.g.differential(x) } } impl<'a, X, G : DifferentiableMapping<X>> Apply<&'a X> for Differential<X, G> { - type Output = G::Differential; + type Output = G::DerivativeDomain; #[inline] - fn apply(&self, x : &'a X) -> G::Differential { + fn apply(&self, x : &'a X) -> Self::Output { self.g.differential(x) } }