| 1 |
1 |
| 2 use alg_tools::iterate::AlgIteratorFactory; |
2 use alg_tools::iterate::AlgIteratorFactory; |
| 3 use alg_tools::mapping::{Mapping, Sum}; |
3 use alg_tools::mapping::{Mapping, Sum}; |
| 4 use serde::Serialize; |
4 use serde::Serialize; |
| |
5 use std::iter::Sum as SumTrait; |
| 5 |
6 |
| 6 use crate::manifold::ManifoldPoint; |
7 use crate::manifold::ManifoldPoint; |
| 7 |
8 |
| |
9 /// Trait for function objects that implement gradients |
| |
10 pub trait Grad<M : ManifoldPoint> { |
| |
11 fn grad(&self, x : &M) -> M::Tangent; |
| |
12 } |
| |
13 |
| 8 /// Trait for function objects that implement gradient steps |
14 /// Trait for function objects that implement gradient steps |
| 9 pub trait Desc<M : ManifoldPoint> { |
15 pub trait Desc<M : ManifoldPoint> { |
| 10 fn desc(&self, τ : f64, pt : M) -> M; |
16 fn desc(&self, τ : f64, x : M) -> M; |
| |
17 } |
| |
18 |
| |
19 /*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { |
| |
20 fn desc(&self, τ : f64, x : M) -> M { |
| |
21 x.exp(self.grad(x) * τ) |
| |
22 } |
| |
23 }*/ |
| |
24 |
| |
25 impl<M, T > Desc<M> for Sum<M, T> |
| |
26 where M : ManifoldPoint, |
| |
27 T : Grad<M> + Mapping<M, Codomain=f64>, |
| |
28 M::Tangent : SumTrait { |
| |
29 fn desc(&self, τ : f64, x : M) -> M { |
| |
30 let t : M::Tangent = self.iter() |
| |
31 .map(|f| f.grad(&x)) |
| |
32 .sum(); |
| |
33 x.exp(&(t * τ)) |
| |
34 } |
| 11 } |
35 } |
| 12 |
36 |
| 13 /// Trait for function objects that implement proximal steps |
37 /// Trait for function objects that implement proximal steps |
| 14 pub trait Prox<M : ManifoldPoint> { |
38 pub trait Prox<M : ManifoldPoint> { |
| 15 fn prox(&self, τ : f64, pt : M) -> M; |
39 fn prox(&self, τ : f64, x : M) -> M; |
| 16 } |
40 } |
| 17 |
41 |
| 18 #[derive(Clone,Debug,Serialize)] |
42 #[derive(Clone,Debug,Serialize)] |
| 19 pub struct IterInfo<M> { |
43 pub struct IterInfo<M> { |
| 20 value : f64, |
44 value : f64, |