Fri, 18 Oct 2024 19:52:06 -0500
Forward-backward skeleton
use alg_tools::iterate::AlgIteratorFactory; use alg_tools::mapping::{Mapping, Sum}; use serde::Serialize; use crate::manifold::ManifoldPoint; /// Trait for function objects that implement gradient steps pub trait Desc<M : ManifoldPoint> { fn desc(&self, τ : f64, pt : M) -> M; } /// Trait for function objects that implement proximal steps pub trait Prox<M : ManifoldPoint> { fn prox(&self, τ : f64, pt : M) -> M; } #[derive(Clone,Debug,Serialize)] pub struct IterInfo<M> { value : f64, point : M, } pub fn forward_backward<M, F, G, I>( f : &F, g : &G, mut x : M, τ : f64, iterator : I ) -> M where M : ManifoldPoint, F : Desc<M> + Mapping<M, Codomain = f64>, G : Prox<M> + Mapping<M, Codomain = f64>, I : AlgIteratorFactory<IterInfo<M>> { for i in iterator.iter() { x = g.prox(τ, f.desc(τ, x)); i.if_verbose(|| { IterInfo { value : f.apply(&x) + g.apply(&x), point : x.clone(), } }) } x }