/*!
Norms, projections, etc.
*/

use crate::types::*;
use std::ops::{Mul,MulAssign,Div,DivAssign,Add,Sub,AddAssign,SubAssign,Neg};
use serde::Serialize;
//use std::iter::Sum;

//
// Dot products
//

/// Space with a defined dot product.
pub trait Dot<U,F> {
    fn dot(&self, other : &U) -> F;
}

//self.iter().zip(other.iter()).map(|(&x,&y)| x*y).sum()

//
// Euclidean spaces
//

pub trait Euclidean<F : Float> : Sized + Dot<Self,F>
        + Mul<F, Output=<Self as Euclidean<F>>::Output> + MulAssign<F>
        + Div<F, Output=<Self as Euclidean<F>>::Output> + DivAssign<F>
        + Add<Self, Output=<Self as Euclidean<F>>::Output>
        + Sub<Self, Output=<Self as Euclidean<F>>::Output>
        + for<'b> Add<&'b Self, Output=<Self as Euclidean<F>>::Output>
        + for<'b> Sub<&'b Self, Output=<Self as Euclidean<F>>::Output>
        + AddAssign<Self> + for<'b> AddAssign<&'b Self>
        + SubAssign<Self> + for<'b> SubAssign<&'b Self>
        + Neg<Output=<Self as Euclidean<F>>::Output> {
    type Output : Euclidean<F>;

    /// Returns the origin of same dimensions as `self`.
    fn similar_origin(&self) -> <Self as Euclidean<F>>::Output;

    /// Calculate the square of the 2-norm.
    #[inline]
    fn norm2_squared(&self) -> F {
        self.dot(self)
    }

    /// Calculate the square of the 2-norm divided by 2.
    #[inline]
    fn norm2_squared_div2(&self) -> F {
        self.norm2_squared()/F::TWO
    }

    /// Calculate the 2-norm.
    #[inline]
    fn norm2(&self) -> F {
        self.norm2_squared().sqrt()
    }

    /// Calculate the 2-distance squared.
    fn dist2_squared(&self, other : &Self) -> F;

    /// Calculate the 2-distance.
    #[inline]
    fn dist2(&self, other : &Self) -> F {
        self.dist2_squared(other).sqrt()
    }

    /// Project to the 2-ball.
    #[inline]
    fn proj_ball2(mut self, ρ : F) -> Self {
        self.proj_ball2_mut(ρ);
        self
    }

    /// Project to the 2-ball in-place.
    #[inline]
    fn proj_ball2_mut(&mut self, ρ : F) {
        let r = self.norm2();
        if r>ρ {
            *self *= ρ/r
        }
    }
}

/// Trait for [`Euclidean`] spaces with dimensions known at compile time.
pub trait StaticEuclidean<F : Float> : Euclidean<F> {
    /// Returns the origin
    fn origin() -> <Self as Euclidean<F>>::Output;
}

//
// Abstract norms
//

#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)]
pub struct L1;

#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)]
pub struct L2;

#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)]
pub struct Linfinity;

#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)]
pub struct L21;

#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)]
pub struct HuberL1<F : Float>(pub F);

#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)]
pub struct HuberL21<F : Float>(pub F);

pub trait Norm<F, Exponent> {
    /// Calculate the norm
    fn norm(&self, _p : Exponent) -> F;
}

/// Indicates that a [`Norm`] is dominated by another norm (`Exponent`) on `Elem` with the
/// corresponding field `F`.
pub trait Dominated<F : Num, Exponent, Elem> {
    /// Indicates the factor $c$ for the inequality $‖x‖ ≤ C ‖x‖_p$.
    fn norm_factor(&self, p : Exponent) -> F;
    /// Given a norm-value $‖x‖_p$, calculates $C‖x‖_p$ such that $‖x‖ ≤ C‖x‖_p$
    #[inline]
    fn from_norm(&self, p_norm : F, p : Exponent) -> F {
        p_norm * self.norm_factor(p)
    }
}

pub trait Dist<F,Exponent> : Norm<F, Exponent> {
    /// Calculate the distance
    fn dist(&self, other : &Self, _p : Exponent) -> F;
}

pub trait Projection<F, Exponent> : Norm<F, Exponent> + Euclidean<F> where F : Float {
    /// Project to the norm-ball.
    fn proj_ball(mut self, ρ : F, q : Exponent) -> Self {
        self.proj_ball_mut(ρ, q);
        self
    }

    /// Project to the norm-ball in-place.
    fn proj_ball_mut(&mut self, ρ : F, _q : Exponent);
}

/*impl<F : Float, E : Euclidean<F>> Norm<F, L2> for E {
    #[inline]
    fn norm(&self, _p : L2) -> F { self.norm2() }

    fn dist(&self, other : &Self, _p : L2) -> F { self.dist2(other) }
}*/

impl<F : Float, E : Euclidean<F> + Norm<F, L2>> Projection<F, L2> for E {
    #[inline]
    fn proj_ball(self, ρ : F, _p : L2) -> Self { self.proj_ball2(ρ) }

    #[inline]
    fn proj_ball_mut(&mut self, ρ : F, _p : L2) { self.proj_ball2_mut(ρ) }
}

impl<F : Float> HuberL1<F> {
    fn apply(self, xnsq : F) -> F {
        let HuberL1(γ) = self;
        let xn = xnsq.sqrt();
        if γ == F::ZERO {
            xn
        } else {
            if xn > γ {
                xn-γ/F::TWO
            } else if xn<(-γ) {
                -xn-γ/F::TWO
            } else {
                xnsq/(F::TWO*γ)
            }
        }
    }
}

impl<F : Float, E : Euclidean<F>> Norm<F, HuberL1<F>> for E {
    fn norm(&self, huber : HuberL1<F>) -> F {
        huber.apply(self.norm2_squared())
    }
}

impl<F : Float, E : Euclidean<F>> Dist<F, HuberL1<F>> for E {
    fn dist(&self, other : &Self, huber : HuberL1<F>) -> F {
        huber.apply(self.dist2_squared(other))
    }
}

/*
#[inline]
pub fn mean<V>(x : V) -> V::Field where V : ValidArray {
     x.iter().sum()/x.len()
}

#[inline]
pub fn proj_nonneg_mut<V>(x : &mut V) -> &mut V where V : ValidArray {
    x.iter_mut().for_each(|&mut p| if p < 0 { *p = 0 } );
    x
}
*/


//
// 2,1-norm generic implementation
//

/*
pub trait InnerVectors {
    type Item;
    type Iter : Iterator<Item=Self::Item>;
    fn inner_vectors(&self) -> &Self::Iter;
}

pub trait InnerVectorsMut : InnerVectors {
    type IterMut : Iterator<Item=Self::Item>;
    fn inner_vectors_mut(&self) -> &mut Self::Item;
}

impl<F : Float + Sum, T : InnerVectors> Norm<F, L21> for T where T::Item : Norm<F, L2> {
    fn norm(&self, _ : L21) -> F {
        self.inner_vectors().map(|t| t.norm(L2)).sum()
    }
}

impl<F : Float + Sum, T : InnerVectorsMut + Euclidean<F>> Projection<F, L21>
for T where T::ItemMut : Projection<F, L2> {
    fn proj_ball_mut(&mut self, _ : L21) {
        self.inner_vectors_mut().for_each(|t| t.proj_ball_mut(L2));
    }
}
*/

