| 2 use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
2 use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
| 3 use alg_tools::mapping::{Mapping, Sum}; |
3 use alg_tools::mapping::{Mapping, Sum}; |
| 4 use serde::Serialize; |
4 use serde::Serialize; |
| 5 use std::iter::Sum as SumTrait; |
5 use std::iter::Sum as SumTrait; |
| 6 use colored::ColoredString; |
6 use colored::ColoredString; |
| 7 |
|
| 8 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; |
7 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; |
| 9 |
8 |
| 10 /// Trait for function objects that implement gradients |
9 /// Trait for function objects that implement gradients |
| 11 pub trait Grad<M : ManifoldPoint> { |
10 pub trait Grad<M : ManifoldPoint> { |
| 12 fn grad(&self, x : &M) -> M::Tangent; |
11 fn grad(&self, x : &M) -> M::Tangent; |
| 38 /// Trait for function objects that implement proximal steps |
37 /// Trait for function objects that implement proximal steps |
| 39 pub trait Prox<M : ManifoldPoint> { |
38 pub trait Prox<M : ManifoldPoint> { |
| 40 fn prox(&self, τ : f64, x : M) -> M; |
39 fn prox(&self, τ : f64, x : M) -> M; |
| 41 } |
40 } |
| 42 |
41 |
| |
42 /// This structure is used to store information from algorithm iterations |
| 43 #[derive(Clone,Debug,Serialize)] |
43 #[derive(Clone,Debug,Serialize)] |
| 44 pub struct IterInfo<M> { |
44 pub struct IterInfo<M> { |
| 45 value : f64, |
45 /// Function value |
| 46 point : M, |
46 pub value : f64, |
| |
47 /// Current iterate |
| |
48 pub point : M, |
| 47 } |
49 } |
| 48 |
50 |
| 49 impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { |
51 impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { |
| 50 fn logrepr(&self) -> ColoredString { |
52 fn logrepr(&self) -> ColoredString { |
| 51 format!("{}\t {}", |
53 format!("{}\t {}", |
| 53 self.point.embedded_coords() |
55 self.point.embedded_coords() |
| 54 ).as_str().into() |
56 ).as_str().into() |
| 55 } |
57 } |
| 56 } |
58 } |
| 57 |
59 |
| |
60 /// The forward-backward method on manifolds. |
| |
61 /// |
| |
62 /// `f` is the smooth, `g` the nonsmooth function, `x` the initial iterate, |
| |
63 /// `τ` the step length parameter, and `iterator` controls the iteration count |
| |
64 /// and verbosity. Return the final iterate. |
| 58 pub fn forward_backward<M, F, G, I>( |
65 pub fn forward_backward<M, F, G, I>( |
| 59 f : &F, |
66 f : &F, |
| 60 g : &G, |
67 g : &G, |
| 61 mut x : M, |
68 mut x : M, |
| 62 τ : f64, |
69 τ : f64, |
| 64 ) -> M |
71 ) -> M |
| 65 where M : ManifoldPoint + EmbeddedManifoldPoint, |
72 where M : ManifoldPoint + EmbeddedManifoldPoint, |
| 66 F : Desc<M> + Mapping<M, Codomain = f64>, |
73 F : Desc<M> + Mapping<M, Codomain = f64>, |
| 67 G : Prox<M> + Mapping<M, Codomain = f64>, |
74 G : Prox<M> + Mapping<M, Codomain = f64>, |
| 68 I : AlgIteratorFactory<IterInfo<M>> { |
75 I : AlgIteratorFactory<IterInfo<M>> { |
| 69 |
76 |
| |
77 // Perform as many iterations as requested by `iterator`. |
| 70 for i in iterator.iter() { |
78 for i in iterator.iter() { |
| |
79 // Forward-backward step |
| 71 x = g.prox(τ, f.desc(τ, x)); |
80 x = g.prox(τ, f.desc(τ, x)); |
| 72 |
81 |
| |
82 // If requested by `iterator`, calculate function value and store iterate. |
| 73 i.if_verbose(|| { |
83 i.if_verbose(|| { |
| 74 IterInfo { |
84 IterInfo { |
| 75 value : f.apply(&x) + g.apply(&x), |
85 value : f.apply(&x) + g.apply(&x), |
| 76 point : x.clone(), |
86 point : x.clone(), |
| 77 } |
87 } |
| 78 }) |
88 }) |
| 79 } |
89 } |
| 80 |
90 |
| |
91 // Return final iterate. |
| 81 x |
92 x |
| 82 } |
93 } |