/*!
Traits for mathematical functions.
*/

use std::marker::PhantomData;
use std::borrow::Cow;
use crate::types::{Num, Float, ClosedMul};
use crate::loc::Loc;
pub use crate::instance::{Instance, Decomposition, BasicDecomposition, Space};
use crate::norms::{Norm, NormExponent};

pub trait ArithmeticOptIn {}

pub struct ArithmeticTrue;
pub struct ArithmeticFalse;

impl ArithmeticOptIn for ArithmeticTrue {}
impl ArithmeticOptIn for ArithmeticFalse {}

/// A mapping from `Domain` to `Codomain`.
///
/// This is automatically implemented when the relevant [`Apply`] are implemented.
pub trait Mapping<Domain : Space> {
    type Codomain : Space;
    type ArithmeticOptIn : ArithmeticOptIn;

    /// Compute the value of `self` at `x`.
    fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain;

    #[inline]
    /// Form the composition `self ∘ other`
    fn compose<X : Space, T : Mapping<X, Codomain=Domain>>(self, other : T)
        -> Composition<Self, T>
    where
        Self : Sized
    {
        Composition{ outer : self, inner : other, intermediate_norm_exponent : () }
    }


    #[inline]
    /// Form the composition `self ∘ other`, assigning a norm to the inermediate space
    fn compose_with_norm<F, X, T, E>(
        self, other : T, norm : E
    )  -> Composition<Self, T, E>
    where
        Self : Sized,
        X : Space,
        T : Mapping<X, Codomain=Domain>,
        E : NormExponent,
        Domain : Norm<F, E>,
        F : Num
    {
        Composition{ outer : self, inner : other, intermediate_norm_exponent : norm }
    }

    /// Multiply `self` by the scalar `a`.
    #[inline]
    fn weigh<C>(self, a : C) -> Weighted<Self, C>
    where
        Self : Sized,
        C : Constant,
        Self::Codomain : ClosedMul<C::Type>,
    {
        Weighted { weight : a, base_fn : self }
    }
}

/// Automatically implemented shorthand for referring to [`Mapping`]s from [`Loc<F, N>`] to `F`.
pub trait RealMapping<F : Float, const N : usize>
: Mapping<Loc<F, N>, Codomain = F> {}

impl<F : Float, T, const N : usize> RealMapping<F, N> for T
where T : Mapping<Loc<F, N>, Codomain = F> {}

/// A helper trait alias for referring to [`Mapping`]s from [`Loc<F, N>`] to [`Loc<F, M>`].
pub trait RealVectorField<F : Float, const N : usize, const M : usize>
: Mapping<Loc<F, N>, Codomain = Loc<F, M>> {}

impl<F : Float, T, const N : usize, const M : usize> RealVectorField<F, N, M> for T
where T : Mapping<Loc<F, N>, Codomain = Loc<F, M>> {}

/// A differentiable mapping from `Domain` to [`Mapping::Codomain`], with differentials
/// `Differential`.
///
/// This is automatically implemented when [`DifferentiableImpl`] is.
pub trait DifferentiableMapping<Domain : Space> : Mapping<Domain> {
    type DerivativeDomain : Space;
    type Differential<'b> : Mapping<Domain, Codomain=Self::DerivativeDomain> where Self : 'b;

    /// Calculate differential at `x`
    fn differential<I : Instance<Domain>>(&self, x : I) -> Self::DerivativeDomain;

    /// Form the differential mapping of `self`.
    fn diff(self) -> Self::Differential<'static>;

    /// Form the differential mapping of `self`.
    fn diff_ref(&self) -> Self::Differential<'_>;
}

/// Automatically implemented shorthand for referring to differentiable [`Mapping`]s from
/// [`Loc<F, N>`] to `F`.
pub trait DifferentiableRealMapping<F : Float, const N : usize>
: DifferentiableMapping<Loc<F, N>, Codomain = F, DerivativeDomain = Loc<F, N>> {}

impl<F : Float, T, const N : usize> DifferentiableRealMapping<F, N> for T
where T : DifferentiableMapping<Loc<F, N>, Codomain = F, DerivativeDomain = Loc<F, N>> {}

/// Helper trait for implementing [`DifferentiableMapping`]
pub trait DifferentiableImpl<X : Space> : Sized {
    type Derivative : Space;

    /// Compute the differential of `self` at `x`, consuming the input.
    fn differential_impl<I : Instance<X>>(&self, x : I) -> Self::Derivative;
}

impl<T, Domain> DifferentiableMapping<Domain> for T
where
    Domain : Space,
    T : Clone + Mapping<Domain> + DifferentiableImpl<Domain>
{
    type DerivativeDomain = T::Derivative;
    type Differential<'b> = Differential<'b, Domain, Self> where Self : 'b;
    
    #[inline]
    fn differential<I : Instance<Domain>>(&self, x : I) -> Self::DerivativeDomain {
        self.differential_impl(x)
    }

    fn diff(self) -> Differential<'static, Domain, Self> {
        Differential{ g : Cow::Owned(self), _space : PhantomData }
    }

    fn diff_ref(&self) -> Differential<'_, Domain, Self> {
        Differential{ g : Cow::Borrowed(self), _space : PhantomData }
    }
}

/// A sum of [`Mapping`]s.
#[derive(Serialize, Debug, Clone)]
pub struct Sum<Domain, M> {
    components : Vec<M>,
    _domain : PhantomData<Domain>,
}

impl<Domain, M> Sum<Domain, M> {
    /// Construct from an iterator.
    pub fn new<I : Iterator<Item = M>>(iter : I) -> Self {
        Sum { components : iter.collect(), _domain : PhantomData }
    }

    /// Iterate over the component functions of the sum
    pub fn iter(&self) -> std::slice::Iter<'_, M> {
        self.components.iter()
    }
}


impl<Domain, M> Mapping<Domain> for Sum<Domain, M>
where
    Domain : Space + Clone,
    M : Mapping<Domain>,
    M::Codomain : std::iter::Sum + Clone
{
    type Codomain = M::Codomain;
    type ArithmeticOptIn = ArithmeticTrue;

    fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain {
        let xr = x.ref_instance();
        self.components.iter().map(|c| c.apply(xr)).sum()
    }
}

impl<Domain, M> DifferentiableImpl<Domain> for Sum<Domain, M>
where
    Domain : Space + Clone,
    M : DifferentiableMapping<Domain>,
    M :: DerivativeDomain : std::iter::Sum
{
    type Derivative = M::DerivativeDomain;

    fn differential_impl<I : Instance<Domain>>(&self, x : I) -> Self::Derivative {
        let xr = x.ref_instance();
        self.components.iter().map(|c| c.differential(xr)).sum()
    }
}

/// Container for the differential [`Mapping`] of a [`Differentiable`] mapping.
pub struct Differential<'a, X, G : Clone> {
    g : Cow<'a, G>,
    _space : PhantomData<X>
}

impl<'a, X, G : Clone> Differential<'a, X, G> {
    pub fn base_fn(&self) -> &G {
        &self.g
    }
}

impl<'a, X, G> Mapping<X> for Differential<'a, X, G>
where
    X : Space,
    G : Clone + DifferentiableMapping<X>
{
    type Codomain = G::DerivativeDomain;
    type ArithmeticOptIn = ArithmeticTrue;

    #[inline]
    fn apply<I : Instance<X>>(&self, x : I) -> Self::Codomain {
        (*self.g).differential(x)
    }
}

/// Container for flattening [`Loc`]`<F, 1>` codomain of a [`Mapping`] to `F`.
pub struct FlattenedCodomain<X, F, G> {
    g : G,
    _phantoms : PhantomData<(X, F)>
}

impl<F : Space, X, G> Mapping<X> for FlattenedCodomain<X, F, G>
where
    X : Space,
    G: Mapping<X, Codomain=Loc<F, 1>>
{
    type Codomain = F;
    type ArithmeticOptIn = ArithmeticTrue;

    #[inline]
    fn apply<I : Instance<X>>(&self, x : I) -> Self::Codomain {
        self.g.apply(x).flatten1d()
    }
}

/// An auto-trait for constructing a [`FlattenCodomain`] structure for
/// flattening the codomain of a [`Mapping`] from [`Loc`]`<F, 1>` to `F`.
pub trait FlattenCodomain<X : Space, F> : Mapping<X, Codomain=Loc<F, 1>> + Sized {
    /// Flatten the codomain from [`Loc`]`<F, 1>` to `F`.
    fn flatten_codomain(self) -> FlattenedCodomain<X, F, Self> {
        FlattenedCodomain{ g : self, _phantoms : PhantomData }
    }
}

impl<X : Space, F, G : Sized + Mapping<X, Codomain=Loc<F, 1>>> FlattenCodomain<X, F> for G {}

/// Container for dimensional slicing [`Loc`]`<F, N>` codomain of a [`Mapping`] to `F`.
pub struct SlicedCodomain<'a, X, F, G : Clone, const N : usize> {
    g : Cow<'a, G>,
    slice : usize,
    _phantoms : PhantomData<(X, F)>
}

impl<'a, X, F, G, const N : usize> Mapping<X> for SlicedCodomain<'a, X, F, G, N>
where
    X : Space,
    F : Copy + Space,
    G : Mapping<X, Codomain=Loc<F, N>> + Clone,
{
    type Codomain = F;
    type ArithmeticOptIn = ArithmeticTrue;

    #[inline]
    fn apply<I : Instance<X>>(&self, x : I) -> Self::Codomain {
        let tmp : [F; N] = (*self.g).apply(x).into();
        // Safety: `slice_codomain` below checks the range.
        unsafe { *tmp.get_unchecked(self.slice) }
    }
}

/// An auto-trait for constructing a [`FlattenCodomain`] structure for
/// flattening the codomain of a [`Mapping`] from [`Loc`]`<F, 1>` to `F`.
pub trait SliceCodomain<X : Space, F : Copy, const N : usize>
    : Mapping<X, Codomain=Loc<F, N>> + Clone + Sized
{
    /// Flatten the codomain from [`Loc`]`<F, 1>` to `F`.
    fn slice_codomain(self, slice : usize) -> SlicedCodomain<'static, X, F, Self, N> {
        assert!(slice < N);
        SlicedCodomain{ g : Cow::Owned(self), slice, _phantoms : PhantomData }
    }

    /// Flatten the codomain from [`Loc`]`<F, 1>` to `F`.
    fn slice_codomain_ref(&self, slice : usize) -> SlicedCodomain<'_, X, F, Self, N> {
        assert!(slice < N);
        SlicedCodomain{ g : Cow::Borrowed(self), slice, _phantoms : PhantomData }
    }
}

impl<X : Space, F : Copy, G : Sized + Mapping<X, Codomain=Loc<F, N>> + Clone, const N : usize>
SliceCodomain<X, F, N>
for G {}


/// The composition S ∘ T. `E` is for storing a `NormExponent` for the intermediate space.
pub struct Composition<S, T, E = ()> {
    pub outer : S,
    pub inner : T,
    pub intermediate_norm_exponent : E
}

impl<S, T, X, E> Mapping<X> for Composition<S, T, E>
where
    X : Space,
    T : Mapping<X>,
    S : Mapping<T::Codomain>
{
    type Codomain = S::Codomain;
    type ArithmeticOptIn = ArithmeticTrue;

    #[inline]
    fn apply<I : Instance<X>>(&self, x : I) -> Self::Codomain {
        self.outer.apply(self.inner.apply(x))
    }
}
