src/fb.rs

Mon, 21 Oct 2024 14:02:52 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Mon, 21 Oct 2024 14:02:52 -0500
changeset 9
0fa3ac0c248b
parent 7
8979a6638424
child 12
3b05a8b45b95
permissions
-rw-r--r--

change experiment


use alg_tools::iterate::{AlgIteratorFactory, LogRepr};
use alg_tools::mapping::{Mapping, Sum};
use serde::Serialize;
use std::iter::Sum as SumTrait;
use colored::ColoredString;

use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint};

/// Trait for function objects that implement gradients
pub trait Grad<M : ManifoldPoint> {
    fn grad(&self, x : &M) -> M::Tangent;
}

/// Trait for function objects that implement gradient steps
pub trait Desc<M : ManifoldPoint> {
    fn desc(&self, τ : f64, x : M) -> M;
}

/*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T {
     fn desc(&self, τ : f64, x : M) -> M {
        x.exp(self.grad(x) * τ)
     }
}*/

impl<M, T > Desc<M> for Sum<M, T>
where M : ManifoldPoint,
      T : Grad<M> + Mapping<M, Codomain=f64>,
      M::Tangent : SumTrait {
    fn desc(&self, τ : f64, x : M) -> M {
        let t : M::Tangent = self.iter()
                                 .map(|f| f.grad(&x))
                                 .sum();
        x.exp(&(t * τ))
    }
}

/// Trait for function objects that implement proximal steps
pub trait Prox<M : ManifoldPoint> {
    fn prox(&self, τ : f64, x : M) -> M;
}

#[derive(Clone,Debug,Serialize)]
pub struct IterInfo<M> {
    value : f64,
    point : M,
}

impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> {
    fn logrepr(&self) -> ColoredString {
        format!("{}\t {}",
            self.value,
            self.point.embedded_coords()
        ).as_str().into()
    }
}

pub fn forward_backward<M, F, G, I>(
    f : &F,
    g : &G,
    mut x : M,
    τ : f64,
    iterator : I
) -> M
where M : ManifoldPoint + EmbeddedManifoldPoint,
      F : Desc<M> +  Mapping<M, Codomain = f64>,
      G : Prox<M> +  Mapping<M, Codomain = f64>,
      I : AlgIteratorFactory<IterInfo<M>> {
    
    for i in iterator.iter() {
        x = g.prox(τ, f.desc(τ, x));

        i.if_verbose(|| {
            IterInfo {
                value : f.apply(&x) + g.apply(&x),
                point : x.clone(),
            }
        })
    }

    x
}

mercurial