--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/fb.rs Fri Oct 18 19:52:06 2024 -0500 @@ -0,0 +1,48 @@ + +use alg_tools::iterate::AlgIteratorFactory; +use alg_tools::mapping::{Mapping, Sum}; +use serde::Serialize; + +use crate::manifold::ManifoldPoint; + +/// Trait for function objects that implement gradient steps +pub trait Desc<M : ManifoldPoint> { + fn desc(&self, τ : f64, pt : M) -> M; +} + +/// Trait for function objects that implement proximal steps +pub trait Prox<M : ManifoldPoint> { + fn prox(&self, τ : f64, pt : M) -> M; +} + +#[derive(Clone,Debug,Serialize)] +pub struct IterInfo<M> { + value : f64, + point : M, +} + +pub fn forward_backward<M, F, G, I>( + f : &F, + g : &G, + mut x : M, + τ : f64, + iterator : I +) -> M +where M : ManifoldPoint, + F : Desc<M> + Mapping<M, Codomain = f64>, + G : Prox<M> + Mapping<M, Codomain = f64>, + I : AlgIteratorFactory<IterInfo<M>> { + + for i in iterator.iter() { + x = g.prox(τ, f.desc(τ, x)); + + i.if_verbose(|| { + IterInfo { + value : f.apply(&x) + g.apply(&x), + point : x.clone(), + } + }) + } + + x +}