Fri, 28 Mar 2025 08:38:05 -0500
bump version
| 13 | 1 | /*! |
| 2 | Implementation of the forward-backward method on manifolds. | |
| 3 | */ | |
| 4 | 4 | |
| 7 | 5 | use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
| 57 | 6 | use alg_tools::mapping::Mapping; |
| 7 | use alg_tools::operator_arithmetic::MappingSum; | |
| 4 | 8 | use serde::Serialize; |
| 57 | 9 | use std::iter::Sum; |
| 7 | 10 | use colored::ColoredString; |
| 11 | use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; | |
| 4 | 12 | |
| 5 | 13 | /// Trait for function objects that implement gradients |
| 14 | pub trait Grad<M : ManifoldPoint> { | |
| 46 | 15 | /// Calculates the gradient of `self` at `x`. |
| 5 | 16 | fn grad(&self, x : &M) -> M::Tangent; |
| 17 | } | |
| 18 | ||
| 4 | 19 | /// Trait for function objects that implement gradient steps |
| 20 | pub trait Desc<M : ManifoldPoint> { | |
| 46 | 21 | /// Calculates the gradient steps of `self` at `x` for the step length `τ`. |
| 5 | 22 | fn desc(&self, τ : f64, x : M) -> M; |
| 23 | } | |
| 24 | ||
| 25 | /*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { | |
| 26 | fn desc(&self, τ : f64, x : M) -> M { | |
| 27 | x.exp(self.grad(x) * τ) | |
| 28 | } | |
| 29 | }*/ | |
| 30 | ||
| 57 | 31 | impl<M, T> Desc<M> for MappingSum<T> |
| 5 | 32 | where M : ManifoldPoint, |
| 33 | T : Grad<M> + Mapping<M, Codomain=f64>, | |
| 57 | 34 | M::Tangent : Sum { |
| 5 | 35 | fn desc(&self, τ : f64, x : M) -> M { |
| 36 | let t : M::Tangent = self.iter() | |
| 37 | .map(|f| f.grad(&x)) | |
| 38 | .sum(); | |
| 39 | x.exp(&(t * τ)) | |
| 40 | } | |
| 4 | 41 | } |
| 42 | ||
| 43 | /// Trait for function objects that implement proximal steps | |
| 44 | pub trait Prox<M : ManifoldPoint> { | |
| 46 | 45 | /// Calculates the proximap map of `self` at `x` for the step length `τ`. |
| 5 | 46 | fn prox(&self, τ : f64, x : M) -> M; |
| 4 | 47 | } |
| 48 | ||
| 12 | 49 | /// This structure is used to store information from algorithm iterations |
| 4 | 50 | #[derive(Clone,Debug,Serialize)] |
| 51 | pub struct IterInfo<M> { | |
| 12 | 52 | /// Function value |
| 53 | pub value : f64, | |
| 54 | /// Current iterate | |
| 55 | pub point : M, | |
| 4 | 56 | } |
| 57 | ||
| 7 | 58 | impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { |
| 59 | fn logrepr(&self) -> ColoredString { | |
| 60 | format!("{}\t {}", | |
| 61 | self.value, | |
| 62 | self.point.embedded_coords() | |
| 63 | ).as_str().into() | |
| 64 | } | |
| 65 | } | |
| 66 | ||
| 12 | 67 | /// The forward-backward method on manifolds. |
| 68 | /// | |
| 69 | /// `f` is the smooth, `g` the nonsmooth function, `x` the initial iterate, | |
| 70 | /// `τ` the step length parameter, and `iterator` controls the iteration count | |
| 71 | /// and verbosity. Return the final iterate. | |
| 4 | 72 | pub fn forward_backward<M, F, G, I>( |
| 73 | f : &F, | |
| 74 | g : &G, | |
| 75 | mut x : M, | |
| 76 | τ : f64, | |
| 77 | iterator : I | |
| 78 | ) -> M | |
| 56 | 79 | where |
| 80 | M : ManifoldPoint + EmbeddedManifoldPoint, | |
| 81 | F : Desc<M> + Mapping<M, Codomain = f64>, | |
| 82 | G : Prox<M> + Mapping<M, Codomain = f64>, | |
| 83 | I : AlgIteratorFactory<IterInfo<M>> | |
| 84 | { | |
|
22
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
85 | |
|
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
86 | // Closure that calculates current status |
|
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
87 | let status = |x : &M| IterInfo { |
|
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
88 | value : f.apply(x) + g.apply(x), |
|
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
89 | point : x.clone(), |
|
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
90 | }; |
| 12 | 91 | |
| 92 | // Perform as many iterations as requested by `iterator`. | |
|
22
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
93 | for i in iterator.iter_init(|| status(&x)) { |
| 12 | 94 | // Forward-backward step |
| 4 | 95 | x = g.prox(τ, f.desc(τ, x)); |
| 96 | ||
| 12 | 97 | // If requested by `iterator`, calculate function value and store iterate. |
|
22
cecdde4ff5c9
Also store initial iterate in log
Tuomo Valkonen <tuomov@iki.fi>
parents:
13
diff
changeset
|
98 | i.if_verbose(|| status(&x)) |
| 4 | 99 | } |
| 100 | ||
| 12 | 101 | // Return final iterate. |
| 4 | 102 | x |
| 103 | } |