--- a/src/fb.rs Fri Oct 18 19:52:06 2024 -0500 +++ b/src/fb.rs Sat Oct 19 10:46:13 2024 -0500 @@ -2,17 +2,41 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::mapping::{Mapping, Sum}; use serde::Serialize; +use std::iter::Sum as SumTrait; use crate::manifold::ManifoldPoint; +/// Trait for function objects that implement gradients +pub trait Grad<M : ManifoldPoint> { + fn grad(&self, x : &M) -> M::Tangent; +} + /// Trait for function objects that implement gradient steps pub trait Desc<M : ManifoldPoint> { - fn desc(&self, τ : f64, pt : M) -> M; + fn desc(&self, τ : f64, x : M) -> M; +} + +/*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { + fn desc(&self, τ : f64, x : M) -> M { + x.exp(self.grad(x) * τ) + } +}*/ + +impl<M, T > Desc<M> for Sum<M, T> +where M : ManifoldPoint, + T : Grad<M> + Mapping<M, Codomain=f64>, + M::Tangent : SumTrait { + fn desc(&self, τ : f64, x : M) -> M { + let t : M::Tangent = self.iter() + .map(|f| f.grad(&x)) + .sum(); + x.exp(&(t * τ)) + } } /// Trait for function objects that implement proximal steps pub trait Prox<M : ManifoldPoint> { - fn prox(&self, τ : f64, pt : M) -> M; + fn prox(&self, τ : f64, x : M) -> M; } #[derive(Clone,Debug,Serialize)]