--- a/src/fb.rs Mon Oct 21 23:07:01 2024 -0500 +++ b/src/fb.rs Tue Oct 22 08:27:45 2024 -0500 @@ -4,7 +4,6 @@ use serde::Serialize; use std::iter::Sum as SumTrait; use colored::ColoredString; - use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; /// Trait for function objects that implement gradients @@ -40,10 +39,13 @@ fn prox(&self, τ : f64, x : M) -> M; } +/// This structure is used to store information from algorithm iterations #[derive(Clone,Debug,Serialize)] pub struct IterInfo<M> { - value : f64, - point : M, + /// Function value + pub value : f64, + /// Current iterate + pub point : M, } impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { @@ -55,6 +57,11 @@ } } +/// The forward-backward method on manifolds. +/// +/// `f` is the smooth, `g` the nonsmooth function, `x` the initial iterate, +/// `τ` the step length parameter, and `iterator` controls the iteration count +/// and verbosity. Return the final iterate. pub fn forward_backward<M, F, G, I>( f : &F, g : &G, @@ -66,10 +73,13 @@ F : Desc<M> + Mapping<M, Codomain = f64>, G : Prox<M> + Mapping<M, Codomain = f64>, I : AlgIteratorFactory<IterInfo<M>> { - + + // Perform as many iterations as requested by `iterator`. for i in iterator.iter() { + // Forward-backward step x = g.prox(τ, f.desc(τ, x)); + // If requested by `iterator`, calculate function value and store iterate. i.if_verbose(|| { IterInfo { value : f.apply(&x) + g.apply(&x), @@ -78,5 +88,6 @@ }) } + // Return final iterate. x }