/*!
Abstract linear operators.
*/

use numeric_literals::replace_float_literals;
use std::marker::PhantomData;
use crate::types::*;
use serde::Serialize;
pub use crate::mapping::Apply;
use crate::direct_product::Pair;

/// Trait for linear operators on `X`.
pub trait Linear<X> : Apply<X, Output=Self::Codomain>
                      + for<'a> Apply<&'a X, Output=Self::Codomain> {
    type Codomain;
}

/// Efficient in-place summation.
#[replace_float_literals(F::cast_from(literal))]
pub trait AXPY<F : Num, X = Self> {
    /// Computes  `y = βy + αx`, where `y` is `Self`.
    fn axpy(&mut self, α : F, x : &X, β : F);

    /// Copies `x` to `self`.
    fn copy_from(&mut self, x : &X) {
        self.axpy(1.0, x, 0.0)
    }

    /// Computes  `y = αx`, where `y` is `Self`.
    fn scale_from(&mut self, α : F, x : &X) {
        self.axpy(α, x, 0.0)
    }
}

/// Efficient in-place application for [`Linear`] operators.
#[replace_float_literals(F::cast_from(literal))]
pub trait GEMV<F : Num, X, Y = <Self as Linear<X>>::Codomain> : Linear<X> {
    /// Computes  `y = αAx + βy`, where `A` is `Self`.
    fn gemv(&self, y : &mut Y, α : F, x : &X, β : F);

    /// Computes `y = Ax`, where `A` is `Self`
    fn apply_mut(&self, y : &mut Y, x : &X){
        self.gemv(y, 1.0, x, 0.0)
    }

    /// Computes `y += Ax`, where `A` is `Self`
    fn apply_add(&self, y : &mut Y, x : &X){
        self.gemv(y, 1.0, x, 1.0)
    }
}


/// Bounded linear operators
pub trait BoundedLinear<X> : Linear<X> {
    type FloatType : Float;
    /// A bound on the operator norm $\|A\|$ for the linear operator $A$=`self`.
    /// This is not expected to be the norm, just any bound on it that can be
    /// reasonably implemented.
    fn opnorm_bound(&self) -> Self::FloatType;
}

// Linear operator application into mutable target. The [`AsRef`] bound
// is used to guarantee compatibility with `Yʹ` and `Self::Codomain`;
// the former is assumed to be e.g. a view into the latter.

/*impl<X,Y,T> Fn(&X) -> Y for T where T : Linear<X,Codomain=Y> {
    fn call(&self, x : &X) -> Y {
        self.apply(x)
    }
}*/

/// Trait for forming the adjoint operator of `Self`.
pub trait Adjointable<X,Yʹ> : Linear<X> {
    type AdjointCodomain;
    type Adjoint<'a> : Linear<Yʹ, Codomain=Self::AdjointCodomain> where Self : 'a;

    /// Form the adjoint operator of `self`.
    fn adjoint(&self) -> Self::Adjoint<'_>;

    /*fn adjoint_apply(&self, y : &Yʹ) -> Self::AdjointCodomain {
        self.adjoint().apply(y)
    }*/
}

/// Trait for forming a preadjoint of an operator.
///
/// For an operator $A$ this is an operator $A\_\*$
/// such that its adjoint $(A\_\*)^\*=A$. The space `X` is the domain of the `Self`
/// operator. The space `Ypre` is the predual of its codomain, and should be the
/// domain of the adjointed operator. `Self::Preadjoint` should be
/// [`Adjointable`]`<'a,Ypre,X>`.
pub trait Preadjointable<X,Ypre> : Linear<X> {
    type PreadjointCodomain;
    type Preadjoint<'a> : Adjointable<Ypre, X, Codomain=Self::PreadjointCodomain> where Self : 'a;

    /// Form the preadjoint operator of `self`.
    fn preadjoint(&self) -> Self::Preadjoint<'_>;
}

/// Adjointable operators $A: X → Y$ between reflexive spaces $X$ and $Y$.
pub trait SimplyAdjointable<X> : Adjointable<X,<Self as Linear<X>>::Codomain> {}
impl<'a,X,T> SimplyAdjointable<X> for T where T : Adjointable<X,<Self as Linear<X>>::Codomain> {}

/// The identity operator
#[derive(Clone,Copy,Debug,Serialize,Eq,PartialEq)]
pub struct IdOp<X> (PhantomData<X>);

impl<X> IdOp<X> {
    fn new() -> IdOp<X> { IdOp(PhantomData) }
}

impl<X> Apply<X> for IdOp<X> {
    type Output = X;

    fn apply(&self, x : X) -> X {
        x
    }
}

impl<'a, X> Apply<&'a X> for IdOp<X> where X : Clone {
    type Output = X;
    
    fn apply(&self, x : &'a X) -> X {
        x.clone()
    }
}

impl<X> Linear<X> for IdOp<X> where X : Clone {
    type Codomain = X;
}


#[replace_float_literals(F::cast_from(literal))]
impl<F : Num, X, Y> GEMV<F, X, Y> for IdOp<X> where Y : AXPY<F, X>, X : Clone {
    // Computes  `y = αAx + βy`, where `A` is `Self`.
    fn gemv(&self, y : &mut Y, α : F, x : &X, β : F) {
        y.axpy(α, x, β)
    }

    fn apply_mut(&self, y : &mut Y, x : &X){
        y.copy_from(x);
    }
}

impl<X> BoundedLinear<X> for IdOp<X> where X : Clone {
    type FloatType = float;
    fn opnorm_bound(&self) -> float { 1.0 }
}

impl<X> Adjointable<X,X> for IdOp<X> where X : Clone {
    type AdjointCodomain=X;
    type Adjoint<'a> = IdOp<X> where X : 'a;
    fn adjoint(&self) -> Self::Adjoint<'_> { IdOp::new() }
}

/// “Row operator” $(S, T)$; $(S, T)(x, y)=Sx + Ty$.
pub struct RowOp<S, T>(pub S, pub T);

use std::ops::Add;

impl<A, B, S, T> Apply<Pair<A, B>> for RowOp<S, T>
where
    S : Apply<A>,
    T : Apply<B>,
    S::Output : Add<T::Output>
{
    type Output = <S::Output as Add<T::Output>>::Output;

    fn apply(&self, Pair(a, b) : Pair<A, B>) -> Self::Output {
        self.0.apply(a) + self.1.apply(b)
    }
}

impl<'a, A, B, S, T> Apply<&'a Pair<A, B>> for RowOp<S, T>
where
    S : Apply<&'a A>,
    T : Apply<&'a B>,
    S::Output : Add<T::Output>
{
    type Output = <S::Output as Add<T::Output>>::Output;

    fn apply(&self, Pair(ref a, ref b) : &'a Pair<A, B>) -> Self::Output {
        self.0.apply(a) + self.1.apply(b)
    }
}

impl<A, B, S, T, D> Linear<Pair<A, B>> for RowOp<S, T>
where
    RowOp<S, T> : Apply<Pair<A, B>, Output=D> + for<'a>  Apply<&'a Pair<A, B>, Output=D>,
{
    type Codomain = D;
}

impl<F, S, T, Y, U, V> GEMV<F, Pair<U, V>, Y> for RowOp<S, T>
where
    S : GEMV<F, U, Y>,
    T : GEMV<F, V, Y>,
    F : Num,
    Self : Linear<Pair<U, V>, Codomain=Y>
{
    fn gemv(&self, y : &mut Y, α : F, x : &Pair<U, V>, β : F) {
        self.0.gemv(y, α, &x.0, β);
        self.1.gemv(y, α, &x.1, β);
    }

    fn apply_mut(&self, y : &mut Y, x : &Pair<U, V>){
        self.0.apply_mut(y, &x.0);
        self.1.apply_mut(y, &x.1);
    }

    /// Computes `y += Ax`, where `A` is `Self`
    fn apply_add(&self, y : &mut Y, x : &Pair<U, V>){
        self.0.apply_add(y, &x.0);
        self.1.apply_add(y, &x.1);
    }
}


/// “Column operator” $(S; T)$; $(S; T)x=(Sx, Tx)$.
pub struct ColOp<S, T>(pub S, pub T);

impl<A, S, T, O> Apply<A> for ColOp<S, T>
where
    S : for<'a> Apply<&'a A, Output=O>,
    T : Apply<A>,
    A : std::borrow::Borrow<A>,
{
    type Output = Pair<O, T::Output>;

    fn apply(&self, a : A) -> Self::Output {
        Pair(self.0.apply(a.borrow()), self.1.apply(a))
    }
}

impl<A, S, T, D> Linear<A> for ColOp<S, T>
where
    ColOp<S, T> : Apply<A, Output=D> + for<'a>  Apply<&'a A, Output=D>,
{
    type Codomain = D;
}

impl<F, S, T, A, B, X> GEMV<F, X, Pair<A, B>> for ColOp<S, T>
where
    S : GEMV<F, X, A>,
    T : GEMV<F, X, B>,
    F : Num,
    Self : Linear<X, Codomain=Pair<A, B>>
{
    fn gemv(&self, y : &mut Pair<A, B>, α : F, x : &X, β : F) {
        self.0.gemv(&mut y.0, α, x, β);
        self.1.gemv(&mut y.1, α, x, β);
    }

    fn apply_mut(&self, y : &mut Pair<A, B>, x : &X){
        self.0.apply_mut(&mut y.0, x);
        self.1.apply_mut(&mut y.1, x);
    }

    /// Computes `y += Ax`, where `A` is `Self`
    fn apply_add(&self, y : &mut Pair<A, B>, x : &X){
        self.0.apply_add(&mut y.0, x);
        self.1.apply_add(&mut y.1, x);
    }
}


impl<A, B, Yʹ, R, S, T> Adjointable<Pair<A,B>,Yʹ> for RowOp<S, T>
where
    S : Adjointable<A, Yʹ>,
    T : Adjointable<B, Yʹ>,
    Self : Linear<Pair<A, B>>,
    for<'a> ColOp<S::Adjoint<'a>, T::Adjoint<'a>> : Linear<Yʹ, Codomain=R>,
{
    type AdjointCodomain = R;
    type Adjoint<'a> = ColOp<S::Adjoint<'a>, T::Adjoint<'a>> where Self : 'a;

    fn adjoint(&self) -> Self::Adjoint<'_> {
        ColOp(self.0.adjoint(), self.1.adjoint())
    }
}

impl<A, B, Yʹ, R, S, T> Preadjointable<Pair<A,B>,Yʹ> for RowOp<S, T>
where
    S : Preadjointable<A, Yʹ>,
    T : Preadjointable<B, Yʹ>,
    Self : Linear<Pair<A, B>>,
    for<'a> ColOp<S::Preadjoint<'a>, T::Preadjoint<'a>> : Adjointable<Yʹ, Pair<A,B>, Codomain=R>,
{
    type PreadjointCodomain = R;
    type Preadjoint<'a> = ColOp<S::Preadjoint<'a>, T::Preadjoint<'a>> where Self : 'a;

    fn preadjoint(&self) -> Self::Preadjoint<'_> {
        ColOp(self.0.preadjoint(), self.1.preadjoint())
    }
}


impl<A, Xʹ, Yʹ, R, S, T> Adjointable<A,Pair<Xʹ,Yʹ>> for ColOp<S, T>
where
    S : Adjointable<A, Xʹ>,
    T : Adjointable<A, Yʹ>,
    Self : Linear<A>,
    for<'a> RowOp<S::Adjoint<'a>, T::Adjoint<'a>> : Linear<Pair<Xʹ,Yʹ>, Codomain=R>,
{
    type AdjointCodomain = R;
    type Adjoint<'a> = RowOp<S::Adjoint<'a>, T::Adjoint<'a>> where Self : 'a;

    fn adjoint(&self) -> Self::Adjoint<'_> {
        RowOp(self.0.adjoint(), self.1.adjoint())
    }
}

impl<A, Xʹ, Yʹ, R, S, T> Preadjointable<A,Pair<Xʹ,Yʹ>> for ColOp<S, T>
where
    S : Preadjointable<A, Xʹ>,
    T : Preadjointable<A, Yʹ>,
    Self : Linear<A>,
    for<'a> RowOp<S::Preadjoint<'a>, T::Preadjoint<'a>> : Adjointable<Pair<Xʹ,Yʹ>, A, Codomain=R>,
{
    type PreadjointCodomain = R;
    type Preadjoint<'a> = RowOp<S::Preadjoint<'a>, T::Preadjoint<'a>> where Self : 'a;

    fn preadjoint(&self) -> Self::Preadjoint<'_> {
        RowOp(self.0.preadjoint(), self.1.preadjoint())
    }
}

/// Diagonal operator
pub struct DiagOp<S, T>(pub S, pub T);

impl<A, B, S, T> Apply<Pair<A, B>> for DiagOp<S, T>
where
    S : Apply<A>,
    T : Apply<B>,
{
    type Output = Pair<S::Output, T::Output>;

    fn apply(&self, Pair(a, b) : Pair<A, B>) -> Self::Output {
        Pair(self.0.apply(a), self.1.apply(b))
    }
}

impl<'a, A, B, S, T> Apply<&'a Pair<A, B>> for DiagOp<S, T>
where
    S : Apply<&'a A>,
    T : Apply<&'a B>,
{
    type Output = Pair<S::Output, T::Output>;

    fn apply(&self, Pair(ref a, ref b) : &'a Pair<A, B>) -> Self::Output {
        Pair(self.0.apply(a), self.1.apply(b))
    }
}

impl<A, B, S, T, D> Linear<Pair<A, B>> for DiagOp<S, T>
where
    DiagOp<S, T> : Apply<Pair<A, B>, Output=D> + for<'a>  Apply<&'a Pair<A, B>, Output=D>,
{
    type Codomain = D;
}

impl<F, S, T, A, B, U, V> GEMV<F, Pair<U, V>, Pair<A, B>> for DiagOp<S, T>
where
    S : GEMV<F, U, A>,
    T : GEMV<F, V, B>,
    F : Num,
    Self : Linear<Pair<U, V>, Codomain=Pair<A, B>>
{
    fn gemv(&self, y : &mut Pair<A, B>, α : F, x : &Pair<U, V>, β : F) {
        self.0.gemv(&mut y.0, α, &x.0, β);
        self.1.gemv(&mut y.1, α, &x.1, β);
    }

    fn apply_mut(&self, y : &mut Pair<A, B>, x : &Pair<U, V>){
        self.0.apply_mut(&mut y.0, &x.0);
        self.1.apply_mut(&mut y.1, &x.1);
    }

    /// Computes `y += Ax`, where `A` is `Self`
    fn apply_add(&self, y : &mut Pair<A, B>, x : &Pair<U, V>){
        self.0.apply_add(&mut y.0, &x.0);
        self.1.apply_add(&mut y.1, &x.1);
    }
}

impl<A, B, Xʹ, Yʹ, R, S, T> Adjointable<Pair<A,B>,Pair<Xʹ,Yʹ>> for DiagOp<S, T>
where
    S : Adjointable<A, Xʹ>,
    T : Adjointable<B, Yʹ>,
    Self : Linear<Pair<A, B>>,
    for<'a> DiagOp<S::Adjoint<'a>, T::Adjoint<'a>> : Linear<Pair<Xʹ,Yʹ>, Codomain=R>,
{
    type AdjointCodomain = R;
    type Adjoint<'a> = DiagOp<S::Adjoint<'a>, T::Adjoint<'a>> where Self : 'a;

    fn adjoint(&self) -> Self::Adjoint<'_> {
        DiagOp(self.0.adjoint(), self.1.adjoint())
    }
}

impl<A, B, Xʹ, Yʹ, R, S, T> Preadjointable<Pair<A,B>,Pair<Xʹ,Yʹ>> for DiagOp<S, T>
where
    S : Preadjointable<A, Xʹ>,
    T : Preadjointable<B, Yʹ>,
    Self : Linear<Pair<A, B>>,
    for<'a> DiagOp<S::Preadjoint<'a>, T::Preadjoint<'a>> : Adjointable<Pair<Xʹ,Yʹ>, Pair<A, B>, Codomain=R>,
{
    type PreadjointCodomain = R;
    type Preadjoint<'a> = DiagOp<S::Preadjoint<'a>, T::Preadjoint<'a>> where Self : 'a;

    fn preadjoint(&self) -> Self::Preadjoint<'_> {
        DiagOp(self.0.preadjoint(), self.1.preadjoint())
    }
}

/// Block operator
pub type BlockOp<S11, S12, S21, S22> = ColOp<RowOp<S11, S12>, RowOp<S21, S22>>;

