Fri, 06 Dec 2024 14:57:11 -0500
More documentation
13 | 1 | /*! |
2 | Implementation of the forward-backward method on manifolds. | |
3 | */ | |
4 | 4 | |
7 | 5 | use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
4 | 6 | use alg_tools::mapping::{Mapping, Sum}; |
7 | use serde::Serialize; | |
5 | 8 | use std::iter::Sum as SumTrait; |
7 | 9 | use colored::ColoredString; |
10 | use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; | |
4 | 11 | |
5 | 12 | /// Trait for function objects that implement gradients |
13 | pub trait Grad<M : ManifoldPoint> { | |
46 | 14 | /// Calculates the gradient of `self` at `x`. |
5 | 15 | fn grad(&self, x : &M) -> M::Tangent; |
16 | } | |
17 | ||
4 | 18 | /// Trait for function objects that implement gradient steps |
19 | pub trait Desc<M : ManifoldPoint> { | |
46 | 20 | /// Calculates the gradient steps of `self` at `x` for the step length `τ`. |
5 | 21 | fn desc(&self, τ : f64, x : M) -> M; |
22 | } | |
23 | ||
24 | /*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { | |
25 | fn desc(&self, τ : f64, x : M) -> M { | |
26 | x.exp(self.grad(x) * τ) | |
27 | } | |
28 | }*/ | |
29 | ||
30 | impl<M, T > Desc<M> for Sum<M, T> | |
31 | where M : ManifoldPoint, | |
32 | T : Grad<M> + Mapping<M, Codomain=f64>, | |
33 | M::Tangent : SumTrait { | |
34 | fn desc(&self, τ : f64, x : M) -> M { | |
35 | let t : M::Tangent = self.iter() | |
36 | .map(|f| f.grad(&x)) | |
37 | .sum(); | |
38 | x.exp(&(t * τ)) | |
39 | } | |
4 | 40 | } |
41 | ||
42 | /// Trait for function objects that implement proximal steps | |
43 | pub trait Prox<M : ManifoldPoint> { | |
46 | 44 | /// Calculates the proximap map of `self` at `x` for the step length `τ`. |
5 | 45 | fn prox(&self, τ : f64, x : M) -> M; |
4 | 46 | } |
47 | ||
12 | 48 | /// This structure is used to store information from algorithm iterations |
4 | 49 | #[derive(Clone,Debug,Serialize)] |
50 | pub struct IterInfo<M> { | |
12 | 51 | /// Function value |
52 | pub value : f64, | |
53 | /// Current iterate | |
54 | pub point : M, | |
4 | 55 | } |
56 | ||
7 | 57 | impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { |
58 | fn logrepr(&self) -> ColoredString { | |
59 | format!("{}\t {}", | |
60 | self.value, | |
61 | self.point.embedded_coords() | |
62 | ).as_str().into() | |
63 | } | |
64 | } | |
65 | ||
12 | 66 | /// The forward-backward method on manifolds. |
67 | /// | |
68 | /// `f` is the smooth, `g` the nonsmooth function, `x` the initial iterate, | |
69 | /// `τ` the step length parameter, and `iterator` controls the iteration count | |
70 | /// and verbosity. Return the final iterate. | |
4 | 71 | pub fn forward_backward<M, F, G, I>( |
72 | f : &F, | |
73 | g : &G, | |
74 | mut x : M, | |
75 | τ : f64, | |
76 | iterator : I | |
77 | ) -> M | |
7 | 78 | where M : ManifoldPoint + EmbeddedManifoldPoint, |
4 | 79 | F : Desc<M> + Mapping<M, Codomain = f64>, |
80 | G : Prox<M> + Mapping<M, Codomain = f64>, | |
81 | I : AlgIteratorFactory<IterInfo<M>> { | |
22
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
82 | |
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
83 | // Closure that calculates current status |
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
84 | let status = |x : &M| IterInfo { |
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
85 | value : f.apply(x) + g.apply(x), |
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
86 | point : x.clone(), |
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
87 | }; |
12 | 88 | |
89 | // Perform as many iterations as requested by `iterator`. | |
22
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
90 | for i in iterator.iter_init(|| status(&x)) { |
12 | 91 | // Forward-backward step |
4 | 92 | x = g.prox(τ, f.desc(τ, x)); |
93 | ||
12 | 94 | // If requested by `iterator`, calculate function value and store iterate. |
22
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
95 | i.if_verbose(|| status(&x)) |
4 | 96 | } |
97 | ||
12 | 98 | // Return final iterate. |
4 | 99 | x |
100 | } |