src/fb.rs

changeset 12
3b05a8b45b95
parent 7
8979a6638424
child 13
f67949050a32
equal deleted inserted replaced
11:933242e0f3b8 12:3b05a8b45b95
2 use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; 2 use alg_tools::iterate::{AlgIteratorFactory, LogRepr};
3 use alg_tools::mapping::{Mapping, Sum}; 3 use alg_tools::mapping::{Mapping, Sum};
4 use serde::Serialize; 4 use serde::Serialize;
5 use std::iter::Sum as SumTrait; 5 use std::iter::Sum as SumTrait;
6 use colored::ColoredString; 6 use colored::ColoredString;
7
8 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; 7 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint};
9 8
10 /// Trait for function objects that implement gradients 9 /// Trait for function objects that implement gradients
11 pub trait Grad<M : ManifoldPoint> { 10 pub trait Grad<M : ManifoldPoint> {
12 fn grad(&self, x : &M) -> M::Tangent; 11 fn grad(&self, x : &M) -> M::Tangent;
38 /// Trait for function objects that implement proximal steps 37 /// Trait for function objects that implement proximal steps
39 pub trait Prox<M : ManifoldPoint> { 38 pub trait Prox<M : ManifoldPoint> {
40 fn prox(&self, τ : f64, x : M) -> M; 39 fn prox(&self, τ : f64, x : M) -> M;
41 } 40 }
42 41
42 /// This structure is used to store information from algorithm iterations
43 #[derive(Clone,Debug,Serialize)] 43 #[derive(Clone,Debug,Serialize)]
44 pub struct IterInfo<M> { 44 pub struct IterInfo<M> {
45 value : f64, 45 /// Function value
46 point : M, 46 pub value : f64,
47 /// Current iterate
48 pub point : M,
47 } 49 }
48 50
49 impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { 51 impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> {
50 fn logrepr(&self) -> ColoredString { 52 fn logrepr(&self) -> ColoredString {
51 format!("{}\t {}", 53 format!("{}\t {}",
53 self.point.embedded_coords() 55 self.point.embedded_coords()
54 ).as_str().into() 56 ).as_str().into()
55 } 57 }
56 } 58 }
57 59
60 /// The forward-backward method on manifolds.
61 ///
62 /// `f` is the smooth, `g` the nonsmooth function, `x` the initial iterate,
63 /// `τ` the step length parameter, and `iterator` controls the iteration count
64 /// and verbosity. Return the final iterate.
58 pub fn forward_backward<M, F, G, I>( 65 pub fn forward_backward<M, F, G, I>(
59 f : &F, 66 f : &F,
60 g : &G, 67 g : &G,
61 mut x : M, 68 mut x : M,
62 τ : f64, 69 τ : f64,
64 ) -> M 71 ) -> M
65 where M : ManifoldPoint + EmbeddedManifoldPoint, 72 where M : ManifoldPoint + EmbeddedManifoldPoint,
66 F : Desc<M> + Mapping<M, Codomain = f64>, 73 F : Desc<M> + Mapping<M, Codomain = f64>,
67 G : Prox<M> + Mapping<M, Codomain = f64>, 74 G : Prox<M> + Mapping<M, Codomain = f64>,
68 I : AlgIteratorFactory<IterInfo<M>> { 75 I : AlgIteratorFactory<IterInfo<M>> {
69 76
77 // Perform as many iterations as requested by `iterator`.
70 for i in iterator.iter() { 78 for i in iterator.iter() {
79 // Forward-backward step
71 x = g.prox(τ, f.desc(τ, x)); 80 x = g.prox(τ, f.desc(τ, x));
72 81
82 // If requested by `iterator`, calculate function value and store iterate.
73 i.if_verbose(|| { 83 i.if_verbose(|| {
74 IterInfo { 84 IterInfo {
75 value : f.apply(&x) + g.apply(&x), 85 value : f.apply(&x) + g.apply(&x),
76 point : x.clone(), 86 point : x.clone(),
77 } 87 }
78 }) 88 })
79 } 89 }
80 90
91 // Return final iterate.
81 x 92 x
82 } 93 }

mercurial