1 /*! |
1 /*! |
2 Implementation of the forward-backward method on manifolds. |
2 Implementation of the forward-backward method on manifolds. |
3 */ |
3 */ |
4 |
4 |
5 use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
5 use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
6 use alg_tools::mapping::{Mapping, Sum}; |
6 use alg_tools::mapping::Mapping; |
|
7 use alg_tools::operator_arithmetic::MappingSum; |
7 use serde::Serialize; |
8 use serde::Serialize; |
8 use std::iter::Sum as SumTrait; |
9 use std::iter::Sum; |
9 use colored::ColoredString; |
10 use colored::ColoredString; |
10 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; |
11 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; |
11 |
12 |
12 /// Trait for function objects that implement gradients |
13 /// Trait for function objects that implement gradients |
13 pub trait Grad<M : ManifoldPoint> { |
14 pub trait Grad<M : ManifoldPoint> { |
25 fn desc(&self, τ : f64, x : M) -> M { |
26 fn desc(&self, τ : f64, x : M) -> M { |
26 x.exp(self.grad(x) * τ) |
27 x.exp(self.grad(x) * τ) |
27 } |
28 } |
28 }*/ |
29 }*/ |
29 |
30 |
30 impl<M, T > Desc<M> for Sum<M, T> |
31 impl<M, T> Desc<M> for MappingSum<T> |
31 where M : ManifoldPoint, |
32 where M : ManifoldPoint, |
32 T : Grad<M> + Mapping<M, Codomain=f64>, |
33 T : Grad<M> + Mapping<M, Codomain=f64>, |
33 M::Tangent : SumTrait { |
34 M::Tangent : Sum { |
34 fn desc(&self, τ : f64, x : M) -> M { |
35 fn desc(&self, τ : f64, x : M) -> M { |
35 let t : M::Tangent = self.iter() |
36 let t : M::Tangent = self.iter() |
36 .map(|f| f.grad(&x)) |
37 .map(|f| f.grad(&x)) |
37 .sum(); |
38 .sum(); |
38 x.exp(&(t * τ)) |
39 x.exp(&(t * τ)) |