| 1 /*! |
1 /*! |
| 2 Implementation of the forward-backward method on manifolds. |
2 Implementation of the forward-backward method on manifolds. |
| 3 */ |
3 */ |
| 4 |
4 |
| 5 use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
5 use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; |
| 6 use alg_tools::mapping::{Mapping, Sum}; |
6 use alg_tools::mapping::Mapping; |
| |
7 use alg_tools::operator_arithmetic::MappingSum; |
| 7 use serde::Serialize; |
8 use serde::Serialize; |
| 8 use std::iter::Sum as SumTrait; |
9 use std::iter::Sum; |
| 9 use colored::ColoredString; |
10 use colored::ColoredString; |
| 10 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; |
11 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; |
| 11 |
12 |
| 12 /// Trait for function objects that implement gradients |
13 /// Trait for function objects that implement gradients |
| 13 pub trait Grad<M : ManifoldPoint> { |
14 pub trait Grad<M : ManifoldPoint> { |
| 25 fn desc(&self, τ : f64, x : M) -> M { |
26 fn desc(&self, τ : f64, x : M) -> M { |
| 26 x.exp(self.grad(x) * τ) |
27 x.exp(self.grad(x) * τ) |
| 27 } |
28 } |
| 28 }*/ |
29 }*/ |
| 29 |
30 |
| 30 impl<M, T > Desc<M> for Sum<M, T> |
31 impl<M, T> Desc<M> for MappingSum<T> |
| 31 where M : ManifoldPoint, |
32 where M : ManifoldPoint, |
| 32 T : Grad<M> + Mapping<M, Codomain=f64>, |
33 T : Grad<M> + Mapping<M, Codomain=f64>, |
| 33 M::Tangent : SumTrait { |
34 M::Tangent : Sum { |
| 34 fn desc(&self, τ : f64, x : M) -> M { |
35 fn desc(&self, τ : f64, x : M) -> M { |
| 35 let t : M::Tangent = self.iter() |
36 let t : M::Tangent = self.iter() |
| 36 .map(|f| f.grad(&x)) |
37 .map(|f| f.grad(&x)) |
| 37 .sum(); |
38 .sum(); |
| 38 x.exp(&(t * τ)) |
39 x.exp(&(t * τ)) |