/*!
Abstract linear operators.
*/

use numeric_literals::replace_float_literals;
use std::marker::PhantomData;
use serde::Serialize;
use crate::types::*;
pub use crate::mapping::{Mapping, Space, Composition};
use crate::direct_product::Pair;
use crate::instance::Instance;
use crate::norms::{NormExponent, PairNorm, L1, L2, Linfinity, Norm};

/// Trait for linear operators on `X`.
pub trait Linear<X : Space> : Mapping<X>
{ }

/// Efficient in-place summation.
#[replace_float_literals(F::cast_from(literal))]
pub trait AXPY<F, X = Self> : Space + std::ops::MulAssign<F>
where
    F : Num,
    X : Space,
{
    type Owned : AXPY<F, X>;

    /// Computes  `y = βy + αx`, where `y` is `Self`.
    fn axpy<I : Instance<X>>(&mut self, α : F, x : I, β : F);

    /// Copies `x` to `self`.
    fn copy_from<I : Instance<X>>(&mut self, x : I) {
        self.axpy(1.0, x, 0.0)
    }

    /// Computes  `y = αx`, where `y` is `Self`.
    fn scale_from<I : Instance<X>>(&mut self, α : F, x : I) {
        self.axpy(α, x, 0.0)
    }

    /// Return a similar zero as `self`.
    fn similar_origin(&self) -> Self::Owned;

    /// Set self to zero.
    fn set_zero(&mut self);
}

/// Efficient in-place application for [`Linear`] operators.
#[replace_float_literals(F::cast_from(literal))]
pub trait GEMV<F : Num, X : Space, Y = <Self as Mapping<X>>::Codomain> : Linear<X> {
    /// Computes  `y = αAx + βy`, where `A` is `Self`.
    fn gemv<I : Instance<X>>(&self, y : &mut Y, α : F, x : I, β : F);

    /// Computes `y = Ax`, where `A` is `Self`
    fn apply_mut<I : Instance<X>>(&self, y : &mut Y, x : I){
        self.gemv(y, 1.0, x, 0.0)
    }

    /// Computes `y += Ax`, where `A` is `Self`
    fn apply_add<I : Instance<X>>(&self, y : &mut Y, x : I){
        self.gemv(y, 1.0, x, 1.0)
    }
}


/// Bounded linear operators
pub trait BoundedLinear<X, XExp, CodExp, F = f64> : Linear<X>
where
    F : Num,
    X : Space + Norm<F, XExp>,
    XExp : NormExponent,
    CodExp : NormExponent
{
    /// A bound on the operator norm $\|A\|$ for the linear operator $A$=`self`.
    /// This is not expected to be the norm, just any bound on it that can be
    /// reasonably implemented. The [`NormExponent`] `xexp`  indicates the norm
    /// in `X`, and `codexp` in the codomain.
    fn opnorm_bound(&self, xexp : XExp, codexp : CodExp) -> F;
}

// Linear operator application into mutable target. The [`AsRef`] bound
// is used to guarantee compatibility with `Yʹ` and `Self::Codomain`;
// the former is assumed to be e.g. a view into the latter.

/*impl<X,Y,T> Fn(&X) -> Y for T where T : Linear<X,Codomain=Y> {
    fn call(&self, x : &X) -> Y {
        self.apply(x)
    }
}*/

/// Trait for forming the adjoint operator of `Self`.
pub trait Adjointable<X, Yʹ> : Linear<X>
where
    X : Space,
    Yʹ : Space,
{
    type AdjointCodomain : Space;
    type Adjoint<'a> : Linear<Yʹ, Codomain=Self::AdjointCodomain> where Self : 'a;

    /// Form the adjoint operator of `self`.
    fn adjoint(&self) -> Self::Adjoint<'_>;
}

/// Trait for forming a preadjoint of an operator.
///
/// For an operator $A$ this is an operator $A\_\*$
/// such that its adjoint $(A\_\*)^\*=A$. The space `X` is the domain of the `Self`
/// operator. The space `Ypre` is the predual of its codomain, and should be the
/// domain of the adjointed operator. `Self::Preadjoint` should be
/// [`Adjointable`]`<'a,Ypre,X>`.
/// We do not make additional restrictions on `Self::Preadjoint` (in particular, it
/// does not have to be adjointable) to allow `X` to be a subspace yet the preadjoint
/// have the full space as the codomain, etc.
pub trait Preadjointable<X : Space, Ypre : Space> : Linear<X> {
    type PreadjointCodomain : Space;
    type Preadjoint<'a> : Linear<
        Ypre, Codomain=Self::PreadjointCodomain
    > where Self : 'a;

    /// Form the adjoint operator of `self`.
    fn preadjoint(&self) -> Self::Preadjoint<'_>;
}

/// Adjointable operators $A: X → Y$ between reflexive spaces $X$ and $Y$.
pub trait SimplyAdjointable<X : Space> : Adjointable<X,<Self as Mapping<X>>::Codomain> {}
impl<'a,X : Space, T> SimplyAdjointable<X> for T
where T : Adjointable<X,<Self as Mapping<X>>::Codomain> {}

/// The identity operator
#[derive(Clone,Copy,Debug,Serialize,Eq,PartialEq)]
pub struct IdOp<X> (PhantomData<X>);

impl<X> IdOp<X> {
    pub fn new() -> IdOp<X> { IdOp(PhantomData) }
}

impl<X : Clone + Space> Mapping<X> for IdOp<X> {
    type Codomain = X;

    fn apply<I : Instance<X>>(&self, x : I) -> X {
        x.own()
    }
}

impl<X : Clone + Space> Linear<X> for IdOp<X>
{ }

#[replace_float_literals(F::cast_from(literal))]
impl<F : Num, X, Y> GEMV<F, X, Y> for IdOp<X>
where
    Y : AXPY<F, X>,
    X : Clone + Space
{
    // Computes  `y = αAx + βy`, where `A` is `Self`.
    fn gemv<I : Instance<X>>(&self, y : &mut Y, α : F, x : I, β : F) {
        y.axpy(α, x, β)
    }

    fn apply_mut<I : Instance<X>>(&self, y : &mut Y, x : I){
        y.copy_from(x);
    }
}

impl<F, X, E> BoundedLinear<X, E, E, F> for IdOp<X>
where
    X : Space + Clone + Norm<F, E>,
    F : Num,
    E : NormExponent
{
    fn opnorm_bound(&self, _xexp : E, _codexp : E) -> F { F::ONE }
}

impl<X : Clone + Space> Adjointable<X,X> for IdOp<X> {
    type AdjointCodomain=X;
    type Adjoint<'a> = IdOp<X> where X : 'a;

    fn adjoint(&self) -> Self::Adjoint<'_> { IdOp::new() }
}

impl<X : Clone + Space> Preadjointable<X,X> for IdOp<X> {
    type PreadjointCodomain=X;
    type Preadjoint<'a> = IdOp<X> where X : 'a;

    fn preadjoint(&self) -> Self::Preadjoint<'_> { IdOp::new() }
}


impl<S, T, E, X> Linear<X> for Composition<S, T, E>
where
    X : Space,
    T : Linear<X>,
    S : Linear<T::Codomain>
{ }

impl<F, S, T, E, X, Y> GEMV<F, X, Y> for Composition<S, T, E>
where
    F : Num,
    X : Space,
    T : Linear<X>,
    S : GEMV<F, T::Codomain, Y>,
{
    fn gemv<I : Instance<X>>(&self, y : &mut Y, α : F, x : I, β : F) {
        self.outer.gemv(y, α, self.inner.apply(x), β)
    }

    /// Computes `y = Ax`, where `A` is `Self`
    fn apply_mut<I : Instance<X>>(&self, y : &mut Y, x : I){
        self.outer.apply_mut(y, self.inner.apply(x))
    }

    /// Computes `y += Ax`, where `A` is `Self`
    fn apply_add<I : Instance<X>>(&self, y : &mut Y, x : I){
        self.outer.apply_add(y, self.inner.apply(x))
    }
}

impl<F, S, T, X, Z, Xexp, Yexp, Zexp> BoundedLinear<X, Xexp, Yexp, F> for Composition<S, T, Zexp>
where
    F : Num,
    X : Space + Norm<F, Xexp>,
    Z : Space + Norm<F, Zexp>,
    Xexp : NormExponent,
    Yexp : NormExponent,
    Zexp : NormExponent,
    T : BoundedLinear<X, Xexp, Zexp, F, Codomain=Z>,
    S : BoundedLinear<Z, Zexp, Yexp, F>,
{
    fn opnorm_bound(&self, xexp : Xexp, yexp : Yexp) -> F {
        let zexp = self.intermediate_norm_exponent;
        self.outer.opnorm_bound(zexp, yexp) * self.inner.opnorm_bound(xexp, zexp)
    }
}

/// “Row operator” $(S, T)$; $(S, T)(x, y)=Sx + Ty$.
pub struct RowOp<S, T>(pub S, pub T);

use std::ops::Add;

impl<A, B, S, T> Mapping<Pair<A, B>> for RowOp<S, T>
where
    A : Space,
    B : Space,
    S : Mapping<A>,
    T : Mapping<B>,
    S::Codomain : Add<T::Codomain>,
    <S::Codomain as Add<T::Codomain>>::Output : Space,

{
    type Codomain = <S::Codomain as Add<T::Codomain>>::Output;

    fn apply<I : Instance<Pair<A, B>>>(&self, x : I) -> Self::Codomain {
        let Pair(a, b) = x.decompose();
        self.0.apply(a) + self.1.apply(b)
    }
}

impl<A, B, S, T> Linear<Pair<A, B>> for RowOp<S, T>
where
    A : Space,
    B : Space,
    S : Linear<A>,
    T : Linear<B>,
    S::Codomain : Add<T::Codomain>,
    <S::Codomain as Add<T::Codomain>>::Output : Space,
{ }


impl<'b, F, S, T, Y, U, V> GEMV<F, Pair<U, V>, Y> for RowOp<S, T>
where
    U : Space,
    V : Space,
    S : GEMV<F, U, Y>,
    T : GEMV<F, V, Y>,
    F : Num,
    Self : Linear<Pair<U, V>, Codomain=Y>
{
    fn gemv<I : Instance<Pair<U, V>>>(&self, y : &mut Y, α : F, x : I, β : F) {
        let Pair(u, v) = x.decompose();
        self.0.gemv(y, α, u, β);
        self.1.gemv(y, α, v, F::ONE);
    }

    fn apply_mut<I : Instance<Pair<U, V>>>(&self, y : &mut Y, x : I) {
        let Pair(u, v) = x.decompose();
        self.0.apply_mut(y, u);
        self.1.apply_mut(y, v);
    }

    /// Computes `y += Ax`, where `A` is `Self`
    fn apply_add<I : Instance<Pair<U, V>>>(&self, y : &mut Y, x : I) {
        let Pair(u, v) = x.decompose();
        self.0.apply_add(y, u);
        self.1.apply_add(y, v);
    }
}

/// “Column operator” $(S; T)$; $(S; T)x=(Sx, Tx)$.
pub struct ColOp<S, T>(pub S, pub T);

impl<A, S, T> Mapping<A> for ColOp<S, T>
where
    A : Space,
    S : Mapping<A>,
    T : Mapping<A>,
{
    type Codomain = Pair<S::Codomain, T::Codomain>;

    fn apply<I : Instance<A>>(&self, a : I) -> Self::Codomain {
        Pair(self.0.apply(a.ref_instance()), self.1.apply(a))
    }
}

impl<A, S, T> Linear<A> for ColOp<S, T>
where
    A : Space,
    S : Mapping<A>,
    T : Mapping<A>,
{ }

impl<F, S, T, A, B, X> GEMV<F, X, Pair<A, B>> for ColOp<S, T>
where
    X : Space,
    S : GEMV<F, X, A>,
    T : GEMV<F, X, B>,
    F : Num,
    Self : Linear<X, Codomain=Pair<A, B>>
{
    fn gemv<I : Instance<X>>(&self, y : &mut Pair<A, B>, α : F, x : I, β : F) {
        self.0.gemv(&mut y.0, α, x.ref_instance(), β);
        self.1.gemv(&mut y.1, α, x, β);
    }

    fn apply_mut<I : Instance<X>>(&self, y : &mut Pair<A, B>, x : I){
        self.0.apply_mut(&mut y.0, x.ref_instance());
        self.1.apply_mut(&mut y.1, x);
    }

    /// Computes `y += Ax`, where `A` is `Self`
    fn apply_add<I : Instance<X>>(&self, y : &mut Pair<A, B>, x : I){
        self.0.apply_add(&mut y.0, x.ref_instance());
        self.1.apply_add(&mut y.1, x);
    }
}


impl<A, B, Yʹ, S, T> Adjointable<Pair<A,B>, Yʹ> for RowOp<S, T>
where
    A : Space,
    B : Space,
    Yʹ : Space,
    S : Adjointable<A, Yʹ>,
    T : Adjointable<B, Yʹ>,
    Self : Linear<Pair<A, B>>,
    // for<'a> ColOp<S::Adjoint<'a>, T::Adjoint<'a>> : Linear<
    //     Yʹ,
    //     Codomain=Pair<S::AdjointCodomain, T::AdjointCodomain>
    // >,
{
    type AdjointCodomain = Pair<S::AdjointCodomain, T::AdjointCodomain>;
    type Adjoint<'a> = ColOp<S::Adjoint<'a>, T::Adjoint<'a>> where Self : 'a;

    fn adjoint(&self) -> Self::Adjoint<'_> {
        ColOp(self.0.adjoint(), self.1.adjoint())
    }
}

impl<A, B, Yʹ, S, T> Preadjointable<Pair<A,B>, Yʹ> for RowOp<S, T>
where
    A : Space,
    B : Space,
    Yʹ : Space,
    S : Preadjointable<A, Yʹ>,
    T : Preadjointable<B, Yʹ>,
    Self : Linear<Pair<A, B>>,
    for<'a> ColOp<S::Preadjoint<'a>, T::Preadjoint<'a>> : Linear<
        Yʹ, Codomain=Pair<S::PreadjointCodomain, T::PreadjointCodomain>,
    >,
{
    type PreadjointCodomain = Pair<S::PreadjointCodomain, T::PreadjointCodomain>;
    type Preadjoint<'a> = ColOp<S::Preadjoint<'a>, T::Preadjoint<'a>> where Self : 'a;

    fn preadjoint(&self) -> Self::Preadjoint<'_> {
        ColOp(self.0.preadjoint(), self.1.preadjoint())
    }
}


impl<A, Xʹ, Yʹ, R, S, T> Adjointable<A,Pair<Xʹ,Yʹ>> for ColOp<S, T>
where
    A : Space,
    Xʹ : Space,
    Yʹ : Space,
    R : Space + ClosedAdd,
    S : Adjointable<A, Xʹ, AdjointCodomain = R>,
    T : Adjointable<A, Yʹ, AdjointCodomain = R>,
    Self : Linear<A>,
    // for<'a> RowOp<S::Adjoint<'a>, T::Adjoint<'a>> : Linear<
    //     Pair<Xʹ,Yʹ>,
    //     Codomain=R,
    // >,
{
    type AdjointCodomain = R;
    type Adjoint<'a> = RowOp<S::Adjoint<'a>, T::Adjoint<'a>> where Self : 'a;

    fn adjoint(&self) -> Self::Adjoint<'_> {
        RowOp(self.0.adjoint(), self.1.adjoint())
    }
}

impl<A, Xʹ, Yʹ, R, S, T> Preadjointable<A,Pair<Xʹ,Yʹ>> for ColOp<S, T>
where
    A : Space,
    Xʹ : Space,
    Yʹ : Space,
    R : Space + ClosedAdd,
    S : Preadjointable<A, Xʹ, PreadjointCodomain = R>,
    T : Preadjointable<A, Yʹ, PreadjointCodomain = R>,
    Self : Linear<A>,
    for<'a> RowOp<S::Preadjoint<'a>, T::Preadjoint<'a>> : Linear<
        Pair<Xʹ,Yʹ>, Codomain = R,
    >,
{
    type PreadjointCodomain = R;
    type Preadjoint<'a> = RowOp<S::Preadjoint<'a>, T::Preadjoint<'a>> where Self : 'a;

    fn preadjoint(&self) -> Self::Preadjoint<'_> {
        RowOp(self.0.preadjoint(), self.1.preadjoint())
    }
}

/// Diagonal operator
pub struct DiagOp<S, T>(pub S, pub T);

impl<A, B, S, T> Mapping<Pair<A, B>> for DiagOp<S, T>
where
    A : Space,
    B : Space,
    S : Mapping<A>,
    T : Mapping<B>,
{
    type Codomain = Pair<S::Codomain, T::Codomain>;

    fn apply<I : Instance<Pair<A, B>>>(&self, x : I) -> Self::Codomain {
        let Pair(a, b) = x.decompose();
        Pair(self.0.apply(a), self.1.apply(b))
    }
}

impl<A, B, S, T> Linear<Pair<A, B>> for DiagOp<S, T>
where
    A : Space,
    B : Space,
    S : Linear<A>,
    T : Linear<B>,
{ }

impl<F, S, T, A, B, U, V> GEMV<F, Pair<U, V>, Pair<A, B>> for DiagOp<S, T>
where
    A : Space,
    B : Space,
    U : Space,
    V : Space,
    S : GEMV<F, U, A>,
    T : GEMV<F, V, B>,
    F : Num,
    Self : Linear<Pair<U, V>, Codomain=Pair<A, B>>,
{
    fn gemv<I : Instance<Pair<U, V>>>(&self, y : &mut Pair<A, B>, α : F, x : I, β : F) {
        let Pair(u, v) = x.decompose();
        self.0.gemv(&mut y.0, α, u, β);
        self.1.gemv(&mut y.1, α, v, β);
    }

    fn apply_mut<I : Instance<Pair<U, V>>>(&self, y : &mut Pair<A, B>, x : I){
        let Pair(u, v) = x.decompose();
        self.0.apply_mut(&mut y.0, u);
        self.1.apply_mut(&mut y.1, v);
    }

    /// Computes `y += Ax`, where `A` is `Self`
    fn apply_add<I : Instance<Pair<U, V>>>(&self, y : &mut Pair<A, B>, x : I){
        let Pair(u, v) = x.decompose();
        self.0.apply_add(&mut y.0, u);
        self.1.apply_add(&mut y.1, v);
    }
}

impl<A, B, Xʹ, Yʹ, R, S, T> Adjointable<Pair<A,B>, Pair<Xʹ,Yʹ>> for DiagOp<S, T>
where
    A : Space,
    B : Space,
    Xʹ: Space,
    Yʹ : Space,
    R : Space,
    S : Adjointable<A, Xʹ>,
    T : Adjointable<B, Yʹ>,
    Self : Linear<Pair<A, B>>,
    for<'a> DiagOp<S::Adjoint<'a>, T::Adjoint<'a>> : Linear<
        Pair<Xʹ,Yʹ>, Codomain=R,
    >,
{
    type AdjointCodomain = R;
    type Adjoint<'a> = DiagOp<S::Adjoint<'a>, T::Adjoint<'a>> where Self : 'a;

    fn adjoint(&self) -> Self::Adjoint<'_> {
        DiagOp(self.0.adjoint(), self.1.adjoint())
    }
}

impl<A, B, Xʹ, Yʹ, R, S, T> Preadjointable<Pair<A,B>, Pair<Xʹ,Yʹ>> for DiagOp<S, T>
where
    A : Space,
    B : Space,
    Xʹ: Space,
    Yʹ : Space,
    R : Space,
    S : Preadjointable<A, Xʹ>,
    T : Preadjointable<B, Yʹ>,
    Self : Linear<Pair<A, B>>,
    for<'a> DiagOp<S::Preadjoint<'a>, T::Preadjoint<'a>> : Linear<
        Pair<Xʹ,Yʹ>, Codomain=R,
    >,
{
    type PreadjointCodomain = R;
    type Preadjoint<'a> = DiagOp<S::Preadjoint<'a>, T::Preadjoint<'a>> where Self : 'a;

    fn preadjoint(&self) -> Self::Preadjoint<'_> {
        DiagOp(self.0.preadjoint(), self.1.preadjoint())
    }
}

/// Block operator
pub type BlockOp<S11, S12, S21, S22> = ColOp<RowOp<S11, S12>, RowOp<S21, S22>>;


macro_rules! pairnorm {
    ($expj:ty) => {
        impl<F, A, B, S, T, ExpA, ExpB, ExpR>
        BoundedLinear<Pair<A, B>, PairNorm<ExpA, ExpB, $expj>, ExpR, F>
        for RowOp<S, T>
        where
            F : Float,
            A : Space + Norm<F, ExpA>,
            B : Space + Norm<F, ExpB>,
            S : BoundedLinear<A, ExpA, ExpR, F>,
            T : BoundedLinear<B, ExpB, ExpR, F>,
            S::Codomain : Add<T::Codomain>,
            <S::Codomain as Add<T::Codomain>>::Output : Space,
            ExpA : NormExponent,
            ExpB : NormExponent,
            ExpR : NormExponent,
        {
            fn opnorm_bound(
                &self,
                PairNorm(expa, expb, _) : PairNorm<ExpA, ExpB, $expj>,
                expr : ExpR
            ) -> F {
                // An application of the triangle inequality bounds the norm by the maximum
                // of the individual norms. A simple observation shows this to be exact.
                let na = self.0.opnorm_bound(expa, expr);
                let nb = self.1.opnorm_bound(expb, expr);
                na.max(nb)
            }
        }
        
        impl<F, A, S, T, ExpA, ExpS, ExpT>
        BoundedLinear<A, ExpA, PairNorm<ExpS, ExpT, $expj>, F>
        for ColOp<S, T>
        where
            F : Float,
            A : Space + Norm<F, ExpA>,
            S : BoundedLinear<A, ExpA, ExpS, F>,
            T : BoundedLinear<A, ExpA, ExpT, F>,
            ExpA : NormExponent,
            ExpS : NormExponent,
            ExpT : NormExponent,
        {
            fn opnorm_bound(
                &self,
                expa : ExpA,
                PairNorm(exps, expt, _) : PairNorm<ExpS, ExpT, $expj>
            ) -> F {
                // This is based on the rule for RowOp and ‖A^*‖ = ‖A‖, hence,
                // for A=[S; T], ‖A‖=‖[S^*, T^*]‖ ≤ max{‖S^*‖, ‖T^*‖} = max{‖S‖, ‖T‖}
                let ns = self.0.opnorm_bound(expa, exps);
                let nt = self.1.opnorm_bound(expa, expt);
                ns.max(nt)
            }
        }
    }
}

pairnorm!(L1);
pairnorm!(L2);
pairnorm!(Linfinity);

