Mon, 21 Oct 2024 23:07:01 -0500
Basic face image export
use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; use alg_tools::mapping::{Mapping, Sum}; use serde::Serialize; use std::iter::Sum as SumTrait; use colored::ColoredString; use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; /// Trait for function objects that implement gradients pub trait Grad<M : ManifoldPoint> { fn grad(&self, x : &M) -> M::Tangent; } /// Trait for function objects that implement gradient steps pub trait Desc<M : ManifoldPoint> { fn desc(&self, τ : f64, x : M) -> M; } /*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { fn desc(&self, τ : f64, x : M) -> M { x.exp(self.grad(x) * τ) } }*/ impl<M, T > Desc<M> for Sum<M, T> where M : ManifoldPoint, T : Grad<M> + Mapping<M, Codomain=f64>, M::Tangent : SumTrait { fn desc(&self, τ : f64, x : M) -> M { let t : M::Tangent = self.iter() .map(|f| f.grad(&x)) .sum(); x.exp(&(t * τ)) } } /// Trait for function objects that implement proximal steps pub trait Prox<M : ManifoldPoint> { fn prox(&self, τ : f64, x : M) -> M; } #[derive(Clone,Debug,Serialize)] pub struct IterInfo<M> { value : f64, point : M, } impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { fn logrepr(&self) -> ColoredString { format!("{}\t {}", self.value, self.point.embedded_coords() ).as_str().into() } } pub fn forward_backward<M, F, G, I>( f : &F, g : &G, mut x : M, τ : f64, iterator : I ) -> M where M : ManifoldPoint + EmbeddedManifoldPoint, F : Desc<M> + Mapping<M, Codomain = f64>, G : Prox<M> + Mapping<M, Codomain = f64>, I : AlgIteratorFactory<IterInfo<M>> { for i in iterator.iter() { x = g.prox(τ, f.desc(τ, x)); i.if_verbose(|| { IterInfo { value : f.apply(&x) + g.apply(&x), point : x.clone(), } }) } x }