2 use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
2 use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
3 use alg_tools::mapping::{Mapping, Sum}; |
3 use alg_tools::mapping::{Mapping, Sum}; |
4 use serde::Serialize; |
4 use serde::Serialize; |
5 use std::iter::Sum as SumTrait; |
5 use std::iter::Sum as SumTrait; |
6 use colored::ColoredString; |
6 use colored::ColoredString; |
7 |
|
8 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; |
7 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; |
9 |
8 |
10 /// Trait for function objects that implement gradients |
9 /// Trait for function objects that implement gradients |
11 pub trait Grad<M : ManifoldPoint> { |
10 pub trait Grad<M : ManifoldPoint> { |
12 fn grad(&self, x : &M) -> M::Tangent; |
11 fn grad(&self, x : &M) -> M::Tangent; |
38 /// Trait for function objects that implement proximal steps |
37 /// Trait for function objects that implement proximal steps |
39 pub trait Prox<M : ManifoldPoint> { |
38 pub trait Prox<M : ManifoldPoint> { |
40 fn prox(&self, τ : f64, x : M) -> M; |
39 fn prox(&self, τ : f64, x : M) -> M; |
41 } |
40 } |
42 |
41 |
|
42 /// This structure is used to store information from algorithm iterations |
43 #[derive(Clone,Debug,Serialize)] |
43 #[derive(Clone,Debug,Serialize)] |
44 pub struct IterInfo<M> { |
44 pub struct IterInfo<M> { |
45 value : f64, |
45 /// Function value |
46 point : M, |
46 pub value : f64, |
|
47 /// Current iterate |
|
48 pub point : M, |
47 } |
49 } |
48 |
50 |
49 impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { |
51 impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { |
50 fn logrepr(&self) -> ColoredString { |
52 fn logrepr(&self) -> ColoredString { |
51 format!("{}\t {}", |
53 format!("{}\t {}", |
53 self.point.embedded_coords() |
55 self.point.embedded_coords() |
54 ).as_str().into() |
56 ).as_str().into() |
55 } |
57 } |
56 } |
58 } |
57 |
59 |
|
60 /// The forward-backward method on manifolds. |
|
61 /// |
|
62 /// `f` is the smooth, `g` the nonsmooth function, `x` the initial iterate, |
|
63 /// `τ` the step length parameter, and `iterator` controls the iteration count |
|
64 /// and verbosity. Return the final iterate. |
58 pub fn forward_backward<M, F, G, I>( |
65 pub fn forward_backward<M, F, G, I>( |
59 f : &F, |
66 f : &F, |
60 g : &G, |
67 g : &G, |
61 mut x : M, |
68 mut x : M, |
62 τ : f64, |
69 τ : f64, |
64 ) -> M |
71 ) -> M |
65 where M : ManifoldPoint + EmbeddedManifoldPoint, |
72 where M : ManifoldPoint + EmbeddedManifoldPoint, |
66 F : Desc<M> + Mapping<M, Codomain = f64>, |
73 F : Desc<M> + Mapping<M, Codomain = f64>, |
67 G : Prox<M> + Mapping<M, Codomain = f64>, |
74 G : Prox<M> + Mapping<M, Codomain = f64>, |
68 I : AlgIteratorFactory<IterInfo<M>> { |
75 I : AlgIteratorFactory<IterInfo<M>> { |
69 |
76 |
|
77 // Perform as many iterations as requested by `iterator`. |
70 for i in iterator.iter() { |
78 for i in iterator.iter() { |
|
79 // Forward-backward step |
71 x = g.prox(τ, f.desc(τ, x)); |
80 x = g.prox(τ, f.desc(τ, x)); |
72 |
81 |
|
82 // If requested by `iterator`, calculate function value and store iterate. |
73 i.if_verbose(|| { |
83 i.if_verbose(|| { |
74 IterInfo { |
84 IterInfo { |
75 value : f.apply(&x) + g.apply(&x), |
85 value : f.apply(&x) + g.apply(&x), |
76 point : x.clone(), |
86 point : x.clone(), |
77 } |
87 } |
78 }) |
88 }) |
79 } |
89 } |
80 |
90 |
|
91 // Return final iterate. |
81 x |
92 x |
82 } |
93 } |