src/fb.rs

changeset 5
f248e1434c3b
parent 4
e09437844ad9
child 7
8979a6638424
equal deleted inserted replaced
4:e09437844ad9 5:f248e1434c3b
1 1
2 use alg_tools::iterate::AlgIteratorFactory; 2 use alg_tools::iterate::AlgIteratorFactory;
3 use alg_tools::mapping::{Mapping, Sum}; 3 use alg_tools::mapping::{Mapping, Sum};
4 use serde::Serialize; 4 use serde::Serialize;
5 use std::iter::Sum as SumTrait;
5 6
6 use crate::manifold::ManifoldPoint; 7 use crate::manifold::ManifoldPoint;
7 8
9 /// Trait for function objects that implement gradients
10 pub trait Grad<M : ManifoldPoint> {
11 fn grad(&self, x : &M) -> M::Tangent;
12 }
13
8 /// Trait for function objects that implement gradient steps 14 /// Trait for function objects that implement gradient steps
9 pub trait Desc<M : ManifoldPoint> { 15 pub trait Desc<M : ManifoldPoint> {
10 fn desc(&self, τ : f64, pt : M) -> M; 16 fn desc(&self, τ : f64, x : M) -> M;
17 }
18
19 /*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T {
20 fn desc(&self, τ : f64, x : M) -> M {
21 x.exp(self.grad(x) * τ)
22 }
23 }*/
24
25 impl<M, T > Desc<M> for Sum<M, T>
26 where M : ManifoldPoint,
27 T : Grad<M> + Mapping<M, Codomain=f64>,
28 M::Tangent : SumTrait {
29 fn desc(&self, τ : f64, x : M) -> M {
30 let t : M::Tangent = self.iter()
31 .map(|f| f.grad(&x))
32 .sum();
33 x.exp(&(t * τ))
34 }
11 } 35 }
12 36
13 /// Trait for function objects that implement proximal steps 37 /// Trait for function objects that implement proximal steps
14 pub trait Prox<M : ManifoldPoint> { 38 pub trait Prox<M : ManifoldPoint> {
15 fn prox(&self, τ : f64, pt : M) -> M; 39 fn prox(&self, τ : f64, x : M) -> M;
16 } 40 }
17 41
18 #[derive(Clone,Debug,Serialize)] 42 #[derive(Clone,Debug,Serialize)]
19 pub struct IterInfo<M> { 43 pub struct IterInfo<M> {
20 value : f64, 44 value : f64,

mercurial