Forward-backward skeleton

Fri, 18 Oct 2024 19:52:06 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Fri, 18 Oct 2024 19:52:06 -0500
changeset 4
e09437844ad9
parent 3
ff4656da04af
child 5
f248e1434c3b

Forward-backward skeleton

Cargo.toml file | annotate | diff | comparison | revisions
src/fb.rs file | annotate | diff | comparison | revisions
src/main.rs file | annotate | diff | comparison | revisions
--- a/Cargo.toml	Fri Oct 18 14:22:45 2024 -0500
+++ b/Cargo.toml	Fri Oct 18 19:52:06 2024 -0500
@@ -18,4 +18,5 @@
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
-alg_tools = { version = "~0.2.0-dev", path = "../alg_tools", default-features = false }
+serde = { version = "1.0", features = ["derive"] }
+alg_tools = { version = "~0.3.0-dev", path = "../alg_tools", default-features = false }
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/fb.rs	Fri Oct 18 19:52:06 2024 -0500
@@ -0,0 +1,48 @@
+
+use alg_tools::iterate::AlgIteratorFactory;
+use alg_tools::mapping::{Mapping, Sum};
+use serde::Serialize;
+
+use crate::manifold::ManifoldPoint;
+
+/// Trait for function objects that implement gradient steps
+pub trait Desc<M : ManifoldPoint> {
+    fn desc(&self, τ : f64, pt : M) -> M;
+}
+
+/// Trait for function objects that implement proximal steps
+pub trait Prox<M : ManifoldPoint> {
+    fn prox(&self, τ : f64, pt : M) -> M;
+}
+
+#[derive(Clone,Debug,Serialize)]
+pub struct IterInfo<M> {
+    value : f64,
+    point : M,
+}
+
+pub fn forward_backward<M, F, G, I>(
+    f : &F,
+    g : &G,
+    mut x : M,
+    τ : f64,
+    iterator : I
+) -> M
+where M : ManifoldPoint,
+      F : Desc<M> +  Mapping<M, Codomain = f64>,
+      G : Prox<M> +  Mapping<M, Codomain = f64>,
+      I : AlgIteratorFactory<IterInfo<M>> {
+    
+    for i in iterator.iter() {
+        x = g.prox(τ, f.desc(τ, x));
+
+        i.if_verbose(|| {
+            IterInfo {
+                value : f.apply(&x) + g.apply(&x),
+                point : x.clone(),
+            }
+        })
+    }
+
+    x
+}
--- a/src/main.rs	Fri Oct 18 14:22:45 2024 -0500
+++ b/src/main.rs	Fri Oct 18 19:52:06 2024 -0500
@@ -1,5 +1,12 @@
+
+// We use unicode. We would like to use much more of it than Rust allows.
+// Live with it. Embrace it.
+#![allow(uncommon_codepoints)]
+#![allow(mixed_script_confusables)]
+#![allow(confusable_idents)]
 
 mod manifold;
+mod fb;
 mod cube;
 
 fn main() {

mercurial