
/*!
Traits for representing the support of a [`Apply`], and analysing the mapping on a [`Cube`].
*/
use serde::Serialize;
use std::ops::{MulAssign,DivAssign,Neg};
use crate::types::{Float, Num};
use crate::maputil::map2;
use crate::mapping::{Apply, Differentiable};
use crate::sets::Cube;
use crate::loc::Loc;
use super::aggregator::Bounds;
use crate::norms::{Norm, L1, L2, Linfinity};

/// A trait for encoding constant [`Float`] values
pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into<Self::Type> {
    /// The type of the value
    type Type : Float;
    /// Returns the value of the constant
    fn value(&self) -> Self::Type;
}

impl<F : Float> Constant for F {
    type Type = F;
    #[inline]
    fn value(&self) -> F { *self }
}


/// A trait for working with the supports of [`Apply`]s.
///
/// Apply is not a super-trait to allow more general use.
pub trait Support<F : Num, const N : usize> : Sized + Sync + Send + 'static {
    /// Return a cube containing the support of the function represented by `self`.
    ///
    /// The hint may be larger than the actual support, but must contain it.
    fn support_hint(&self) -> Cube<F,N>;

    /// Indicate whether `x` is in the support of the function represented by `self`.
    fn in_support(&self, x : &Loc<F,N>) -> bool;

    // Indicate whether `cube` is fully in the support of the function represented by `self`.
    //fn fully_in_support(&self, cube : &Cube<F,N>) -> bool;

    /// Return an optional hint for bisecting the support.
    ///
    /// The output along each axis a possible coordinate at which to bisect `cube`.
    ///
    /// This is useful for nonsmooth functions to make finite element models as used by
    /// [`BTFN`][super::btfn::BTFN] minimisation/maximisation compatible with points of
    /// non-differentiability.
    ///
    /// The default implementation returns `[None; N]`.
    #[inline]
    #[allow(unused_variables)]
    fn bisection_hint(&self, cube : &Cube<F, N>) -> [Option<F>; N] {
        [None; N]
    }

    /// Translate `self` by `x`.
    #[inline]
    fn shift(self, x : Loc<F, N>) -> Shift<Self, F, N> {
        Shift { shift : x, base_fn : self }
    }

    /// Multiply `self` by the scalar `a`.
    #[inline]
    fn weigh<C : Constant<Type=F>>(self, a : C) -> Weighted<Self, C> {
        Weighted { weight : a, base_fn : self }
    }
}

/// Trait for globally analysing a property `A` of a [`Apply`].
///
/// Typically `A` is an [`Aggregator`][super::aggregator::Aggregator] such as
/// [`Bounds`][super::aggregator::Bounds].
pub trait GlobalAnalysis<F : Num, A> {
    /// Perform global analysis of the property `A` of `Self`.
    ///
    /// As an example, in the case of `A` being [`Bounds`][super::aggregator::Bounds],
    /// this function will return global upper and lower bounds for the mapping
    /// represented by `self`.
    fn global_analysis(&self) -> A;
}

// default impl<F, A, N, L> GlobalAnalysis<F, A, N> for L
// where L : LocalAnalysis<F, A, N> {
//     #[inline]
//     fn global_analysis(&self) -> Bounds<F> {
//         self.local_analysis(&self.support_hint())
//     }
// }

/// Trait for locally analysing a property `A` of a [`Apply`] (implementing [`Support`])
/// within a [`Cube`].
///
/// Typically `A` is an [`Aggregator`][super::aggregator::Aggregator] such as
/// [`Bounds`][super::aggregator::Bounds].
pub trait LocalAnalysis<F : Num, A, const N : usize> : GlobalAnalysis<F, A> + Support<F, N> {
    /// Perform local analysis of the property `A` of `Self`.
    ///
    /// As an example, in the case of `A` being [`Bounds`][super::aggregator::Bounds],
    /// this function will return upper and lower bounds within `cube` for the mapping
    /// represented by `self`.
    fn local_analysis(&self, cube : &Cube<F, N>) -> A;
}

/// Trait for determining the upper and lower bounds of an float-valued [`Apply`].
///
/// This is a blanket-implemented alias for [`GlobalAnalysis`]`<F, Bounds<F>>`
/// [`Apply`] is not a supertrait to allow flexibility in the implementation of either
/// reference or non-reference arguments.
pub trait Bounded<F : Float> : GlobalAnalysis<F, Bounds<F>> {
    /// Return lower and upper bounds for the values of of `self`.
    #[inline]
    fn bounds(&self) -> Bounds<F> {
        self.global_analysis()
    }
}

impl<F : Float, T : GlobalAnalysis<F, Bounds<F>>> Bounded<F> for T { }

/// Shift of [`Support`] and [`Apply`]; output of [`Support::shift`].
#[derive(Copy,Clone,Debug,Serialize)] // Serialize! but not implemented by Loc.
pub struct Shift<T, F, const N : usize> {
    shift : Loc<F, N>,
    base_fn : T,
}

impl<'a, T, V, F : Float, const N : usize> Apply<&'a Loc<F, N>> for Shift<T,F,N>
where T : Apply<Loc<F, N>, Output=V> {
    type Output = V;
    #[inline]
    fn apply(&self, x : &'a Loc<F, N>) -> Self::Output {
        self.base_fn.apply(x - &self.shift)
    }
}

impl<'a, T, V, F : Float, const N : usize> Apply<Loc<F, N>> for Shift<T,F,N>
where T : Apply<Loc<F, N>, Output=V> {
    type Output = V;
    #[inline]
    fn apply(&self, x : Loc<F, N>) -> Self::Output {
        self.base_fn.apply(x - &self.shift)
    }
}

impl<'a, T, V, F : Float, const N : usize> Differentiable<&'a Loc<F, N>> for Shift<T,F,N>
where T : Differentiable<Loc<F, N>, Output=V> {
    type Output = V;
    #[inline]
    fn differential(&self, x : &'a Loc<F, N>) -> Self::Output {
        self.base_fn.differential(x - &self.shift)
    }
}

impl<'a, T, V, F : Float, const N : usize> Differentiable<Loc<F, N>> for Shift<T,F,N>
where T : Differentiable<Loc<F, N>, Output=V> {
    type Output = V;
    #[inline]
    fn differential(&self, x : Loc<F, N>) -> Self::Output {
        self.base_fn.differential(x - &self.shift)
    }
}

impl<'a, T, F : Float, const N : usize> Support<F,N> for Shift<T,F,N>
where T : Support<F, N> {
    #[inline]
    fn support_hint(&self) -> Cube<F,N> {
        self.base_fn.support_hint().shift(&self.shift)
    }

    #[inline]
    fn in_support(&self, x : &Loc<F,N>) -> bool {
        self.base_fn.in_support(&(x - &self.shift))
    }
    
    // fn fully_in_support(&self, _cube : &Cube<F,N>) -> bool {
    //     //self.base_fn.fully_in_support(cube.shift(&vectorneg(self.shift)))
    //     todo!("Not implemented, but not used at the moment")
    // }

    #[inline]
    fn bisection_hint(&self, cube : &Cube<F,N>) -> [Option<F>; N] {
        let base_hint = self.base_fn.bisection_hint(cube);
        map2(base_hint, &self.shift, |h, s| h.map(|z| z + *s))
    }

}

impl<'a, T, F : Float, const N : usize> GlobalAnalysis<F, Bounds<F>> for Shift<T,F,N>
where T : LocalAnalysis<F, Bounds<F>, N> {
    #[inline]
    fn global_analysis(&self) -> Bounds<F> {
        self.base_fn.global_analysis()
    }
}

impl<'a, T, F : Float, const N : usize> LocalAnalysis<F, Bounds<F>, N> for Shift<T,F,N>
where T : LocalAnalysis<F, Bounds<F>, N> {
    #[inline]
    fn local_analysis(&self, cube : &Cube<F, N>) -> Bounds<F> {
        self.base_fn.local_analysis(&cube.shift(&(-self.shift)))
    }
}

macro_rules! impl_shift_norm {
    ($($norm:ident)*) => { $(
        impl<'a, T, F : Float, const N : usize> Norm<F, $norm> for Shift<T,F,N>
        where T : Norm<F, $norm> {
            #[inline]
            fn norm(&self, n : $norm) -> F {
                self.base_fn.norm(n)
            }
        }
    )* }
}

impl_shift_norm!(L1 L2 Linfinity);

/// Weighting of a [`Support`] and [`Apply`] by scalar multiplication;
/// output of [`Support::weigh`].
#[derive(Copy,Clone,Debug,Serialize)]
pub struct Weighted<T, C : Constant> {
    /// The weight
    pub weight : C,
    /// The base [`Support`] or [`Apply`] being weighted.
    pub base_fn : T,
}

impl<'a, T, V, F : Float, C, const N : usize> Apply<&'a Loc<F, N>> for Weighted<T, C>
where T : for<'b> Apply<&'b Loc<F, N>, Output=V>,
      V : std::ops::Mul<F,Output=V>,
      C : Constant<Type=F> {
    type Output = V;
    #[inline]
    fn apply(&self, x : &'a Loc<F, N>) -> Self::Output {
        self.base_fn.apply(x) * self.weight.value()
    }
}

impl<'a, T, V, F : Float, C, const N : usize> Apply<Loc<F, N>> for Weighted<T, C>
where T : Apply<Loc<F, N>, Output=V>,
      V : std::ops::Mul<F,Output=V>,
      C : Constant<Type=F> {
    type Output = V;
    #[inline]
    fn apply(&self, x : Loc<F, N>) -> Self::Output {
        self.base_fn.apply(x) * self.weight.value()
    }
}

impl<'a, T, V, F : Float, C, const N : usize> Differentiable<&'a Loc<F, N>> for Weighted<T, C>
where T : for<'b> Differentiable<&'b Loc<F, N>, Output=V>,
      V : std::ops::Mul<F, Output=V>,
      C : Constant<Type=F> {
    type Output = V;
    #[inline]
    fn differential(&self, x : &'a Loc<F, N>) -> Self::Output {
        self.base_fn.differential(x) * self.weight.value()
    }
}

impl<'a, T, V, F : Float, C, const N : usize> Differentiable<Loc<F, N>> for Weighted<T, C>
where T : Differentiable<Loc<F, N>, Output=V>,
      V : std::ops::Mul<F, Output=V>,
      C : Constant<Type=F> {
    type Output = V;
    #[inline]
    fn differential(&self, x : Loc<F, N>) -> Self::Output {
        self.base_fn.differential(x) * self.weight.value()
    }
}

impl<'a, T, F : Float, C, const N : usize> Support<F,N> for Weighted<T, C>
where T : Support<F, N>,
      C : Constant<Type=F> {

    #[inline]
    fn support_hint(&self) -> Cube<F,N> {
        self.base_fn.support_hint()
    }

    #[inline]
    fn in_support(&self, x : &Loc<F,N>) -> bool {
        self.base_fn.in_support(x)
    }
    
    // fn fully_in_support(&self, cube : &Cube<F,N>) -> bool {
    //     self.base_fn.fully_in_support(cube)
    // }

    #[inline]
    fn bisection_hint(&self, cube : &Cube<F,N>) -> [Option<F>; N] {
        self.base_fn.bisection_hint(cube)
    }
}

impl<'a, T, F : Float, C> GlobalAnalysis<F, Bounds<F>> for Weighted<T, C>
where T : GlobalAnalysis<F, Bounds<F>>,
      C : Constant<Type=F> {
    #[inline]
    fn global_analysis(&self) -> Bounds<F> {
        let Bounds(lower, upper) = self.base_fn.global_analysis();
        debug_assert!(lower <= upper);
        match self.weight.value() {
            w if w < F::ZERO => Bounds(w * upper, w * lower),
            w => Bounds(w * lower, w * upper),
        }
    }
}

impl<'a, T, F : Float, C, const N : usize> LocalAnalysis<F, Bounds<F>, N> for Weighted<T, C>
where T : LocalAnalysis<F, Bounds<F>, N>,
      C : Constant<Type=F> {
    #[inline]
    fn local_analysis(&self, cube : &Cube<F, N>) -> Bounds<F> {
        let Bounds(lower, upper) = self.base_fn.local_analysis(cube);
        debug_assert!(lower <= upper);
        match self.weight.value() {
            w if w < F::ZERO => Bounds(w * upper, w * lower),
            w => Bounds(w * lower, w * upper),
        }
    }
}

macro_rules! make_weighted_scalarop_rhs {
    ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => {
        impl<F : Float, T> std::ops::$trait_assign<F> for Weighted<T, F> {
            #[inline]
            fn $fn_assign(&mut self, t : F) {
                self.weight.$fn_assign(t);
            }
        }

        impl<'a, F : Float, T> std::ops::$trait<F> for Weighted<T, F> {
            type Output = Self;
            #[inline]
            fn $fn(mut self, t : F) -> Self {
                self.weight.$fn_assign(t);
                self
            }
        }

        impl<'a, F : Float, T> std::ops::$trait<F> for &'a Weighted<T, F>
        where T : Clone {
            type Output = Weighted<T, F>;
            #[inline]
            fn $fn(self, t : F) -> Self::Output {
                Weighted { weight : self.weight.$fn(t), base_fn : self.base_fn.clone() }
            }
        }
    }
}

make_weighted_scalarop_rhs!(Mul, mul, MulAssign, mul_assign);
make_weighted_scalarop_rhs!(Div, div, DivAssign, div_assign);

macro_rules! impl_weighted_norm {
    ($($norm:ident)*) => { $(
        impl<'a, T, F : Float> Norm<F, $norm> for Weighted<T,F>
        where T : Norm<F, $norm> {
            #[inline]
            fn norm(&self, n : $norm) -> F {
                self.base_fn.norm(n) * self.weight.abs()
            }
        }
    )* }
}

impl_weighted_norm!(L1 L2 Linfinity);


/// Normalisation of [`Support`] and [`Apply`] to L¹ norm 1.
///
/// Currently only scalar-valued functions are supported.
#[derive(Copy, Clone, Debug, Serialize, PartialEq)]
pub struct Normalised<T>(
    /// The base [`Support`] or [`Apply`].
    pub T
);

impl<'a, T, F : Float, const N : usize> Apply<&'a Loc<F, N>> for Normalised<T>
where T : Norm<F, L1> + for<'b> Apply<&'b Loc<F, N>, Output=F> {
    type Output = F;
    #[inline]
    fn apply(&self, x : &'a Loc<F, N>) -> Self::Output {
        let w = self.0.norm(L1);
        if w == F::ZERO { F::ZERO } else { self.0.apply(x) / w }
    }
}

impl<'a, T, F : Float, const N : usize> Apply<Loc<F, N>> for Normalised<T>
where T : Norm<F, L1> + Apply<Loc<F,N>, Output=F> {
    type Output = F;
    #[inline]
    fn apply(&self, x : Loc<F, N>) -> Self::Output {
        let w = self.0.norm(L1);
        if w == F::ZERO { F::ZERO } else { self.0.apply(x) / w }
    }
}

impl<'a, T, F : Float, const N : usize> Support<F,N> for Normalised<T>
where T : Norm<F, L1> + Support<F, N> {

    #[inline]
    fn support_hint(&self) -> Cube<F,N> {
        self.0.support_hint()
    }

    #[inline]
    fn in_support(&self, x : &Loc<F,N>) -> bool {
        self.0.in_support(x)
    }
    
    // fn fully_in_support(&self, cube : &Cube<F,N>) -> bool {
    //     self.0.fully_in_support(cube)
    // }

    #[inline]
    fn bisection_hint(&self, cube : &Cube<F,N>) -> [Option<F>; N] {
        self.0.bisection_hint(cube)
    }
}

impl<'a, T, F : Float> GlobalAnalysis<F, Bounds<F>> for Normalised<T>
where T : Norm<F, L1> + GlobalAnalysis<F, Bounds<F>> {
    #[inline]
    fn global_analysis(&self) -> Bounds<F> {
        let Bounds(lower, upper) = self.0.global_analysis();
        debug_assert!(lower <= upper);
        let w = self.0.norm(L1);
        debug_assert!(w >= F::ZERO);
        Bounds(w * lower, w * upper)
    }
}

impl<'a, T, F : Float, const N : usize> LocalAnalysis<F, Bounds<F>, N> for Normalised<T>
where T : Norm<F, L1> + LocalAnalysis<F, Bounds<F>, N> {
    #[inline]
    fn local_analysis(&self, cube : &Cube<F, N>) -> Bounds<F> {
        let Bounds(lower, upper) = self.0.local_analysis(cube);
        debug_assert!(lower <= upper);
        let w = self.0.norm(L1);
        debug_assert!(w >= F::ZERO);
        Bounds(w * lower, w * upper)
    }
}

impl<'a, T, F : Float> Norm<F, L1> for Normalised<T>
where T : Norm<F, L1> {
    #[inline]
    fn norm(&self, _ : L1) -> F {
        let w = self.0.norm(L1);
        if w == F::ZERO { F::ZERO } else { F::ONE }
    }
}

macro_rules! impl_normalised_norm {
    ($($norm:ident)*) => { $(
        impl<'a, T, F : Float> Norm<F, $norm> for Normalised<T>
        where T : Norm<F, $norm> + Norm<F, L1> {
            #[inline]
            fn norm(&self, n : $norm) -> F {
                let w = self.0.norm(L1);
                if w == F::ZERO { F::ZERO } else { self.0.norm(n) / w }
            }
        }
    )* }
}

impl_normalised_norm!(L2 Linfinity);

/*
impl<F : Num, S : Support<F, N>, const N : usize> LocalAnalysis<F, NullAggregator, N> for S {
    fn local_analysis(&self, _cube : &Cube<F, N>) -> NullAggregator { NullAggregator }
}

impl<F : Float, S : Bounded<F>, const N : usize> LocalAnalysis<F, Bounds<F>, N> for S {
    #[inline]
    fn local_analysis(&self, cube : &Cube<F, N>) -> Bounds<F> {
        self.bounds(cube)
    }
}*/

/// Generator of [`Support`]-implementing component functions based on low storage requirement
/// [ids][`Self::Id`].
pub trait SupportGenerator<F : Float, const N : usize>
: MulAssign<F> + DivAssign<F> + Neg<Output=Self> + Clone + Sync + Send + 'static {
    /// The identification type
    type Id : 'static + Copy;
    /// The type of the [`Support`] (often also a [`Apply`]).
    type SupportType : 'static + Support<F, N>;
    /// An iterator over all the [`Support`]s of the generator.
    type AllDataIter<'a> : Iterator<Item=(Self::Id, Self::SupportType)> where Self : 'a;

    /// Returns the component identified by `id`.
    ///
    /// Panics if `id` is an invalid identifier.
    fn support_for(&self, id : Self::Id) -> Self::SupportType;
    
    /// Returns the number of different components in this generator.
    fn support_count(&self) -> usize;

    /// Returns an iterator over all pairs of `(id, support)`.
    fn all_data(&self) -> Self::AllDataIter<'_>;
}

