Sat, 19 Oct 2024 10:46:13 -0500
Some distance functions etc.
4 | 1 | |
2 | use alg_tools::iterate::AlgIteratorFactory; | |
3 | use alg_tools::mapping::{Mapping, Sum}; | |
4 | use serde::Serialize; | |
5 | 5 | use std::iter::Sum as SumTrait; |
4 | 6 | |
7 | use crate::manifold::ManifoldPoint; | |
8 | ||
5 | 9 | /// Trait for function objects that implement gradients |
10 | pub trait Grad<M : ManifoldPoint> { | |
11 | fn grad(&self, x : &M) -> M::Tangent; | |
12 | } | |
13 | ||
4 | 14 | /// Trait for function objects that implement gradient steps |
15 | pub trait Desc<M : ManifoldPoint> { | |
5 | 16 | fn desc(&self, τ : f64, x : M) -> M; |
17 | } | |
18 | ||
19 | /*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { | |
20 | fn desc(&self, τ : f64, x : M) -> M { | |
21 | x.exp(self.grad(x) * τ) | |
22 | } | |
23 | }*/ | |
24 | ||
25 | impl<M, T > Desc<M> for Sum<M, T> | |
26 | where M : ManifoldPoint, | |
27 | T : Grad<M> + Mapping<M, Codomain=f64>, | |
28 | M::Tangent : SumTrait { | |
29 | fn desc(&self, τ : f64, x : M) -> M { | |
30 | let t : M::Tangent = self.iter() | |
31 | .map(|f| f.grad(&x)) | |
32 | .sum(); | |
33 | x.exp(&(t * τ)) | |
34 | } | |
4 | 35 | } |
36 | ||
37 | /// Trait for function objects that implement proximal steps | |
38 | pub trait Prox<M : ManifoldPoint> { | |
5 | 39 | fn prox(&self, τ : f64, x : M) -> M; |
4 | 40 | } |
41 | ||
42 | #[derive(Clone,Debug,Serialize)] | |
43 | pub struct IterInfo<M> { | |
44 | value : f64, | |
45 | point : M, | |
46 | } | |
47 | ||
48 | pub fn forward_backward<M, F, G, I>( | |
49 | f : &F, | |
50 | g : &G, | |
51 | mut x : M, | |
52 | τ : f64, | |
53 | iterator : I | |
54 | ) -> M | |
55 | where M : ManifoldPoint, | |
56 | F : Desc<M> + Mapping<M, Codomain = f64>, | |
57 | G : Prox<M> + Mapping<M, Codomain = f64>, | |
58 | I : AlgIteratorFactory<IterInfo<M>> { | |
59 | ||
60 | for i in iterator.iter() { | |
61 | x = g.prox(τ, f.desc(τ, x)); | |
62 | ||
63 | i.if_verbose(|| { | |
64 | IterInfo { | |
65 | value : f.apply(&x) + g.apply(&x), | |
66 | point : x.clone(), | |
67 | } | |
68 | }) | |
69 | } | |
70 | ||
71 | x | |
72 | } |