Mon, 21 Oct 2024 23:07:01 -0500
Basic face image export
4 | 1 | |
7 | 2 | use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
4 | 3 | use alg_tools::mapping::{Mapping, Sum}; |
4 | use serde::Serialize; | |
5 | 5 | use std::iter::Sum as SumTrait; |
7 | 6 | use colored::ColoredString; |
4 | 7 | |
7 | 8 | use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; |
4 | 9 | |
5 | 10 | /// Trait for function objects that implement gradients |
11 | pub trait Grad<M : ManifoldPoint> { | |
12 | fn grad(&self, x : &M) -> M::Tangent; | |
13 | } | |
14 | ||
4 | 15 | /// Trait for function objects that implement gradient steps |
16 | pub trait Desc<M : ManifoldPoint> { | |
5 | 17 | fn desc(&self, τ : f64, x : M) -> M; |
18 | } | |
19 | ||
20 | /*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { | |
21 | fn desc(&self, τ : f64, x : M) -> M { | |
22 | x.exp(self.grad(x) * τ) | |
23 | } | |
24 | }*/ | |
25 | ||
26 | impl<M, T > Desc<M> for Sum<M, T> | |
27 | where M : ManifoldPoint, | |
28 | T : Grad<M> + Mapping<M, Codomain=f64>, | |
29 | M::Tangent : SumTrait { | |
30 | fn desc(&self, τ : f64, x : M) -> M { | |
31 | let t : M::Tangent = self.iter() | |
32 | .map(|f| f.grad(&x)) | |
33 | .sum(); | |
34 | x.exp(&(t * τ)) | |
35 | } | |
4 | 36 | } |
37 | ||
38 | /// Trait for function objects that implement proximal steps | |
39 | pub trait Prox<M : ManifoldPoint> { | |
5 | 40 | fn prox(&self, τ : f64, x : M) -> M; |
4 | 41 | } |
42 | ||
43 | #[derive(Clone,Debug,Serialize)] | |
44 | pub struct IterInfo<M> { | |
45 | value : f64, | |
46 | point : M, | |
47 | } | |
48 | ||
7 | 49 | impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { |
50 | fn logrepr(&self) -> ColoredString { | |
51 | format!("{}\t {}", | |
52 | self.value, | |
53 | self.point.embedded_coords() | |
54 | ).as_str().into() | |
55 | } | |
56 | } | |
57 | ||
4 | 58 | pub fn forward_backward<M, F, G, I>( |
59 | f : &F, | |
60 | g : &G, | |
61 | mut x : M, | |
62 | τ : f64, | |
63 | iterator : I | |
64 | ) -> M | |
7 | 65 | where M : ManifoldPoint + EmbeddedManifoldPoint, |
4 | 66 | F : Desc<M> + Mapping<M, Codomain = f64>, |
67 | G : Prox<M> + Mapping<M, Codomain = f64>, | |
68 | I : AlgIteratorFactory<IterInfo<M>> { | |
69 | ||
70 | for i in iterator.iter() { | |
71 | x = g.prox(τ, f.desc(τ, x)); | |
72 | ||
73 | i.if_verbose(|| { | |
74 | IterInfo { | |
75 | value : f.apply(&x) + g.apply(&x), | |
76 | point : x.clone(), | |
77 | } | |
78 | }) | |
79 | } | |
80 | ||
81 | x | |
82 | } |