Sat, 19 Oct 2024 10:46:13 -0500
Some distance functions etc.
use alg_tools::iterate::AlgIteratorFactory; use alg_tools::mapping::{Mapping, Sum}; use serde::Serialize; use std::iter::Sum as SumTrait; use crate::manifold::ManifoldPoint; /// Trait for function objects that implement gradients pub trait Grad<M : ManifoldPoint> { fn grad(&self, x : &M) -> M::Tangent; } /// Trait for function objects that implement gradient steps pub trait Desc<M : ManifoldPoint> { fn desc(&self, τ : f64, x : M) -> M; } /*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { fn desc(&self, τ : f64, x : M) -> M { x.exp(self.grad(x) * τ) } }*/ impl<M, T > Desc<M> for Sum<M, T> where M : ManifoldPoint, T : Grad<M> + Mapping<M, Codomain=f64>, M::Tangent : SumTrait { fn desc(&self, τ : f64, x : M) -> M { let t : M::Tangent = self.iter() .map(|f| f.grad(&x)) .sum(); x.exp(&(t * τ)) } } /// Trait for function objects that implement proximal steps pub trait Prox<M : ManifoldPoint> { fn prox(&self, τ : f64, x : M) -> M; } #[derive(Clone,Debug,Serialize)] pub struct IterInfo<M> { value : f64, point : M, } pub fn forward_backward<M, F, G, I>( f : &F, g : &G, mut x : M, τ : f64, iterator : I ) -> M where M : ManifoldPoint, F : Desc<M> + Mapping<M, Codomain = f64>, G : Prox<M> + Mapping<M, Codomain = f64>, I : AlgIteratorFactory<IterInfo<M>> { for i in iterator.iter() { x = g.prox(τ, f.desc(τ, x)); i.if_verbose(|| { IterInfo { value : f.apply(&x) + g.apply(&x), point : x.clone(), } }) } x }