/*!
Direct products of the form $A \times B$.

TODO: This could be easily much more generic if `derive_more` could derive arithmetic
operations on references.
*/

use core::ops::{Mul,MulAssign,Div,DivAssign,Add,AddAssign,Sub,SubAssign,Neg};
use std::clone::Clone;
use serde::{Serialize, Deserialize};
use crate::types::{Num, Float};
use crate::{maybe_lifetime, maybe_ref};
use crate::euclidean::{Dot, Euclidean};
use crate::instance::{Instance, InstanceMut, Decomposition, DecompositionMut, MyCow};
use crate::mapping::Space;
use crate::linops::AXPY;
use crate::loc::Loc;
use crate::norms::{Norm, PairNorm, NormExponent, Normed, HasDual, L2};

#[derive(Debug,Clone,Copy,PartialEq,Eq,Serialize,Deserialize)]
pub struct Pair<A, B> (pub A, pub B);

impl<A, B> Pair<A,B> {
    pub fn new(a : A, b : B) -> Pair<A,B> { Pair(a, b) }
}

impl<A, B> From<(A,B)> for Pair<A,B> {
    #[inline]
    fn from((a, b) : (A, B)) -> Pair<A,B> { Pair(a, b) }
}

impl<A, B> From<Pair<A,B>> for (A,B) {
    #[inline]
    fn from(Pair(a, b) : Pair<A, B>) -> (A,B) { (a, b) }
}

macro_rules! impl_binop {
    (($a : ty, $b : ty), $trait : ident, $fn : ident, $refl:ident, $refr:ident) => {
        impl_binop!(@doit: $a, $b, $trait, $fn;
                           maybe_lifetime!($refl, &'l Pair<$a,$b>),
                           (maybe_lifetime!($refl, &'l $a),
                            maybe_lifetime!($refl, &'l $b));
                           maybe_lifetime!($refr, &'r Pair<Ai,Bi>),
                           (maybe_lifetime!($refr, &'r Ai),
                            maybe_lifetime!($refr, &'r Bi));
                           $refl, $refr);
    };

    (@doit: $a:ty, $b:ty,
            $trait:ident, $fn:ident;
            $self:ty, ($aself:ty, $bself:ty);
            $in:ty, ($ain:ty, $bin:ty);
            $refl:ident, $refr:ident) => {
        impl<'l, 'r, Ai, Bi> $trait<$in>
        for $self
        where $aself: $trait<$ain>,
              $bself: $trait<$bin> {
            type Output = Pair<<$aself as $trait<$ain>>::Output,
                               <$bself as $trait<$bin>>::Output>;

            #[inline]
            fn $fn(self, y : $in) -> Self::Output {
                Pair(maybe_ref!($refl, self.0).$fn(maybe_ref!($refr, y.0)),
                     maybe_ref!($refl, self.1).$fn(maybe_ref!($refr, y.1)))
            }
        }
    };
}

macro_rules! impl_assignop {
    (($a : ty, $b : ty), $trait : ident, $fn : ident, $refr:ident) => {
        impl_assignop!(@doit: $a, $b,
                              $trait, $fn;
                              maybe_lifetime!($refr, &'r Pair<Ai,Bi>),
                              (maybe_lifetime!($refr, &'r Ai),
                               maybe_lifetime!($refr, &'r Bi));
                              $refr);
    };
    (@doit: $a : ty, $b : ty,
            $trait:ident, $fn:ident;
            $in:ty, ($ain:ty, $bin:ty);
            $refr:ident) => {
        impl<'r, Ai, Bi> $trait<$in>
        for Pair<$a,$b>
        where $a: $trait<$ain>,
              $b: $trait<$bin> {
            #[inline]
            fn $fn(&mut self, y : $in) -> () {
                self.0.$fn(maybe_ref!($refr, y.0));
                self.1.$fn(maybe_ref!($refr, y.1));
            }
        }
    }
}

macro_rules! impl_scalarop {
    (($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident, $refl:ident) => {
        impl_scalarop!(@doit: $field,
                              $trait, $fn;
                              maybe_lifetime!($refl, &'l Pair<$a,$b>),
                              (maybe_lifetime!($refl, &'l $a),
                               maybe_lifetime!($refl, &'l $b));
                              $refl);
    };
    (@doit: $field : ty,
            $trait:ident, $fn:ident;
            $self:ty, ($aself:ty, $bself:ty);
            $refl:ident) => {
        // Scalar as Rhs
        impl<'l> $trait<$field>
        for $self
        where $aself: $trait<$field>,
              $bself: $trait<$field> {
            type Output = Pair<<$aself as $trait<$field>>::Output,
                               <$bself as $trait<$field>>::Output>;
            #[inline]
            fn $fn(self, a : $field) -> Self::Output {
                Pair(maybe_ref!($refl, self.0).$fn(a),
                     maybe_ref!($refl, self.1).$fn(a))
            }
        }
    }
}

// Not used due to compiler overflow
#[allow(unused_macros)]
macro_rules! impl_scalarlhs_op {
    (($a : ty, $b : ty), $field : ty, $trait:ident, $fn:ident, $refr:ident) => {
        impl_scalarlhs_op!(@doit: $trait, $fn,
                                  maybe_lifetime!($refr, &'r Pair<$a,$b>),
                                  (maybe_lifetime!($refr, &'r $a),
                                   maybe_lifetime!($refr, &'r $b));
                                  $refr, $field);
    };
    (@doit: $trait:ident, $fn:ident,
            $in:ty, ($ain:ty, $bin:ty);
            $refr:ident, $field:ty) => {
        impl<'r> $trait<$in>
        for $field
        where $field : $trait<$ain>
                     + $trait<$bin> {
            type Output = Pair<<$field as $trait<$ain>>::Output,
                               <$field as $trait<$bin>>::Output>;
            #[inline]
            fn $fn(self, x : $in) -> Self::Output {
                Pair(self.$fn(maybe_ref!($refr, x.0)),
                     self.$fn(maybe_ref!($refr, x.1)))
            }
        }
    };
}

macro_rules! impl_scalar_assignop {
    (($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident) => {
        impl<'r> $trait<$field>
        for Pair<$a, $b>
        where $a: $trait<$field>, $b: $trait<$field> {
            #[inline]
            fn $fn(&mut self, a : $field) -> () {
                self.0.$fn(a);
                self.1.$fn(a);
            }
        }
    }
}

macro_rules! impl_unaryop {
    (($a : ty, $b : ty), $trait:ident, $fn:ident, $refl:ident) => {
        impl_unaryop!(@doit: $trait, $fn;
                             maybe_lifetime!($refl, &'l Pair<$a,$b>),
                             (maybe_lifetime!($refl, &'l $a),
                              maybe_lifetime!($refl, &'l $b));
                             $refl);
    };
    (@doit: $trait:ident, $fn:ident;
            $self:ty, ($aself:ty, $bself:ty);
            $refl : ident) => {
        impl<'l> $trait
        for $self
        where $aself: $trait,
              $bself: $trait {
            type Output = Pair<<$aself as $trait>::Output,
                               <$bself as $trait>::Output>;
            #[inline]
            fn $fn(self) -> Self::Output {
                Pair(maybe_ref!($refl, self.0).$fn(),
                     maybe_ref!($refl, self.1).$fn())
            }
        }
    }
}

#[macro_export]
macro_rules! impl_pair_vectorspace_ops {
    (($a:ty, $b:ty), $field:ty) => {
        impl_pair_vectorspace_ops!(@binary, ($a, $b), Add, add);
        impl_pair_vectorspace_ops!(@binary, ($a, $b), Sub, sub);
        impl_pair_vectorspace_ops!(@assign, ($a, $b), AddAssign, add_assign);
        impl_pair_vectorspace_ops!(@assign, ($a, $b), SubAssign, sub_assign);
        impl_pair_vectorspace_ops!(@scalar, ($a, $b), $field, Mul, mul);
        impl_pair_vectorspace_ops!(@scalar, ($a, $b), $field, Div, div);
        // Compiler overflow
        // $(
        //     impl_pair_vectorspace_ops!(@scalar_lhs, ($a, $b), $field, $impl_scalarlhs_op, Mul, mul);
        // )*
        impl_pair_vectorspace_ops!(@scalar_assign, ($a, $b), $field, MulAssign, mul_assign);
        impl_pair_vectorspace_ops!(@scalar_assign, ($a, $b), $field, DivAssign, div_assign);
        impl_pair_vectorspace_ops!(@unary, ($a, $b), Neg, neg);
    };
    (@binary, ($a : ty, $b : ty), $trait : ident, $fn : ident) => {
        impl_binop!(($a, $b), $trait, $fn, ref, ref);
        impl_binop!(($a, $b), $trait, $fn, ref, noref);
        impl_binop!(($a, $b), $trait, $fn, noref, ref);
        impl_binop!(($a, $b), $trait, $fn, noref, noref);
    };
    (@assign, ($a : ty, $b : ty), $trait : ident, $fn :ident) => {
        impl_assignop!(($a, $b), $trait, $fn, ref);
        impl_assignop!(($a, $b), $trait, $fn, noref);
    };
    (@scalar, ($a : ty, $b : ty), $field : ty, $trait : ident, $fn :ident) => {
        impl_scalarop!(($a, $b), $field, $trait, $fn, ref);
        impl_scalarop!(($a, $b), $field, $trait, $fn, noref);
    };
    (@scalar_lhs, ($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident) => {
        impl_scalarlhs_op!(($a, $b), $field, $trait, $fn, ref);
        impl_scalarlhs_op!(($a, $b), $field, $trait, $fn, noref);
    };
    (@scalar_assign, ($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident) => {
        impl_scalar_assignop!(($a, $b), $field, $trait, $fn);
    };
    (@unary, ($a : ty, $b : ty), $trait : ident, $fn :  ident) => {
        impl_unaryop!(($a, $b), $trait, $fn, ref);
        impl_unaryop!(($a, $b), $trait, $fn, noref);
    };
}

impl_pair_vectorspace_ops!((f32, f32), f32);
impl_pair_vectorspace_ops!((f64, f64), f64);

impl<A, B, U, V, F> Dot<Pair<U, V>, F> for Pair<A, B>
where
    A : Dot<U, F>,
    B : Dot<V, F>,
    F : Num
{

    fn dot(&self, Pair(ref u, ref v) : &Pair<U, V>) -> F {
        self.0.dot(u) + self.1.dot(v)
    }
}

type PairOutput<F, A, B> = Pair<<A as Euclidean<F>>::Output, <B as Euclidean<F>>::Output>;

impl<A, B, F> Euclidean<F> for Pair<A, B>
where
    A : Euclidean<F>,
    B : Euclidean<F>,
    F : Float,
    PairOutput<F, A, B> : Euclidean<F>,
    Self : Sized + Dot<Self,F>
          + Mul<F, Output=PairOutput<F, A, B>> + MulAssign<F>
          + Div<F, Output=PairOutput<F, A, B>> + DivAssign<F>
          + Add<Self, Output=PairOutput<F, A, B>>
          + Sub<Self, Output=PairOutput<F, A, B>>
          + for<'b> Add<&'b Self, Output=PairOutput<F, A, B>>
          + for<'b> Sub<&'b Self, Output=PairOutput<F, A, B>>
          + AddAssign<Self> + for<'b> AddAssign<&'b Self>
          + SubAssign<Self> + for<'b> SubAssign<&'b Self>
          + Neg<Output=PairOutput<F, A, B>>
{
    type Output = PairOutput<F, A, B>;

    fn similar_origin(&self) -> PairOutput<F, A, B> {
        Pair(self.0.similar_origin(), self.1.similar_origin())
    }

    fn dist2_squared(&self, Pair(ref u, ref v) : &Self) -> F {
        self.0.dist2_squared(u) + self.1.dist2_squared(v)
    }
}

impl<F, A, B> AXPY for Pair<A, B>
where
    A : AXPY<Field=F>,
    B : AXPY<Field=F>,
    F : Num + AXPY
{
    type Field = F;

    fn axpy<I : Instance<Pair<A,B>>>(&mut self, α : F, x : I, β : F) {
        let Pair(u, v) = x.decompose();
        self.0.axpy(α, u, β);
        self.1.axpy(α, v, β);
    }

    fn copy_from<I : Instance<Pair<A,B>>>(&mut self, x : I) {
        let Pair(u, v) = x.decompose();
        self.0.copy_from(u);
        self.1.copy_from(v);
    }

    fn scale_from<I : Instance<Pair<A,B>>>(&mut self, α : F, x : I) {
        let Pair(u, v) = x.decompose();
        self.0.scale_from(α, u);
        self.1.scale_from(α, v);
    }
}

/// [`Decomposition`] for working with [`Pair`]s.
#[derive(Copy, Clone, Debug)]
pub struct PairDecomposition<D, Q>(D, Q);

impl<A : Space, B : Space> Space for Pair<A, B> {
    type Decomp = PairDecomposition<A::Decomp, B::Decomp>;
}

impl<A, B, D, Q> Decomposition<Pair<A, B>> for PairDecomposition<D,Q>
where
    A : Space,
    B : Space,
    D : Decomposition<A>,
    Q : Decomposition<B>,
{
    type Decomposition<'b> = Pair<D::Decomposition<'b>, Q::Decomposition<'b>> where Pair<A, B> : 'b;
    type Reference<'b> = Pair<D::Reference<'b>, Q::Reference<'b>> where Pair<A, B> : 'b;

    #[inline]
    fn lift<'b>(Pair(u, v) : Self::Reference<'b>) -> Self::Decomposition<'b> {
        Pair(D::lift(u), Q::lift(v))
    }
}

impl<A, B, U, V, D, Q> Instance<Pair<A, B>, PairDecomposition<D, Q>> for Pair<U, V>
where
    A : Space,
    B : Space,
    D : Decomposition<A>,
    Q : Decomposition<B>,
    U : Instance<A, D>,
    V : Instance<B, Q>,
{
    #[inline]
    fn decompose<'b>(self)
        -> <PairDecomposition<D, Q> as Decomposition<Pair<A, B>>>::Decomposition<'b>
    where Self : 'b, Pair<A, B> : 'b
    {
        Pair(self.0.decompose(), self.1.decompose())
    }

    #[inline]
    fn ref_instance(&self)
        -> <PairDecomposition<D, Q> as Decomposition<Pair<A, B>>>::Reference<'_>
    {
        Pair(self.0.ref_instance(), self.1.ref_instance())
    }

    #[inline]
    fn cow<'b>(self) -> MyCow<'b, Pair<A, B>> where Self : 'b{
        MyCow::Owned(Pair(self.0.own(), self.1.own()))
    }

    #[inline]
    fn own(self) -> Pair<A, B> {
        Pair(self.0.own(), self.1.own())
    }
}


impl<'a, A, B, U, V, D, Q> Instance<Pair<A, B>, PairDecomposition<D, Q>> for &'a Pair<U, V>
where
    A : Space,
    B : Space,
    D : Decomposition<A>,
    Q : Decomposition<B>,
    U : Instance<A, D>,
    V : Instance<B, Q>,
    &'a U : Instance<A, D>,
    &'a V : Instance<B, Q>,
{
    #[inline]
    fn decompose<'b>(self)
        -> <PairDecomposition<D, Q> as Decomposition<Pair<A, B>>>::Decomposition<'b>
    where Self : 'b, Pair<A, B> : 'b
    {
        Pair(D::lift(self.0.ref_instance()), Q::lift(self.1.ref_instance()))
    }

    #[inline]
    fn ref_instance(&self)
        -> <PairDecomposition<D, Q> as Decomposition<Pair<A, B>>>::Reference<'_>
    {
        Pair(self.0.ref_instance(), self.1.ref_instance())
    }

    #[inline]
    fn cow<'b>(self) -> MyCow<'b, Pair<A, B>> where Self : 'b {
        MyCow::Owned(self.own())
    }

    #[inline]
    fn own(self) -> Pair<A, B> {
        let Pair(ref u, ref v) = self;
        Pair(u.own(), v.own())
    }

}

impl<A, B, D, Q> DecompositionMut<Pair<A, B>> for PairDecomposition<D,Q>
where
    A : Space,
    B : Space,
    D : DecompositionMut<A>,
    Q : DecompositionMut<B>,
{
    type ReferenceMut<'b> = Pair<D::ReferenceMut<'b>, Q::ReferenceMut<'b>> where Pair<A, B> : 'b;
}

impl<A, B, U, V, D, Q> InstanceMut<Pair<A, B>, PairDecomposition<D, Q>> for Pair<U, V>
where
    A : Space,
    B : Space,
    D : DecompositionMut<A>,
    Q : DecompositionMut<B>,
    U : InstanceMut<A, D>,
    V : InstanceMut<B, Q>,
{
    #[inline]
    fn ref_instance_mut(&mut self)
        -> <PairDecomposition<D, Q> as DecompositionMut<Pair<A, B>>>::ReferenceMut<'_>
    {
        Pair(self.0.ref_instance_mut(), self.1.ref_instance_mut())
    }
}

impl<'a, A, B, U, V, D, Q> InstanceMut<Pair<A, B>, PairDecomposition<D, Q>> for &'a mut Pair<U, V>
where
    A : Space,
    B : Space,
    D : DecompositionMut<A>,
    Q : DecompositionMut<B>,
    U : InstanceMut<A, D>,
    V : InstanceMut<B, Q>,
{
    #[inline]
    fn ref_instance_mut(&mut self)
        -> <PairDecomposition<D, Q> as DecompositionMut<Pair<A, B>>>::ReferenceMut<'_>
    {
        Pair(self.0.ref_instance_mut(), self.1.ref_instance_mut())
    }
}


impl<F, A, B, ExpA, ExpB, ExpJ> Norm<F, PairNorm<ExpA, ExpB, ExpJ>>
for Pair<A,B>
where
    F : Num,
    ExpA : NormExponent,
    ExpB : NormExponent,
    ExpJ : NormExponent,
    A : Norm<F, ExpA>,
    B : Norm<F, ExpB>,
    Loc<F, 2> : Norm<F, ExpJ>,
{
    fn norm(&self, PairNorm(expa, expb, expj) : PairNorm<ExpA, ExpB, ExpJ>) -> F {
        Loc([self.0.norm(expa), self.1.norm(expb)]).norm(expj)
    }
}


impl<F : Float, A, B> Normed<F> for Pair<A,B>
where
    A : Normed<F>,
    B : Normed<F>,
{
    type NormExp = PairNorm<A::NormExp, B::NormExp, L2>;

    #[inline]
    fn norm_exponent(&self) -> Self::NormExp {
        PairNorm(self.0.norm_exponent(), self.1.norm_exponent(), L2)
    }

    #[inline]
    fn is_zero(&self) -> bool {
        self.0.is_zero() && self.1.is_zero()
    }
}

impl<F : Float, A, B> HasDual<F> for Pair<A,B>
where
    A : HasDual<F>,
    B : HasDual<F>,

{
    type DualSpace = Pair<A::DualSpace, B::DualSpace>;
}
