src/fb.rs

changeset 4
e09437844ad9
child 5
f248e1434c3b
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/fb.rs	Fri Oct 18 19:52:06 2024 -0500
@@ -0,0 +1,48 @@
+
+use alg_tools::iterate::AlgIteratorFactory;
+use alg_tools::mapping::{Mapping, Sum};
+use serde::Serialize;
+
+use crate::manifold::ManifoldPoint;
+
+/// Trait for function objects that implement gradient steps
+pub trait Desc<M : ManifoldPoint> {
+    fn desc(&self, τ : f64, pt : M) -> M;
+}
+
+/// Trait for function objects that implement proximal steps
+pub trait Prox<M : ManifoldPoint> {
+    fn prox(&self, τ : f64, pt : M) -> M;
+}
+
+#[derive(Clone,Debug,Serialize)]
+pub struct IterInfo<M> {
+    value : f64,
+    point : M,
+}
+
+pub fn forward_backward<M, F, G, I>(
+    f : &F,
+    g : &G,
+    mut x : M,
+    τ : f64,
+    iterator : I
+) -> M
+where M : ManifoldPoint,
+      F : Desc<M> +  Mapping<M, Codomain = f64>,
+      G : Prox<M> +  Mapping<M, Codomain = f64>,
+      I : AlgIteratorFactory<IterInfo<M>> {
+    
+    for i in iterator.iter() {
+        x = g.prox(τ, f.desc(τ, x));
+
+        i.if_verbose(|| {
+            IterInfo {
+                value : f.apply(&x) + g.apply(&x),
+                point : x.clone(),
+            }
+        })
+    }
+
+    x
+}

mercurial