/*!
Traits for mathematical functions.
*/

use std::marker::PhantomData;
use crate::types::{Float};
use serde::Serialize;
use crate::loc::Loc;

/// A mapping from `Domain` to `Codomain`.
pub trait Mapping<Domain> {
    type Codomain;

    /// Calculate the value of the mapping at `x`.
    fn value(&self, x : Domain) -> Self::Codomain;
}

/// A helper trait alias for referring to `Mapping`s from references to floats.
pub trait RealRefMapping<F : Float, const N : usize>
: for<'a> Mapping<&'a Loc<F, N>, Codomain=F> {}

impl<F : Float, T, const N : usize> RealRefMapping<F, N> for T
where T : for<'a> Mapping<&'a Loc<F, N>, Codomain=F> {}


/// A differentiable mapping from `Domain` to [`Mapping::Codomain`], with differentials
/// `Differential`.
pub trait DifferentiableMapping<Domain> : Mapping<Domain> {
    type Differential;

    /// Calculate the differentialeof the mapping at `x`.
    fn differential(&self, x : Domain) -> Self::Differential;
}

/// A `Mapping` whose minimum and maximum can be computed.
pub trait RealMapping<Domain> : Mapping<Domain> where Self::Codomain : Float {
    /// Calculate a minimum and a minimiser of the mapping.
    fn minimise(&self, tolerance : Self::Codomain) -> (Domain, Self::Codomain);
    /// Calculate a maximum and a maximiser of the mapping.
    fn maximise(&self, tolerance : Self::Codomain) -> (Domain, Self::Codomain);
}

/// A sum of [`Mapping`]s.
#[derive(Serialize, Debug, Clone)]
pub struct Sum<Domain, M : Mapping<Domain>> {
    components : Vec<M>,
    _domain : PhantomData<Domain>,
}

impl<Domain, M : Mapping<Domain>> Sum<Domain, M> {
    /// Construct from an iterator.
    pub fn new<I : Iterator<Item = M>>(iter : I) -> Self {
        Sum { components : iter.collect(), _domain : PhantomData }
    }
}


impl<Domain, M> Mapping<Domain> for Sum<Domain,M>
where M : Mapping<Domain>,
      M :: Codomain : std::iter::Sum,
      Domain : Copy {

    type Codomain = M::Codomain;

    fn value(&self, x : Domain) -> Self::Codomain {
        self.components.iter().map(|c| c.value(x)).sum()
    }
}

impl<Domain, M> DifferentiableMapping<Domain> for Sum<Domain,M>
where M : DifferentiableMapping<Domain>,
      M :: Codomain : std::iter::Sum,
      M :: Differential : std::iter::Sum,
      Domain : Copy {

    type Differential = M::Differential;

    fn differential(&self, x : Domain) -> Self::Differential {
        self.components.iter().map(|c| c.differential(x)).sum()
    }
}
