| |
1 |
| |
2 use alg_tools::iterate::AlgIteratorFactory; |
| |
3 use alg_tools::mapping::{Mapping, Sum}; |
| |
4 use serde::Serialize; |
| |
5 |
| |
6 use crate::manifold::ManifoldPoint; |
| |
7 |
| |
8 /// Trait for function objects that implement gradient steps |
| |
9 pub trait Desc<M : ManifoldPoint> { |
| |
10 fn desc(&self, τ : f64, pt : M) -> M; |
| |
11 } |
| |
12 |
| |
13 /// Trait for function objects that implement proximal steps |
| |
14 pub trait Prox<M : ManifoldPoint> { |
| |
15 fn prox(&self, τ : f64, pt : M) -> M; |
| |
16 } |
| |
17 |
| |
18 #[derive(Clone,Debug,Serialize)] |
| |
19 pub struct IterInfo<M> { |
| |
20 value : f64, |
| |
21 point : M, |
| |
22 } |
| |
23 |
| |
24 pub fn forward_backward<M, F, G, I>( |
| |
25 f : &F, |
| |
26 g : &G, |
| |
27 mut x : M, |
| |
28 τ : f64, |
| |
29 iterator : I |
| |
30 ) -> M |
| |
31 where M : ManifoldPoint, |
| |
32 F : Desc<M> + Mapping<M, Codomain = f64>, |
| |
33 G : Prox<M> + Mapping<M, Codomain = f64>, |
| |
34 I : AlgIteratorFactory<IterInfo<M>> { |
| |
35 |
| |
36 for i in iterator.iter() { |
| |
37 x = g.prox(τ, f.desc(τ, x)); |
| |
38 |
| |
39 i.if_verbose(|| { |
| |
40 IterInfo { |
| |
41 value : f.apply(&x) + g.apply(&x), |
| |
42 point : x.clone(), |
| |
43 } |
| |
44 }) |
| |
45 } |
| |
46 |
| |
47 x |
| |
48 } |