
use alg_tools::iterate::AlgIteratorFactory;
use alg_tools::mapping::{Mapping, Sum};
use serde::Serialize;

use crate::manifold::ManifoldPoint;

/// Trait for function objects that implement gradient steps
pub trait Desc<M : ManifoldPoint> {
    fn desc(&self, τ : f64, pt : M) -> M;
}

/// Trait for function objects that implement proximal steps
pub trait Prox<M : ManifoldPoint> {
    fn prox(&self, τ : f64, pt : M) -> M;
}

#[derive(Clone,Debug,Serialize)]
pub struct IterInfo<M> {
    value : f64,
    point : M,
}

pub fn forward_backward<M, F, G, I>(
    f : &F,
    g : &G,
    mut x : M,
    τ : f64,
    iterator : I
) -> M
where M : ManifoldPoint,
      F : Desc<M> +  Mapping<M, Codomain = f64>,
      G : Prox<M> +  Mapping<M, Codomain = f64>,
      I : AlgIteratorFactory<IterInfo<M>> {
    
    for i in iterator.iter() {
        x = g.prox(τ, f.desc(τ, x));

        i.if_verbose(|| {
            IterInfo {
                value : f.apply(&x) + g.apply(&x),
                point : x.clone(),
            }
        })
    }

    x
}
