/*!
Arithmetic of [`Mapping`]s.
 */

use serde::Serialize;
use crate::types::*;
use crate::instance::{Space, Instance};
use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping};

/// A trait for encoding constant [`Float`] values
pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> {
    /// The type of the value
    type Type : Float;
    /// Returns the value of the constant
    fn value(&self) -> Self::Type;
}

impl<F : Float> Constant for F {
    type Type = F;
    #[inline]
    fn value(&self) -> F { *self }
}

/// Weighting of a [`Support`] and [`Apply`] by scalar multiplication;
/// output of [`Support::weigh`].
#[derive(Copy,Clone,Debug,Serialize)]
pub struct Weighted<T, C : Constant> {
    /// The weight
    pub weight : C,
    /// The base [`Support`] or [`Apply`] being weighted.
    pub base_fn : T,
}

impl<T, C> Weighted<T, C>
where
    C : Constant,
{
    /// Construct from an iterator.
    pub fn new(weight : C, base_fn : T) -> Self {
        Weighted{ weight, base_fn }
    }
}

impl<'a, T, V, D, F, C> Mapping<D> for Weighted<T, C>
where
    F : Float,
    D : Space,
    T : Mapping<D, Codomain=V>,
    V : Space + ClosedMul<F>,
    C : Constant<Type=F>
{
    type Codomain = V;

    #[inline]
    fn apply<I : Instance<D>>(&self, x : I) -> Self::Codomain {
        self.base_fn.apply(x) * self.weight.value()
    }
}

impl<'a, T, V, D, F, C> DifferentiableImpl<D> for Weighted<T, C>
where
    F : Float,
    D : Space,
    T : DifferentiableMapping<D, DerivativeDomain=V>,
    V : Space + std::ops::Mul<F, Output=V>,
    C : Constant<Type=F>
{
    type Derivative = V;

    #[inline]
    fn differential_impl<I : Instance<D>>(&self, x : I) -> Self::Derivative {
        self.base_fn.differential(x) * self.weight.value()
    }
}

/// A sum of [`Mapping`]s.
#[derive(Serialize, Debug, Clone)]
pub struct MappingSum<M>(Vec<M>);

impl< M> MappingSum<M> {
    /// Construct from an iterator.
    pub fn new<I : IntoIterator<Item = M>>(iter : I) -> Self {
        MappingSum(iter.into_iter().collect())
    }

    /// Iterate over the component functions of the sum
    pub fn iter(&self) -> std::slice::Iter<'_, M> {
        self.0.iter()
    }
}

impl<Domain, M> Mapping<Domain> for MappingSum<M>
where
    Domain : Space + Clone,
    M : Mapping<Domain>,
    M::Codomain : std::iter::Sum + Clone
{
    type Codomain = M::Codomain;

    fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain {
        let xr = x.ref_instance();
        self.0.iter().map(|c| c.apply(xr)).sum()
    }
}

impl<Domain, M> DifferentiableImpl<Domain> for MappingSum< M>
where
    Domain : Space + Clone,
    M : DifferentiableMapping<Domain>,
    M :: DerivativeDomain : std::iter::Sum
{
    type Derivative = M::DerivativeDomain;

    fn differential_impl<I : Instance<Domain>>(&self, x : I) -> Self::Derivative {
        let xr = x.ref_instance();
        self.0.iter().map(|c| c.differential(xr)).sum()
    }
}
