|
1 |
|
2 use alg_tools::iterate::AlgIteratorFactory; |
|
3 use alg_tools::mapping::{Mapping, Sum}; |
|
4 use serde::Serialize; |
|
5 |
|
6 use crate::manifold::ManifoldPoint; |
|
7 |
|
8 /// Trait for function objects that implement gradient steps |
|
9 pub trait Desc<M : ManifoldPoint> { |
|
10 fn desc(&self, τ : f64, pt : M) -> M; |
|
11 } |
|
12 |
|
13 /// Trait for function objects that implement proximal steps |
|
14 pub trait Prox<M : ManifoldPoint> { |
|
15 fn prox(&self, τ : f64, pt : M) -> M; |
|
16 } |
|
17 |
|
18 #[derive(Clone,Debug,Serialize)] |
|
19 pub struct IterInfo<M> { |
|
20 value : f64, |
|
21 point : M, |
|
22 } |
|
23 |
|
24 pub fn forward_backward<M, F, G, I>( |
|
25 f : &F, |
|
26 g : &G, |
|
27 mut x : M, |
|
28 τ : f64, |
|
29 iterator : I |
|
30 ) -> M |
|
31 where M : ManifoldPoint, |
|
32 F : Desc<M> + Mapping<M, Codomain = f64>, |
|
33 G : Prox<M> + Mapping<M, Codomain = f64>, |
|
34 I : AlgIteratorFactory<IterInfo<M>> { |
|
35 |
|
36 for i in iterator.iter() { |
|
37 x = g.prox(τ, f.desc(τ, x)); |
|
38 |
|
39 i.if_verbose(|| { |
|
40 IterInfo { |
|
41 value : f.apply(&x) + g.apply(&x), |
|
42 point : x.clone(), |
|
43 } |
|
44 }) |
|
45 } |
|
46 |
|
47 x |
|
48 } |