src/fb.rs

changeset 57
1afca417d71b
parent 56
34f8ec636368
equal deleted inserted replaced
56:34f8ec636368 57:1afca417d71b
1 /*! 1 /*!
2 Implementation of the forward-backward method on manifolds. 2 Implementation of the forward-backward method on manifolds.
3 */ 3 */
4 4
5 use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; 5 use alg_tools::iterate::{AlgIteratorFactory, LogRepr};
6 use alg_tools::mapping::{Mapping, Sum}; 6 use alg_tools::mapping::Mapping;
7 use alg_tools::operator_arithmetic::MappingSum;
7 use serde::Serialize; 8 use serde::Serialize;
8 use std::iter::Sum as SumTrait; 9 use std::iter::Sum;
9 use colored::ColoredString; 10 use colored::ColoredString;
10 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; 11 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint};
11 12
12 /// Trait for function objects that implement gradients 13 /// Trait for function objects that implement gradients
13 pub trait Grad<M : ManifoldPoint> { 14 pub trait Grad<M : ManifoldPoint> {
25 fn desc(&self, τ : f64, x : M) -> M { 26 fn desc(&self, τ : f64, x : M) -> M {
26 x.exp(self.grad(x) * τ) 27 x.exp(self.grad(x) * τ)
27 } 28 }
28 }*/ 29 }*/
29 30
30 impl<M, T > Desc<M> for Sum<M, T> 31 impl<M, T> Desc<M> for MappingSum<T>
31 where M : ManifoldPoint, 32 where M : ManifoldPoint,
32 T : Grad<M> + Mapping<M, Codomain=f64>, 33 T : Grad<M> + Mapping<M, Codomain=f64>,
33 M::Tangent : SumTrait { 34 M::Tangent : Sum {
34 fn desc(&self, τ : f64, x : M) -> M { 35 fn desc(&self, τ : f64, x : M) -> M {
35 let t : M::Tangent = self.iter() 36 let t : M::Tangent = self.iter()
36 .map(|f| f.grad(&x)) 37 .map(|f| f.grad(&x))
37 .sum(); 38 .sum();
38 x.exp(&(t * τ)) 39 x.exp(&(t * τ))

mercurial