/*!
Some convex analysis basics
*/

use std::marker::PhantomData;
use crate::types::*;
use crate::mapping::{Mapping, Space, ArithmeticTrue};
use crate::linops::IdOp;
use crate::instance::{Instance, InstanceMut, DecompositionMut,};
use crate::norms::*;

/// Trait for convex mappings. Has no features, just serves as a constraint
///
/// TODO: should constrain `Mapping::Codomain` to implement a partial order,
/// but this makes everything complicated with little benefit.
pub trait ConvexMapping<Domain : Space, F : Num = f64> : Mapping<Domain, Codomain = F>
{}

/// Trait for mappings with a Fenchel conjugate
///
/// The conjugate type has to implement [`ConvexMapping`], but a `Conjugable` mapping need
/// not be convex.
pub trait Conjugable<Domain : HasDual<F>, F : Num = f64> : Mapping<Domain> {
    type Conjugate<'a> : ConvexMapping<Domain::DualSpace, F> where Self : 'a;

    fn conjugate(&self) -> Self::Conjugate<'_>;
}

/// Trait for mappings with a Fenchel preconjugate
///
/// In contrast to [`Conjugable`], the preconjugate need not implement [`ConvexMapping`],
/// but a `Preconjugable` mapping has to be convex.
pub trait Preconjugable<Domain, Predual, F : Num = f64> : ConvexMapping<Domain, F>
where
    Domain : Space,
    Predual : HasDual<F>
{
    type Preconjugate<'a> : Mapping<Predual> where Self : 'a;

    fn preconjugate(&self) -> Self::Preconjugate<'_>;
}

/// Trait for mappings with a proximap map
///
/// The conjugate type has to implement [`ConvexMapping`], but a `Conjugable` mapping need
/// not be convex.
pub trait Prox<Domain : Space> : Mapping<Domain> {
    type Prox<'a> : Mapping<Domain, Codomain=Domain> where Self : 'a;

    /// Returns a proximal mapping with weight τ
    fn prox_mapping(&self, τ : Self::Codomain) -> Self::Prox<'_>;

    /// Calculate the proximal mapping with weight τ
    fn prox<I : Instance<Domain>>(&self, τ : Self::Codomain, z : I) -> Domain {
        self.prox_mapping(τ).apply(z)
    }

    /// Calculate the proximal mapping with weight τ in-place
    fn prox_mut<'b>(&self, τ : Self::Codomain, y : &'b mut Domain)
    where
        &'b mut Domain : InstanceMut<Domain>,
        Domain:: Decomp : DecompositionMut<Domain>,
        for<'a> &'a Domain : Instance<Domain>,
    {
        *y = self.prox(τ, &*y);
    }
}


pub struct NormConjugate<F : Float, E : NormExponent>(NormMapping<F, E>);

impl<Domain, E, F> ConvexMapping<Domain, F> for NormMapping<F, E>
where
    Domain : Space,
    E : NormExponent,
    F : Float,
    Self : Mapping<Domain, Codomain=F>
{}


impl<F, E, Domain> Mapping<Domain> for NormConjugate<F, E>
where
    Domain : Space + Norm<F, E>,
    F : Float,
    E : NormExponent,
{
    type Codomain = F;
    type ArithmeticOptIn = ArithmeticTrue;

    fn apply<I : Instance<Domain>>(&self, d : I) -> F {
        if d.eval(|x| x.norm(self.0.exponent)) <= F::ONE {
            F::ZERO
        } else {
            F::INFINITY
        }
    }
}

impl<Domain, E, F> ConvexMapping<Domain, F> for NormConjugate<F, E>
where
    Domain : Space,
    E : NormExponent,
    F : Float,
    Self : Mapping<Domain, Codomain=F>
{}


impl<E, F, Domain> Conjugable<Domain, F> for NormMapping<F, E>
where
    E : HasDualExponent,
    F : Float,
    Domain : HasDual<F> + Norm<F, E> + Space,
    <Domain as HasDual<F>>::DualSpace : Norm<F, E::DualExp>
{
    type Conjugate<'a> = NormConjugate<F, E::DualExp> where Self : 'a;

    fn conjugate(&self) -> Self::Conjugate<'_> {
        NormConjugate(self.exponent.dual_exponent().as_mapping())
    }
}

impl<Domain, E, F> Prox<Domain> for NormConjugate<F, E>
where
    Domain : Space + Norm<F, E>,
    E : NormExponent,
    F : Float,
    NormProjection<F, E> : Mapping<Domain, Codomain=Domain>,
{
    type Prox<'a> = NormProjection<F, E> where Self : 'a;

    #[inline]
    fn prox_mapping(&self, τ : Self::Codomain) -> Self::Prox<'_> {
        NormProjection{ α : τ, exponent : self.0.exponent }
    }
}

pub struct NormProjection<F : Float, E : NormExponent> {
    α : F,
    exponent : E,
}

/*
impl<F, Domain> Mapping<Domain> for NormProjection<F, L2>
where
    Domain : Space + Euclidean<F> + std::ops::MulAssign<F>,
    F : Float,
{
    type Codomain = Domain;

    fn apply<I : Instance<Domain>>(&self, d : I) -> Domain {
        d.own().proj_ball2(self.α)
    }
}
*/

impl<F, E, Domain> Mapping<Domain> for NormProjection<F, E>
where
    Domain : Space + Projection<F, E> + std::ops::MulAssign<F>,
    F : Float,
    E : NormExponent,
{
    type Codomain = Domain;
    type ArithmeticOptIn = ArithmeticTrue;

    fn apply<I : Instance<Domain>>(&self, d : I) -> Domain {
        d.own().proj_ball(self.α, self.exponent)
    }
}

/// The zero mapping
pub struct Zero<Domain : Space, F : Num>(PhantomData<(Domain, F)>);

impl<Domain : Space, F : Num> Zero<Domain, F> {
    pub fn new() -> Self {
        Zero(PhantomData)
    }
}

impl<Domain : Space, F : Num> Mapping<Domain> for Zero<Domain, F> {
    type Codomain = F;
    type ArithmeticOptIn = ArithmeticTrue;

    /// Compute the value of `self` at `x`.
    fn apply<I : Instance<Domain>>(&self, _x : I) -> Self::Codomain {
        F::ZERO
    }
}

impl<Domain : Space, F : Num> ConvexMapping<Domain, F> for Zero<Domain, F> { }


impl<Domain : HasDual<F>, F : Float> Conjugable<Domain, F> for Zero<Domain, F> {
    type Conjugate<'a> = ZeroIndicator<Domain::DualSpace, F> where Self : 'a;

    #[inline]
    fn conjugate(&self) -> Self::Conjugate<'_> {
        ZeroIndicator::new()
    }
}

impl<Domain, Predual, F : Float> Preconjugable<Domain, Predual, F> for Zero<Domain, F>
where
    Domain : Space,
    Predual : HasDual<F>
{
    type Preconjugate<'a> = ZeroIndicator<Predual, F> where Self : 'a;

    #[inline]
    fn preconjugate(&self) -> Self::Preconjugate<'_> {
        ZeroIndicator::new()
    }
}

impl<Domain : Space + Clone, F : Num> Prox<Domain> for Zero<Domain, F> {
    type Prox<'a> = IdOp<Domain> where Self : 'a;

    #[inline]
    fn prox_mapping(&self, _τ : Self::Codomain) -> Self::Prox<'_> {
        IdOp::new()
    }
}


/// The zero indicator
pub struct ZeroIndicator<Domain : Space, F : Num>(PhantomData<(Domain, F)>);

impl<Domain : Space, F : Num> ZeroIndicator<Domain, F> {
    pub fn new() -> Self {
        ZeroIndicator(PhantomData)
    }
}

impl<Domain : Normed<F>, F : Float> Mapping<Domain> for ZeroIndicator<Domain, F> {
    type Codomain = F;
    type ArithmeticOptIn = ArithmeticTrue;

    /// Compute the value of `self` at `x`.
    fn apply<I : Instance<Domain>>(&self, x : I) -> Self::Codomain {
        x.eval(|x̃| if x̃.is_zero() { F::ZERO } else { F::INFINITY })
    }
}

impl<Domain : Normed<F>, F : Float> ConvexMapping<Domain, F> for ZeroIndicator<Domain, F> { }

impl<Domain : HasDual<F>, F : Float> Conjugable<Domain, F> for ZeroIndicator<Domain, F> {
    type Conjugate<'a> = Zero<Domain::DualSpace, F> where Self : 'a;

    #[inline]
    fn conjugate(&self) -> Self::Conjugate<'_> {
        Zero::new()
    }
}

impl<Domain, Predual, F : Float> Preconjugable<Domain, Predual, F> for ZeroIndicator<Domain, F>
where
    Domain : Normed<F>,
    Predual : HasDual<F>
{
    type Preconjugate<'a> = Zero<Predual, F> where Self : 'a;

    #[inline]
    fn preconjugate(&self) -> Self::Preconjugate<'_> {
        Zero::new()
    }
}
