src/fb.rs

Wed, 06 Nov 2024 14:58:31 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Wed, 06 Nov 2024 14:58:31 -0500
changeset 21
5f2b65738e66
parent 13
f67949050a32
child 22
cecdde4ff5c9
permissions
-rw-r--r--

Add visualisation TikZ

/*!
Implementation of the forward-backward method on manifolds.
*/

use alg_tools::iterate::{AlgIteratorFactory, LogRepr};
use alg_tools::mapping::{Mapping, Sum};
use serde::Serialize;
use std::iter::Sum as SumTrait;
use colored::ColoredString;
use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint};

/// Trait for function objects that implement gradients
pub trait Grad<M : ManifoldPoint> {
    fn grad(&self, x : &M) -> M::Tangent;
}

/// Trait for function objects that implement gradient steps
pub trait Desc<M : ManifoldPoint> {
    fn desc(&self, τ : f64, x : M) -> M;
}

/*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T {
     fn desc(&self, τ : f64, x : M) -> M {
        x.exp(self.grad(x) * τ)
     }
}*/

impl<M, T > Desc<M> for Sum<M, T>
where M : ManifoldPoint,
      T : Grad<M> + Mapping<M, Codomain=f64>,
      M::Tangent : SumTrait {
    fn desc(&self, τ : f64, x : M) -> M {
        let t : M::Tangent = self.iter()
                                 .map(|f| f.grad(&x))
                                 .sum();
        x.exp(&(t * τ))
    }
}

/// Trait for function objects that implement proximal steps
pub trait Prox<M : ManifoldPoint> {
    fn prox(&self, τ : f64, x : M) -> M;
}

/// This structure is used to store information from algorithm iterations
#[derive(Clone,Debug,Serialize)]
pub struct IterInfo<M> {
    /// Function value
    pub value : f64,
    /// Current iterate
    pub point : M,
}

impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> {
    fn logrepr(&self) -> ColoredString {
        format!("{}\t {}",
            self.value,
            self.point.embedded_coords()
        ).as_str().into()
    }
}

/// The forward-backward method on manifolds.
///
/// `f` is the smooth, `g` the nonsmooth function, `x` the initial iterate,
/// `τ` the step length parameter, and `iterator` controls the iteration count
/// and verbosity. Return the final iterate.
pub fn forward_backward<M, F, G, I>(
    f : &F,
    g : &G,
    mut x : M,
    τ : f64,
    iterator : I
) -> M
where M : ManifoldPoint + EmbeddedManifoldPoint,
      F : Desc<M> +  Mapping<M, Codomain = f64>,
      G : Prox<M> +  Mapping<M, Codomain = f64>,
      I : AlgIteratorFactory<IterInfo<M>> {

    // Perform as many iterations as requested by `iterator`.
    for i in iterator.iter() {
        // Forward-backward step
        x = g.prox(τ, f.desc(τ, x));

        // If requested by `iterator`, calculate function value and store iterate.
        i.if_verbose(|| {
            IterInfo {
                value : f.apply(&x) + g.apply(&x),
                point : x.clone(),
            }
        })
    }

    // Return final iterate.
    x
}

mercurial