/*!
Implementation of the forward-backward method on manifolds.
*/

use alg_tools::iterate::{AlgIteratorFactory, LogRepr};
use alg_tools::mapping::{Mapping, Sum};
use serde::Serialize;
use std::iter::Sum as SumTrait;
use colored::ColoredString;
use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint};

/// Trait for function objects that implement gradients
pub trait Grad<M : ManifoldPoint> {
    /// Calculates the gradient of `self` at `x`.
    fn grad(&self, x : &M) -> M::Tangent;
}

/// Trait for function objects that implement gradient steps
pub trait Desc<M : ManifoldPoint> {
    /// Calculates the gradient steps of `self` at `x` for the step length `τ`.
    fn desc(&self, τ : f64, x : M) -> M;
}

/*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T {
     fn desc(&self, τ : f64, x : M) -> M {
        x.exp(self.grad(x) * τ)
     }
}*/

impl<M, T > Desc<M> for Sum<M, T>
where M : ManifoldPoint,
      T : Grad<M> + Mapping<M, Codomain=f64>,
      M::Tangent : SumTrait {
    fn desc(&self, τ : f64, x : M) -> M {
        let t : M::Tangent = self.iter()
                                 .map(|f| f.grad(&x))
                                 .sum();
        x.exp(&(t * τ))
    }
}

/// Trait for function objects that implement proximal steps
pub trait Prox<M : ManifoldPoint> {
    /// Calculates the proximap map of `self` at `x` for the step length `τ`.
    fn prox(&self, τ : f64, x : M) -> M;
}

/// This structure is used to store information from algorithm iterations
#[derive(Clone,Debug,Serialize)]
pub struct IterInfo<M> {
    /// Function value
    pub value : f64,
    /// Current iterate
    pub point : M,
}

impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> {
    fn logrepr(&self) -> ColoredString {
        format!("{}\t {}",
            self.value,
            self.point.embedded_coords()
        ).as_str().into()
    }
}

/// The forward-backward method on manifolds.
///
/// `f` is the smooth, `g` the nonsmooth function, `x` the initial iterate,
/// `τ` the step length parameter, and `iterator` controls the iteration count
/// and verbosity. Return the final iterate.
pub fn forward_backward<M, F, G, I>(
    f : &F,
    g : &G,
    mut x : M,
    τ : f64,
    iterator : I
) -> M
where
    M : ManifoldPoint + EmbeddedManifoldPoint,
    F : Desc<M> +  Mapping<M, Codomain = f64>,
    G : Prox<M> +  Mapping<M, Codomain = f64>,
    I : AlgIteratorFactory<IterInfo<M>>
{
    
    // Closure that calculates current status
    let status = |x : &M| IterInfo {
        value : f.apply(x) + g.apply(x),
        point : x.clone(),
    };

    // Perform as many iterations as requested by `iterator`.
    for i in iterator.iter_init(|| status(&x)) {
        // Forward-backward step
        x = g.prox(τ, f.desc(τ, x));

        // If requested by `iterator`, calculate function value and store iterate.
        i.if_verbose(|| status(&x))
    }

    // Return final iterate.
    x
}
