/*!
Utilities for mapping over various container types.
*/

#[cfg(feature = "nightly")]
use std::mem::MaybeUninit;
use itertools::izip;

/// Trait for a fixed-length container type.
/// 
/// Implemented by [`Loc`][crate::loc::Loc] vectors, [`Cube`][crate::sets::Cube]s,
/// and basic arrays.
pub trait FixedLength<const N : usize> {
    /// Type of elements of the container.
    type Elem;
    /// Type of iterators over the elements of the container.
    type Iter : Iterator<Item = Self::Elem>;

    /// Returns an iteartor over the elements of the container.
    fn fl_iter(self) -> Self::Iter;
}

/// Trait for a mutable fixed-length container type.
pub trait FixedLengthMut<const N : usize> : FixedLength<N> {
    /// Type of iterators over references to mutable elements of the container.
    type IterMut<'a> : Iterator<Item=&'a mut Self::Elem> where Self : 'a;

    /// Returns an iterator over mutable references to elements of the container.
    fn fl_iter_mut(&mut self) -> Self::IterMut<'_>;
}

impl<A, const N : usize> FixedLength<N> for [A; N] {
    type Elem = A;
    type Iter = std::array::IntoIter<A, N>;
    #[inline]
    fn fl_iter(self) -> Self::Iter {
        self.into_iter()
    }
}

impl<A, const N : usize> FixedLengthMut<N> for [A; N] {
    type IterMut<'a> = std::slice::IterMut<'a, A> where A : 'a;
    #[inline]
    fn fl_iter_mut(&mut self) -> Self::IterMut<'_> {
        self.iter_mut()
    }
}

impl<'a, A, const N : usize> FixedLength<N> for &'a [A; N] {
    type Elem = &'a A;
    type Iter = std::slice::Iter<'a, A>;
    #[inline]
    fn fl_iter(self) -> Self::Iter {
        self.iter()
    }
}

macro_rules! tuple_or_singleton {
    ($a:ident,) => { $a };
    ($($a:ident),+) => { ($($a),+) }
}

macro_rules! make_mapmany {
    ($name:ident, $name_indexed:ident, $var0:ident $($var:ident)* ;
     $etype0:ident $($etype:ident)*, $ctype0:ident $($ctype:ident)*) => {
        /// Map over [`FixedLength`] container(s), returning an array.
        #[inline]
        pub fn $name<
            $etype0,
            $($etype,)*
            $ctype0 : FixedLength<N,Elem=$etype0>,
            $($ctype : FixedLength<N,Elem=$etype>,)*
            Res,
            const N : usize
        >(
            $var0 : $ctype0,
            $($var : $ctype,)*
            f : impl Fn($etype0, $($etype),*) -> Res
        ) -> [Res; N] {
            let zipit = izip!($var0.fl_iter(), $($var.fl_iter()),*);
            let map = zipit.map(|tuple_or_singleton!($var0, $($var),*)| f($var0, $($var),*));
            collect_into_array_unchecked(map)
        }

        /// Map over [`FixedLength`] containers(s) and element indices, returning an array.
        #[inline]
        pub fn $name_indexed<
            $etype0,
            $($etype,)*
            $ctype0 : FixedLength<N,Elem=$etype0>,
            $($ctype : FixedLength<N,Elem=$etype>,)*
            Res,
            const N : usize
        >(
            $var0 : $ctype0,
            $($var : $ctype,)*
            f : impl Fn(usize, $etype0, $($etype),*) -> Res
        ) -> [Res; N] {
            let zipit = (0..N).zip(izip!($var0.fl_iter(), $($var.fl_iter()),*));
            let map = zipit.map(|(i, tuple_or_singleton!($var0, $($var),*))| f(i, $var0, $($var),*));
            collect_into_array_unchecked(map)
        }
    }
}

make_mapmany!(map1, map1_indexed, a;           A,           CA);
make_mapmany!(map2, map2_indexed, a b;         A B,         CA CB);
make_mapmany!(map3, map3_indexed, a b c;       A B C,       CA CB CC);
make_mapmany!(map4, map4_indexed, a b c d;     A B C D,     CA CB CC CD);
make_mapmany!(map5, map5_indexed, a b c d e;   A B C D E,   CA CB CC CD CE);
make_mapmany!(map6, map6_indexed, a b c d e f; A B C D E F, CA CB CC CD CE CF);

macro_rules! make_mapmany_mut{
    ($name:ident, $name_indexed:ident, $var0:ident $($var:ident)* ;
     $etype0:ident $($etype:ident)*, $ctype0:ident $($ctype:ident)*) => {
        /// Map over [`FixedLength`] container(s) with mutable references to the first container.
        #[inline]
        pub fn $name<
            $etype0,
            $($etype,)*
            $ctype0 : FixedLengthMut<N,Elem=$etype0>,
            $($ctype : FixedLength<N,Elem=$etype>,)*
            const N : usize
        > (
            $var0 : &mut $ctype0,
            $($var : $ctype,)*
            f : impl Fn(&mut $etype0, $($etype),*)
        ) {
            let zipit = izip!($var0.fl_iter_mut(), $($var.fl_iter()),*);
            zipit.for_each(|tuple_or_singleton!($var0, $($var),*)| f($var0, $($var),*));
        }

        /// Map over [`FixedLength`] container(s) and element indices
        /// with mutable references to the first container.
        #[inline]
        pub fn $name_indexed<
            $etype0,
            $($etype,)*
            $ctype0 : FixedLengthMut<N,Elem=$etype0>,
            $($ctype : FixedLength<N,Elem=$etype>,)*
            const N : usize
        > (
            $var0 : &mut $ctype0,
            $($var : $ctype,)*
            f : impl Fn(usize, &mut $etype0, $($etype),*)
        ) {
            let zipit = (0..N).zip(izip!($var0.fl_iter_mut(), $($var.fl_iter()),*));
            zipit.for_each(|(i, tuple_or_singleton!($var0, $($var),*))| f(i, $var0, $($var),*));
        }
    }
}

make_mapmany_mut!(map1_mut, map1_indexed_mut, a;           A,           CA);
make_mapmany_mut!(map2_mut, map2_indexed_mut, a b;         A B,         CA CB);
make_mapmany_mut!(map3_mut, map3_indexed_mut, a b c;       A B C,       CA CB CC);
make_mapmany_mut!(map4_mut, map4_indexed_mut, a b c d;     A B C D,     CA CB CC CD);
make_mapmany_mut!(map5_mut, map5_indexed_mut, a b c d e;   A B C D E,   CA CB CC CD CE);
make_mapmany_mut!(map6_mut, map6_indexed_mut, a b c d e f; A B C D E F, CA CB CC CD CE CF);


/// Initialise an array of length `N` by calling `f` multiple times.
#[inline]
pub fn array_init<A, F :  Fn() -> A, const N : usize>(f : F) -> [A; N] {
    //[(); N].map(|_| f())
    core::array::from_fn(|_| f())
}

// /// Initialise an array of length `N` by calling `f` with the index of each element.
// #[inline]
// pub fn array_gen<A, F :  Fn(usize) -> A, const N : usize>(f : F) -> [A; N] {
//     //[(); N].indexmap(|i, _| f(i))
//     core::array::from_fn(f)
// }



/// Iterator returned by [`foldmap`][FoldMappable::foldmap] applied to an iterator.

pub struct FoldMap<I : Iterator<Item=A>, A, B, J : Copy, F : Fn(J, A) -> (J, B)> {
    iter : I,
    f : F,
    j : J,
}

impl<A, B, I : Iterator<Item=A>, J : Copy, F : Fn(J, A) -> (J, B)> Iterator for FoldMap<I, A, B, J, F> {
    type Item = B;
    #[inline]
    fn next(&mut self) -> Option<B> {
        self.iter.next().map(|a| {
            let (jnew, b) = (self.f)(self.j, a);
            self.j = jnew;
            b
        })
    }
}

/// Iterator returned by [`indexmap`][IndexMappable::indexmap] applied to an iterator.
pub struct IndexMap<I : Iterator<Item=A>, A, B, F : Fn(usize, A) -> B> {
    iter : I,
    f : F,
    j : usize,
}

impl<A, B, I : Iterator<Item=A>, F : Fn(usize, A) -> B> Iterator for IndexMap<I, A, B, F> {
    type Item = B;
    #[inline]
    fn next(&mut self) -> Option<B> {
        self.iter.next().map(|a| {
            let b = (self.f)(self.j, a);
            self.j = self.j+1;
            b
        })
    }
}

/// Trait for things that can be foldmapped.
///
/// `A` is the type of elements of `Self`, and `J` the accumulator type for the folding.
pub trait FoldMappable<A, J> : Sized {
    type Output<B, F> where F : Fn(J, A) -> (J, B);
    /// Fold and map over `self` with `f`. `j` is the initial accumulator for folding.
    ///
    /// The output type depends on the implementation, but will generally have elements of
    /// type `B`.
    fn foldmap<B, F : Fn(J, A) -> (J, B)>(self, j : J, f : F) -> Self::Output<B, F>;
}

/// Trait for things that can be indexmapped.
///
/// `A` is the type of elements of `Self`.
pub trait IndexMappable<A> : Sized {
    type Output<B, F> where F : Fn(usize, A) -> B;
    /// Map over element indices and elements of `self`.
    ///
    /// The output type depends on the implementation, but will generally have elements of
    /// type `B`.
    fn indexmap<B, F : Fn(usize, A) -> B>(self, f : F) -> Self::Output<B, F>;
}

impl<'a, A, J : Copy> FoldMappable<&'a A, J>
for std::slice::Iter<'a, A> {
    type Output<B, F> = FoldMap<Self, &'a A, B, J, F> where F : Fn(J, &'a A) -> (J, B);
    #[inline]
    fn foldmap<B, F : Fn(J, &'a A) -> (J, B)>(self, j : J, f : F) -> Self::Output<B, F> {
        FoldMap { iter : self, j, f }
    }
}

impl<'a, A> IndexMappable<&'a A>
for std::slice::Iter<'a, A> {
    type Output<B, F> = IndexMap<Self, &'a A, B, F> where F : Fn(usize, &'a A) -> B;
    #[inline]
    fn indexmap<B, F : Fn(usize, &'a A) -> B>(self, f : F) -> Self::Output<B, F> {
        IndexMap { iter : self, j : 0, f }
    }
}


impl<A, J : Copy, const N : usize> FoldMappable<A, J>
for std::array::IntoIter<A, N> {
    type Output<B, F> = FoldMap<Self, A, B, J, F> where F : Fn(J, A) -> (J, B);
    #[inline]
    fn foldmap<B, F : Fn(J, A) -> (J, B)>(self, j : J, f : F) -> Self::Output<B, F> {
        FoldMap { iter : self, j, f }
    }
}

impl<'a, A, const N : usize> IndexMappable<A>
for std::array::IntoIter<A, N> {
    type Output<B, F> = IndexMap<Self, A, B, F> where F : Fn(usize, A) -> B;
    #[inline]
    fn indexmap<B, F : Fn(usize, A) -> B>(self, f : F) -> Self::Output<B, F> {
        IndexMap { iter : self, j : 0, f }
    }
}

impl<A, J : Copy, const N : usize> FoldMappable<A, J> for [A; N] {
    type Output<B, F> = [B; N] where F : Fn(J, A) -> (J, B);
    #[inline]
    fn foldmap<B, F : Fn(J, A) -> (J, B)>(self, j : J, f : F) -> [B; N] {
        // //let mut res : [MaybeUninit<B>; N] = unsafe { MaybeUninit::uninit().assume_init() };
        // let mut res = MaybeUninit::uninit_array::<N>();
        // for (a, i) in self.into_iter().zip(0..N) {
        //     let (jnew, b) = f(j, a);
        //     unsafe { *(res.get_unchecked_mut(i)) = MaybeUninit::new(b) };
        //     j = jnew;
        // }
        // //unsafe { res.as_mut_ptr().cast::<[B; N]>().read() }
        // unsafe { MaybeUninit::array_assume_init(res) }
        let it = self.into_iter().foldmap(j, f);
        collect_into_array_unchecked(it)
    }
}

impl<A, const N : usize> IndexMappable<A> for [A; N] {
    type Output<B, F> = [B; N] where F : Fn(usize, A) -> B;
    #[inline]
    fn indexmap<B, F : Fn(usize, A) -> B>(self, f : F) -> [B; N] {
        // //let mut res : [MaybeUninit<B>; N] = unsafe { MaybeUninit::uninit().assume_init() };
        // let mut res = MaybeUninit::uninit_array::<N>();
        // for (a, i) in self.into_iter().zip(0..N) {
        //     let b = f(i, a);
        //     unsafe { *(res.get_unchecked_mut(i)) = MaybeUninit::new(b) };
        // }
        // //unsafe { res.as_mut_ptr().cast::<[B; N]>().read() }
        // unsafe { MaybeUninit::array_assume_init(res) }
        let it = self.into_iter().indexmap(f);
        collect_into_array_unchecked(it)
    }
}

/// This is taken and simplified from core::array to not involve `ControlFlow`,
/// `Try` etc. (Pulling everything including `NeverShortCircuit` turned out
/// too much to maintain here.)
///
/// Pulls `N` items from `iter` and returns them as an array. If the iterator
/// yields fewer than `N` items, `None` is returned and all already yielded
/// items are dropped.
///
/// Since the iterator is passed as a mutable reference and this function calls
/// `next` at most `N` times, the iterator can still be used afterwards to
/// retrieve the remaining items.
///
/// If `iter.next()` panicks, all items already yielded by the iterator are
/// dropped.
#[cfg(feature = "nightly")]
#[inline]
pub(crate) fn collect_into_array_unchecked<
    T,
    I : Iterator<Item=T>,
    const N: usize
>(mut iter: I) -> [T; N]
{
    if N == 0 {
        // SAFETY: An empty array is always inhabited and has no validity invariants.
        return unsafe { core::mem::zeroed() };
    }

    struct Guard<'a, T, const N: usize> {
        array_mut: &'a mut [MaybeUninit<T>; N],
        initialized: usize,
    }

    impl<T, const N: usize> Drop for Guard<'_, T, N> {
        #[inline]
        fn drop(&mut self) {
            debug_assert!(self.initialized <= N);

            // SAFETY: this slice will contain only initialized objects.
            unsafe {
                core::ptr::drop_in_place(MaybeUninit::slice_assume_init_mut(
                    &mut self.array_mut.get_unchecked_mut(..self.initialized),
                ));
            }
        }
    }

    let mut array = MaybeUninit::uninit_array::<N>();
    let mut guard = Guard { array_mut: &mut array, initialized: 0 };

    while let Some(item) = iter.next() {
        // SAFETY: `guard.initialized` starts at 0, is increased by one in the
        // loop and the loop is aborted once it reaches N (which is
        // `array.len()`).
        unsafe {
            guard.array_mut.get_unchecked_mut(guard.initialized).write(item);
        }
        guard.initialized += 1;

        // Check if the whole array was initialized.
        if guard.initialized == N {
            core::mem::forget(guard);

            // SAFETY: the condition above asserts that all elements are
            // initialized.
            let out = unsafe { MaybeUninit::array_assume_init(array) };
            return out;
        }
    }

    unreachable!("Something went wrong with iterator length")
}

#[cfg(not(feature = "nightly"))]
#[inline]
pub(crate) fn collect_into_array_unchecked<
    T,
    I : Iterator<Item=T>,
    const N: usize
>(iter: I) -> [T; N]
{
    match iter.collect::<Vec<T>>().try_into() {
        Ok(a) => a,
        Err(_) => panic!("collect_into_array failure: should not happen"),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn mapx_test() {
        let a = [0,1,2];
        let mut b = [2,1,0];
        assert_eq!(map1(a, |x| x+1), [1,2,3]);
        assert_eq!(map2(a, b, |x, y| x+y), [2,2,2]);
        assert_eq!(map1_indexed(a, |i, y| y-i), [0,0,0]);
        map1_indexed_mut(&mut b, |i, y| *y=i);
        assert_eq!(b, a);
    }
}
