/*!
Implementation of scaling of functions on a manifold by a scalar.
*/

use alg_tools::mapping::{Mapping, Instance};
use crate::manifold::ManifoldPoint;
use crate::fb::{Grad, Desc, Prox};

/// Structure for a function of type `G`, scaled by a scalar.
pub struct Scaled<G> {
    /// Scaling factor
    α : f64,
    /// The base function
    g : G,
}

impl<G> Scaled<G> {
    /// Creates a new scaled function.
    #[allow(dead_code)]
    pub fn new(α : f64, g : G) -> Self {
        Scaled{ α, g }
    }
}

impl<M, G : Mapping<M, Codomain=f64>> Mapping<M> for Scaled< G> {
    type Codomain = f64;

    fn apply<I : Instance<M>>(&self, x : I) -> Self::Codomain {
        self.g.apply(x) * self.α
    }
}

impl<M : ManifoldPoint, G : Desc<M>> Desc<M> for Scaled<G> {
    fn desc(&self, τ : f64, x : M) -> M {
        self.g.desc(τ * self.α, x)
    }
}

impl<M : ManifoldPoint, G : Grad<M>> Grad<M> for Scaled<G> {
    fn grad(&self, x : &M) -> M::Tangent {
       self.g.grad(x) * self.α
    }
}

impl<M : ManifoldPoint, G : Prox<M>> Prox<M> for Scaled<G> {
    fn prox(&self, τ : f64, x : M) -> M {
        self.g.prox(τ * self.α, x)
    }
}