src/fb.rs

changeset 4
e09437844ad9
child 5
f248e1434c3b
equal deleted inserted replaced
3:ff4656da04af 4:e09437844ad9
1
2 use alg_tools::iterate::AlgIteratorFactory;
3 use alg_tools::mapping::{Mapping, Sum};
4 use serde::Serialize;
5
6 use crate::manifold::ManifoldPoint;
7
8 /// Trait for function objects that implement gradient steps
9 pub trait Desc<M : ManifoldPoint> {
10 fn desc(&self, τ : f64, pt : M) -> M;
11 }
12
13 /// Trait for function objects that implement proximal steps
14 pub trait Prox<M : ManifoldPoint> {
15 fn prox(&self, τ : f64, pt : M) -> M;
16 }
17
18 #[derive(Clone,Debug,Serialize)]
19 pub struct IterInfo<M> {
20 value : f64,
21 point : M,
22 }
23
24 pub fn forward_backward<M, F, G, I>(
25 f : &F,
26 g : &G,
27 mut x : M,
28 τ : f64,
29 iterator : I
30 ) -> M
31 where M : ManifoldPoint,
32 F : Desc<M> + Mapping<M, Codomain = f64>,
33 G : Prox<M> + Mapping<M, Codomain = f64>,
34 I : AlgIteratorFactory<IterInfo<M>> {
35
36 for i in iterator.iter() {
37 x = g.prox(τ, f.desc(τ, x));
38
39 i.if_verbose(|| {
40 IterInfo {
41 value : f.apply(&x) + g.apply(&x),
42 point : x.clone(),
43 }
44 })
45 }
46
47 x
48 }

mercurial