1 |
1 |
2 use alg_tools::iterate::AlgIteratorFactory; |
2 use alg_tools::iterate::AlgIteratorFactory; |
3 use alg_tools::mapping::{Mapping, Sum}; |
3 use alg_tools::mapping::{Mapping, Sum}; |
4 use serde::Serialize; |
4 use serde::Serialize; |
|
5 use std::iter::Sum as SumTrait; |
5 |
6 |
6 use crate::manifold::ManifoldPoint; |
7 use crate::manifold::ManifoldPoint; |
7 |
8 |
|
9 /// Trait for function objects that implement gradients |
|
10 pub trait Grad<M : ManifoldPoint> { |
|
11 fn grad(&self, x : &M) -> M::Tangent; |
|
12 } |
|
13 |
8 /// Trait for function objects that implement gradient steps |
14 /// Trait for function objects that implement gradient steps |
9 pub trait Desc<M : ManifoldPoint> { |
15 pub trait Desc<M : ManifoldPoint> { |
10 fn desc(&self, τ : f64, pt : M) -> M; |
16 fn desc(&self, τ : f64, x : M) -> M; |
|
17 } |
|
18 |
|
19 /*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { |
|
20 fn desc(&self, τ : f64, x : M) -> M { |
|
21 x.exp(self.grad(x) * τ) |
|
22 } |
|
23 }*/ |
|
24 |
|
25 impl<M, T > Desc<M> for Sum<M, T> |
|
26 where M : ManifoldPoint, |
|
27 T : Grad<M> + Mapping<M, Codomain=f64>, |
|
28 M::Tangent : SumTrait { |
|
29 fn desc(&self, τ : f64, x : M) -> M { |
|
30 let t : M::Tangent = self.iter() |
|
31 .map(|f| f.grad(&x)) |
|
32 .sum(); |
|
33 x.exp(&(t * τ)) |
|
34 } |
11 } |
35 } |
12 |
36 |
13 /// Trait for function objects that implement proximal steps |
37 /// Trait for function objects that implement proximal steps |
14 pub trait Prox<M : ManifoldPoint> { |
38 pub trait Prox<M : ManifoldPoint> { |
15 fn prox(&self, τ : f64, pt : M) -> M; |
39 fn prox(&self, τ : f64, x : M) -> M; |
16 } |
40 } |
17 |
41 |
18 #[derive(Clone,Debug,Serialize)] |
42 #[derive(Clone,Debug,Serialize)] |
19 pub struct IterInfo<M> { |
43 pub struct IterInfo<M> { |
20 value : f64, |
44 value : f64, |