Fri, 18 Oct 2024 19:52:06 -0500
Forward-backward skeleton
4 | 1 | |
2 | use alg_tools::iterate::AlgIteratorFactory; | |
3 | use alg_tools::mapping::{Mapping, Sum}; | |
4 | use serde::Serialize; | |
5 | ||
6 | use crate::manifold::ManifoldPoint; | |
7 | ||
8 | /// Trait for function objects that implement gradient steps | |
9 | pub trait Desc<M : ManifoldPoint> { | |
10 | fn desc(&self, τ : f64, pt : M) -> M; | |
11 | } | |
12 | ||
13 | /// Trait for function objects that implement proximal steps | |
14 | pub trait Prox<M : ManifoldPoint> { | |
15 | fn prox(&self, τ : f64, pt : M) -> M; | |
16 | } | |
17 | ||
18 | #[derive(Clone,Debug,Serialize)] | |
19 | pub struct IterInfo<M> { | |
20 | value : f64, | |
21 | point : M, | |
22 | } | |
23 | ||
24 | pub fn forward_backward<M, F, G, I>( | |
25 | f : &F, | |
26 | g : &G, | |
27 | mut x : M, | |
28 | τ : f64, | |
29 | iterator : I | |
30 | ) -> M | |
31 | where M : ManifoldPoint, | |
32 | F : Desc<M> + Mapping<M, Codomain = f64>, | |
33 | G : Prox<M> + Mapping<M, Codomain = f64>, | |
34 | I : AlgIteratorFactory<IterInfo<M>> { | |
35 | ||
36 | for i in iterator.iter() { | |
37 | x = g.prox(τ, f.desc(τ, x)); | |
38 | ||
39 | i.if_verbose(|| { | |
40 | IterInfo { | |
41 | value : f.apply(&x) + g.apply(&x), | |
42 | point : x.clone(), | |
43 | } | |
44 | }) | |
45 | } | |
46 | ||
47 | x | |
48 | } |