Wed, 06 Nov 2024 14:58:31 -0500
Add visualisation TikZ
/*! Implementation of the forward-backward method on manifolds. */ use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; use alg_tools::mapping::{Mapping, Sum}; use serde::Serialize; use std::iter::Sum as SumTrait; use colored::ColoredString; use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; /// Trait for function objects that implement gradients pub trait Grad<M : ManifoldPoint> { fn grad(&self, x : &M) -> M::Tangent; } /// Trait for function objects that implement gradient steps pub trait Desc<M : ManifoldPoint> { fn desc(&self, τ : f64, x : M) -> M; } /*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { fn desc(&self, τ : f64, x : M) -> M { x.exp(self.grad(x) * τ) } }*/ impl<M, T > Desc<M> for Sum<M, T> where M : ManifoldPoint, T : Grad<M> + Mapping<M, Codomain=f64>, M::Tangent : SumTrait { fn desc(&self, τ : f64, x : M) -> M { let t : M::Tangent = self.iter() .map(|f| f.grad(&x)) .sum(); x.exp(&(t * τ)) } } /// Trait for function objects that implement proximal steps pub trait Prox<M : ManifoldPoint> { fn prox(&self, τ : f64, x : M) -> M; } /// This structure is used to store information from algorithm iterations #[derive(Clone,Debug,Serialize)] pub struct IterInfo<M> { /// Function value pub value : f64, /// Current iterate pub point : M, } impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { fn logrepr(&self) -> ColoredString { format!("{}\t {}", self.value, self.point.embedded_coords() ).as_str().into() } } /// The forward-backward method on manifolds. /// /// `f` is the smooth, `g` the nonsmooth function, `x` the initial iterate, /// `τ` the step length parameter, and `iterator` controls the iteration count /// and verbosity. Return the final iterate. pub fn forward_backward<M, F, G, I>( f : &F, g : &G, mut x : M, τ : f64, iterator : I ) -> M where M : ManifoldPoint + EmbeddedManifoldPoint, F : Desc<M> + Mapping<M, Codomain = f64>, G : Prox<M> + Mapping<M, Codomain = f64>, I : AlgIteratorFactory<IterInfo<M>> { // Perform as many iterations as requested by `iterator`. for i in iterator.iter() { // Forward-backward step x = g.prox(τ, f.desc(τ, x)); // If requested by `iterator`, calculate function value and store iterate. i.if_verbose(|| { IterInfo { value : f.apply(&x) + g.apply(&x), point : x.clone(), } }) } // Return final iterate. x }