src/fb.rs

changeset 5
f248e1434c3b
parent 4
e09437844ad9
child 7
8979a6638424
--- a/src/fb.rs	Fri Oct 18 19:52:06 2024 -0500
+++ b/src/fb.rs	Sat Oct 19 10:46:13 2024 -0500
@@ -2,17 +2,41 @@
 use alg_tools::iterate::AlgIteratorFactory;
 use alg_tools::mapping::{Mapping, Sum};
 use serde::Serialize;
+use std::iter::Sum as SumTrait;
 
 use crate::manifold::ManifoldPoint;
 
+/// Trait for function objects that implement gradients
+pub trait Grad<M : ManifoldPoint> {
+    fn grad(&self, x : &M) -> M::Tangent;
+}
+
 /// Trait for function objects that implement gradient steps
 pub trait Desc<M : ManifoldPoint> {
-    fn desc(&self, τ : f64, pt : M) -> M;
+    fn desc(&self, τ : f64, x : M) -> M;
+}
+
+/*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T {
+     fn desc(&self, τ : f64, x : M) -> M {
+        x.exp(self.grad(x) * τ)
+     }
+}*/
+
+impl<M, T > Desc<M> for Sum<M, T>
+where M : ManifoldPoint,
+      T : Grad<M> + Mapping<M, Codomain=f64>,
+      M::Tangent : SumTrait {
+    fn desc(&self, τ : f64, x : M) -> M {
+        let t : M::Tangent = self.iter()
+                                 .map(|f| f.grad(&x))
+                                 .sum();
+        x.exp(&(t * τ))
+    }
 }
 
 /// Trait for function objects that implement proximal steps
 pub trait Prox<M : ManifoldPoint> {
-    fn prox(&self, τ : f64, pt : M) -> M;
+    fn prox(&self, τ : f64, x : M) -> M;
 }
 
 #[derive(Clone,Debug,Serialize)]

mercurial