/*!
Implementation of the surface of the 3D cube as a [`ManifoldPoint`].
*/

use serde_repr::*;
use serde::Serialize;
use alg_tools::loc::Loc;
use alg_tools::norms::{Norm, L2};
use crate::manifold::{ManifoldPoint, EmbeddedManifoldPoint};

/// All the difference faces of a [`OnCube`].
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize_repr, Deserialize_repr)]
#[repr(u8)]
pub enum Face {F1 = 1, F2 = 2, F3 = 3, F4 = 4, F5 = 5, F6 = 6}
use Face::*;

pub type Point = Loc<f64, 2>;

pub type AdjacentFaces = [Face; 4];

#[derive(Clone, Debug, Serialize)]
pub enum Path {
    Direct { destination : Face },
    Indirect { destination : Face, intermediate : Face },
}

/// An iterator over paths on a cube, from a source face to a destination face.
#[derive(Clone, Debug)]
pub enum PathIter {
    Direct(Face),
    Indirect{ destination : Face, intermediate : AdjacentFaces, current : usize},
    Exhausted,
}

impl std::iter::Iterator for PathIter {
    type Item = Path;
    
    fn next(&mut self) -> Option<Self::Item> {
        use PathIter::*;
        match self {
            &mut Exhausted => None,
            &mut Direct(destination) => {
                *self = Exhausted;
                Some(Path::Direct { destination })
            },
            &mut Indirect{destination, intermediate : ref i, ref mut current} => {
                if *current < i.len() {
                    let intermediate = i[*current];
                    *current += 1;
                    Some(Path::Indirect{ destination, intermediate })
                } else {
                    *self = Exhausted;
                    None
                }
            }
        }
    }
}

impl std::fmt::Display for Face {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let s = match *self {
            F1 => "F1",
            F2 => "F2",
            F3 => "F3",
            F4 => "F4",
            F5 => "F5",
            F6 => "F6",
        };
        write!(f, "{}", s)
    }
}

impl Face {
    /// Return an aray of all faces
    pub fn all() -> [Face; 6] {
        [F1, F2, F3, F4, F5, F6]
    }

    /// Returns an array of the four faces adjacent to `self` in the
    /// order [left, right, down, up] in the `self`-relative unfolding.
    pub fn adjacent_faces(&self) -> AdjacentFaces {
        match *self {
            F1 => [F3, F2, F4, F5],
            F2 => [F4, F5, F1, F6],
            F3 => [F5, F4, F1, F6],
            F4 => [F3, F2, F1, F6],
            F5 => [F2, F3, F1, F6],
            F6 => [F3, F2, F4, F5],
        }
    }

    /// Returns the face opposing `self`.
    pub fn opposing_face(&self) -> Face {
        match *self {
            F1 => F6,
            F2 => F3,
            F3 => F2,
            F4 => F5,
            F5 => F4,
            F6 => F1,
        }
    }

    /// Converts a point on an adjacent face to the coordinate system of `self`.
    pub fn convert_adjacent(&self, adjacent : Face, p: &Point) -> Option<Point> {
        let Loc([x, y]) = *p;
        let mk = |x, y| Some(Loc([x, y]));
        match adjacent {
            F1 =>  match *self {
                F2 => mk(y, x - 1.0),
                F3 => mk(1.0 - y, -x),
                F4 => mk(x, -y),
                F5 => mk(1.0 - x, y - 1.0),
                F1 => mk(x, y),
                F6 => None,
            },
            F2 => match *self {
                F1 => mk(y + 1.0, x),
                F4 => mk(x + 1.0, y),
                F5 => mk(x - 1.0, y),
                F6 => mk(2.0 - y, x),
                F2 => mk(x, y),
                F3 => None,
            },
            F3 => match *self {
                F1 => mk(-y, 1.0 - x),
                F4 => mk(x - 1.0, y),
                F5 => mk(x + 1.0, y),
                F6 => mk(y - 1.0, 1.0 - x),
                F3 => mk(x, y),
                F2 => None,
            },
            F4 => match *self {
                F1 => mk(x, -y),
                F2 => mk(x - 1.0, y),
                F3 => mk(x + 1.0, y),
                F6 => mk(x, y - 1.0),
                F4 => mk(x, y),
                F5 => None,
            },
            F5 => match *self {
                F1 => mk(1.0 -x, y + 1.0),
                F2 => mk(x + 1.0, y),
                F3 => mk(x - 1.0, y),
                F6 => mk(1.0 -x, 2.0 - y),
                F5 => mk(x, y),
                F4 => None,
            },
            F6 => match *self {
                F2 => mk(y, 2.0 - x),
                F3 => mk(1.0 - y, x + 1.0),
                F4 => mk(x, y + 1.0),
                F5 => mk(1.0 - x, 2.0 - y),
                F6 => mk(x, y),
                F1 => None,
            }
        }
    }

    /// Converts a point behind a path to the coordinate system of `self`.
    pub fn convert(&self, path : &Path, p: &Point) -> Point {
        use Path::*;
        //dbg!(*self, path);
        match path {
            &Direct{ destination : d} => self.convert_adjacent(d, p),
            &Indirect{ destination : d, intermediate : i }
                => {self.convert_adjacent(i, &i.convert_adjacent(d, p).unwrap())}
        }.unwrap()
    }


    /// Returns an iterator over all the paths from `self` to `other`.
    fn paths(&self, other : Face) -> PathIter {
        //dbg!(self, other);
        if self.opposing_face() == other {
            PathIter::Indirect {
                intermediate : self.adjacent_faces(),
                destination : other,
                current : 0
            }
        } else {
            PathIter::Direct(other)
        }
    }

    /// Indicates whether an unfolded point `p` is on this face, i.e.,
    /// has coordinates in [0,1]².
    pub fn is_in_face(&self, p: &Point) -> bool {
        p.iter().map(|t| t.abs()).all(|t| 0.0 <= t && t <= 1.0)
    }

    /// Given an unfolded point `p`, possibly outside this face, finds
    /// the edge, presented by an adjacent face, in whose direction it is.
    ///
    /// **TODO:** this does not correctly handle corners, i.e., when the point is not in
    /// the direction of an adjacent face.
    pub fn find_crossing(&self, p : &Point) -> Face {
        let &Loc([x, y]) = p;
        use std::cmp::Ordering::*;
        let crossing = |t| match (0.0 <= t, t<=1.0) {
            (false, _) => Less,
            (_, false) => Greater,
            _ => Equal,
        };

        // TODO: how to properly handle corners? Just throw an error?
        match (crossing(x), crossing(y)) {
            (Equal, Equal) => *self,
            (Less, _) => self.adjacent_faces()[0],
            (Greater, _) => self.adjacent_faces()[1],
            (Equal, Less) => self.adjacent_faces()[2],
            (Equal, Greater) => self.adjacent_faces()[3],
        }
    }

    /// Get embedded 3D coordinates
    pub fn embedded_coords(&self, p : &Point) -> Loc<f64, 3> {
        let &Loc([x, y]) = p;
        Loc(match *self {
            F1 => [x, y, 0.0],
            F2 => [1.0, x, y],
            F3 => [0.0, 1.0-x, y],
            F4 => [x, 0.0, y],
            F5 => [1.0 - x, 1.0, y],
            F6 => [x, y, 1.0],
        })
    }
}

#[derive(Clone, Debug, PartialEq, Serialize)]
pub struct OnCube {
    face : Face,
    point : Point,
}

impl OnCube {
    /// Creates a new point on the cube, given a face and face-relative coordinates
    /// in [0, 1]^2
    pub fn new(face : Face, point : Point) -> Self {
        assert!(face.is_in_face(&point));
        OnCube { face, point }
    }

    /// Calculates both the logarithmic map and distance to another point
    fn log_dist(&self, other : &Self) -> (<Self as ManifoldPoint>::Tangent, f64) {
        let mut best_len = f64::INFINITY;
        let mut best_tan = Loc([0.0, 0.0]);
        for path in self.face.paths(other.face) {
            let tan = self.face.convert(&path, &other.point) - &self.point;
            let len = tan.norm(L2);
            if len < best_len {
                best_tan = tan;
                best_len = len;
            }
        }
        (best_tan, best_len)
    }

    /// Returns the face of this point.
    pub fn face(&self) -> Face {
        self.face
    }
}


impl EmbeddedManifoldPoint for OnCube {
    type EmbeddedCoords = Loc<f64, 3>;

    /// Get embedded 3D coordinates
    fn embedded_coords(&self) -> Loc<f64, 3> {
        self.face.embedded_coords(&self.point)
    }
}

impl ManifoldPoint for OnCube {
    type Tangent = Point;

    fn exp(self, tangent : &Self::Tangent) -> Self {
        let mut face = self.face;
        let mut point = self.point + tangent;
        loop {
            let next_face = face.find_crossing(&point);
            if next_face == face {
                break
            }
            point = next_face.convert_adjacent(face, &point).unwrap();
            face = next_face;
        }
        OnCube { face, point }
    }

    fn log(&self, other : &Self) -> Self::Tangent {
        self.log_dist(other).0
    }

    fn dist_to(&self, other : &Self) -> f64 {
        self.log_dist(other).1
    }

    fn tangent_origin(&self) -> Self::Tangent {
        Loc([0.0, 0.0])
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn convert_adjacent() {
        let point = Loc([0.4, 0.6]);

        for f1 in [F1, F2, F3, F4, F5, F6] {
            for f2 in [F1, F2, F3, F4, F5, F6] {
                println!("{:?}-{:?}", f1, f2);
                match f1.convert_adjacent(f2, &point) {
                    None => assert_eq!(f2.opposing_face(), f1),
                    Some(q) => {
                        match f2.convert_adjacent(f1, &q) {
                            None => assert_eq!(f1.opposing_face(), f2),
                            Some(p) => assert!((p-&point).norm(L2) < 1e-9),
                        }
                    }
                }
            }
        }
    }

    // This will fail, as different return path does not guarantee
    // that a point outside the face will be returned to its point of origin.
    // #[test]
    // fn convert_paths() {
    //     let point = Loc([0.4, 0.6]);

    //     for f1 in [F1, F2, F3, F4, F5, F6] {
    //         for f2 in [F1, F2, F3, F4, F5, F6] {
    //             for p1 in f2.paths(f1) {
    //                 for p2 in f1.paths(f2) {
    //                     println!("{:?}-{:?}; {:?} {:?}", f1, f2, p1, p2);
    //                     let v = &f2.convert(&p1, &point);
    //                     let q = f1.convert(&p2, v);
    //                     assert!((q-&point).norm(L2) < 1e-9,
    //                             "norm({}-{}) ≥ 1e-9 (dest {})", q, &point, &v);
    //                 }
    //             }
    //         }
    //     }
    // }

    #[test]
    fn log_adjacent() {
        let p1 = OnCube{ face : F1, point : Loc([0.5, 0.5])};
        let p2 = OnCube{ face : F2, point : Loc([0.5, 0.5])};

        assert_eq!(p1.log(&p2).norm(L2), 1.0);
    }

    #[test]
    fn log_opposing_equal() {
        let p1 = OnCube{ face : F1, point : Loc([0.5, 0.5])};
        let p2 = OnCube{ face : F6, point : Loc([0.5, 0.5])};

        assert_eq!(p1.log(&p2).norm(L2), 2.0);
    }

    #[test]
    fn log_opposing_unique_shortest() {
        let p1 = OnCube{ face : F1, point : Loc([0.3, 0.25])};
        let p2 = OnCube{ face : F6, point : Loc([0.3, 0.25])};

        assert_eq!(p1.log(&p2).norm(L2), 1.5);
    }
}

