1 |
1 |
2 use alg_tools::iterate::AlgIteratorFactory; |
2 use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
3 use alg_tools::mapping::{Mapping, Sum}; |
3 use alg_tools::mapping::{Mapping, Sum}; |
4 use serde::Serialize; |
4 use serde::Serialize; |
5 use std::iter::Sum as SumTrait; |
5 use std::iter::Sum as SumTrait; |
|
6 use colored::ColoredString; |
6 |
7 |
7 use crate::manifold::ManifoldPoint; |
8 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; |
8 |
9 |
9 /// Trait for function objects that implement gradients |
10 /// Trait for function objects that implement gradients |
10 pub trait Grad<M : ManifoldPoint> { |
11 pub trait Grad<M : ManifoldPoint> { |
11 fn grad(&self, x : &M) -> M::Tangent; |
12 fn grad(&self, x : &M) -> M::Tangent; |
12 } |
13 } |
43 pub struct IterInfo<M> { |
44 pub struct IterInfo<M> { |
44 value : f64, |
45 value : f64, |
45 point : M, |
46 point : M, |
46 } |
47 } |
47 |
48 |
|
49 impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { |
|
50 fn logrepr(&self) -> ColoredString { |
|
51 format!("{}\t {}", |
|
52 self.value, |
|
53 self.point.embedded_coords() |
|
54 ).as_str().into() |
|
55 } |
|
56 } |
|
57 |
48 pub fn forward_backward<M, F, G, I>( |
58 pub fn forward_backward<M, F, G, I>( |
49 f : &F, |
59 f : &F, |
50 g : &G, |
60 g : &G, |
51 mut x : M, |
61 mut x : M, |
52 τ : f64, |
62 τ : f64, |
53 iterator : I |
63 iterator : I |
54 ) -> M |
64 ) -> M |
55 where M : ManifoldPoint, |
65 where M : ManifoldPoint + EmbeddedManifoldPoint, |
56 F : Desc<M> + Mapping<M, Codomain = f64>, |
66 F : Desc<M> + Mapping<M, Codomain = f64>, |
57 G : Prox<M> + Mapping<M, Codomain = f64>, |
67 G : Prox<M> + Mapping<M, Codomain = f64>, |
58 I : AlgIteratorFactory<IterInfo<M>> { |
68 I : AlgIteratorFactory<IterInfo<M>> { |
59 |
69 |
60 for i in iterator.iter() { |
70 for i in iterator.iter() { |