Mon, 21 Oct 2024 23:07:01 -0500
Basic face image export
| 4 | 1 | |
| 7 | 2 | use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
| 4 | 3 | use alg_tools::mapping::{Mapping, Sum}; |
| 4 | use serde::Serialize; | |
| 5 | 5 | use std::iter::Sum as SumTrait; |
| 7 | 6 | use colored::ColoredString; |
| 4 | 7 | |
| 7 | 8 | use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; |
| 4 | 9 | |
| 5 | 10 | /// Trait for function objects that implement gradients |
| 11 | pub trait Grad<M : ManifoldPoint> { | |
| 12 | fn grad(&self, x : &M) -> M::Tangent; | |
| 13 | } | |
| 14 | ||
| 4 | 15 | /// Trait for function objects that implement gradient steps |
| 16 | pub trait Desc<M : ManifoldPoint> { | |
| 5 | 17 | fn desc(&self, τ : f64, x : M) -> M; |
| 18 | } | |
| 19 | ||
| 20 | /*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { | |
| 21 | fn desc(&self, τ : f64, x : M) -> M { | |
| 22 | x.exp(self.grad(x) * τ) | |
| 23 | } | |
| 24 | }*/ | |
| 25 | ||
| 26 | impl<M, T > Desc<M> for Sum<M, T> | |
| 27 | where M : ManifoldPoint, | |
| 28 | T : Grad<M> + Mapping<M, Codomain=f64>, | |
| 29 | M::Tangent : SumTrait { | |
| 30 | fn desc(&self, τ : f64, x : M) -> M { | |
| 31 | let t : M::Tangent = self.iter() | |
| 32 | .map(|f| f.grad(&x)) | |
| 33 | .sum(); | |
| 34 | x.exp(&(t * τ)) | |
| 35 | } | |
| 4 | 36 | } |
| 37 | ||
| 38 | /// Trait for function objects that implement proximal steps | |
| 39 | pub trait Prox<M : ManifoldPoint> { | |
| 5 | 40 | fn prox(&self, τ : f64, x : M) -> M; |
| 4 | 41 | } |
| 42 | ||
| 43 | #[derive(Clone,Debug,Serialize)] | |
| 44 | pub struct IterInfo<M> { | |
| 45 | value : f64, | |
| 46 | point : M, | |
| 47 | } | |
| 48 | ||
| 7 | 49 | impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { |
| 50 | fn logrepr(&self) -> ColoredString { | |
| 51 | format!("{}\t {}", | |
| 52 | self.value, | |
| 53 | self.point.embedded_coords() | |
| 54 | ).as_str().into() | |
| 55 | } | |
| 56 | } | |
| 57 | ||
| 4 | 58 | pub fn forward_backward<M, F, G, I>( |
| 59 | f : &F, | |
| 60 | g : &G, | |
| 61 | mut x : M, | |
| 62 | τ : f64, | |
| 63 | iterator : I | |
| 64 | ) -> M | |
| 7 | 65 | where M : ManifoldPoint + EmbeddedManifoldPoint, |
| 4 | 66 | F : Desc<M> + Mapping<M, Codomain = f64>, |
| 67 | G : Prox<M> + Mapping<M, Codomain = f64>, | |
| 68 | I : AlgIteratorFactory<IterInfo<M>> { | |
| 69 | ||
| 70 | for i in iterator.iter() { | |
| 71 | x = g.prox(τ, f.desc(τ, x)); | |
| 72 | ||
| 73 | i.if_verbose(|| { | |
| 74 | IterInfo { | |
| 75 | value : f.apply(&x) + g.apply(&x), | |
| 76 | point : x.clone(), | |
| 77 | } | |
| 78 | }) | |
| 79 | } | |
| 80 | ||
| 81 | x | |
| 82 | } |