src/fb.rs

changeset 7
8979a6638424
parent 5
f248e1434c3b
child 12
3b05a8b45b95
equal deleted inserted replaced
6:df9628092285 7:8979a6638424
1 1
2 use alg_tools::iterate::AlgIteratorFactory; 2 use alg_tools::iterate::{AlgIteratorFactory, LogRepr};
3 use alg_tools::mapping::{Mapping, Sum}; 3 use alg_tools::mapping::{Mapping, Sum};
4 use serde::Serialize; 4 use serde::Serialize;
5 use std::iter::Sum as SumTrait; 5 use std::iter::Sum as SumTrait;
6 use colored::ColoredString;
6 7
7 use crate::manifold::ManifoldPoint; 8 use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint};
8 9
9 /// Trait for function objects that implement gradients 10 /// Trait for function objects that implement gradients
10 pub trait Grad<M : ManifoldPoint> { 11 pub trait Grad<M : ManifoldPoint> {
11 fn grad(&self, x : &M) -> M::Tangent; 12 fn grad(&self, x : &M) -> M::Tangent;
12 } 13 }
43 pub struct IterInfo<M> { 44 pub struct IterInfo<M> {
44 value : f64, 45 value : f64,
45 point : M, 46 point : M,
46 } 47 }
47 48
49 impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> {
50 fn logrepr(&self) -> ColoredString {
51 format!("{}\t {}",
52 self.value,
53 self.point.embedded_coords()
54 ).as_str().into()
55 }
56 }
57
48 pub fn forward_backward<M, F, G, I>( 58 pub fn forward_backward<M, F, G, I>(
49 f : &F, 59 f : &F,
50 g : &G, 60 g : &G,
51 mut x : M, 61 mut x : M,
52 τ : f64, 62 τ : f64,
53 iterator : I 63 iterator : I
54 ) -> M 64 ) -> M
55 where M : ManifoldPoint, 65 where M : ManifoldPoint + EmbeddedManifoldPoint,
56 F : Desc<M> + Mapping<M, Codomain = f64>, 66 F : Desc<M> + Mapping<M, Codomain = f64>,
57 G : Prox<M> + Mapping<M, Codomain = f64>, 67 G : Prox<M> + Mapping<M, Codomain = f64>,
58 I : AlgIteratorFactory<IterInfo<M>> { 68 I : AlgIteratorFactory<IterInfo<M>> {
59 69
60 for i in iterator.iter() { 70 for i in iterator.iter() {

mercurial