/*!
Implementation of the surface of the 3D cube as a [`ManifoldPoint`].
*/

use serde_repr::*;
use serde::Serialize;
use alg_tools::loc::Loc;
use alg_tools::norms::{Norm, L2};
use crate::manifold::{EmbeddedManifoldPoint, FacedManifoldPoint, ManifoldPoint};

/// All the difference faces of a [`OnCube`].
#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize_repr, Deserialize_repr)]
#[repr(u8)]
pub enum Face {F1 = 1, F2 = 2, F3 = 3, F4 = 4, F5 = 5, F6 = 6}
use Face::*;

/// General point in 2D
pub type Point = Loc<f64, 2>;

/// Types for faces adjacent to a given face.
pub type AdjacentFaces = [Face; 4];

/// Types of paths on a cube
#[derive(Clone, Debug, Serialize)]
pub enum Path {
    /// Direct path from an unindicated source face to a `destination` face.
    Direct { destination : Face },
    /// Indirect path from an unindicated source face to a `destination` face,
    /// via an `intermediate` face.
    Indirect { destination : Face, intermediate : Face },
}

/// An iterator over paths on a cube, from a source face to a destination face.
#[derive(Clone, Debug)]
pub enum PathIter {
    /// Direct path to a destination.
    Same {
        /// Deistination face
        destination : Face,
        /// Indicator whether the only possible [`Path::Direct`] has already been returned.
        exhausted : bool
    },
    /// Path via several possible intermedite faces.
    /// This is used to generate several [`Path::Indirect`].
    Indirect {
        /// Destination face
        destination : Face,
        /// Possible intermediate faces
        intermediate : AdjacentFaces,
        /// Intermediate face index counter.
        current : usize
    }
}

impl std::iter::Iterator for PathIter {
    type Item = Path;
    
    fn next(&mut self) -> Option<Self::Item> {
        match *self {
            PathIter::Same { destination, ref mut exhausted } => {
                if !*exhausted {
                    *exhausted = true;
                    return Some(Path::Direct { destination })
                }
                None
            },
            PathIter::Indirect { destination, intermediate : ref i, ref mut current } => {
                while *current < i.len() {
                    let intermediate = i[*current];
                    *current += 1;
                    if intermediate == destination {
                        return Some(Path::Direct { destination })
                    } else if intermediate != destination.opposing_face() {
                        return Some(Path::Indirect{ destination, intermediate })
                    }
                    // Paths should never go through a face opposing the destination.
                }
                None
            }
        }
    }
}

impl std::fmt::Display for Face {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let s = match *self {
            F1 => "F1",
            F2 => "F2",
            F3 => "F3",
            F4 => "F4",
            F5 => "F5",
            F6 => "F6",
        };
        write!(f, "{}", s)
    }
}

impl Face {
    /// Return an aray of all faces
    pub fn all() -> [Face; 6] {
        [F1, F2, F3, F4, F5, F6]
    }

    /// Returns an array of the four faces adjacent to `self` in the
    /// order [left, right, down, up] in the `self`-relative unfolding.
    pub fn adjacent_faces(&self) -> AdjacentFaces {
        match *self {
            F1 => [F3, F2, F4, F5],
            F2 => [F4, F5, F1, F6],
            F3 => [F5, F4, F1, F6],
            F4 => [F3, F2, F1, F6],
            F5 => [F2, F3, F1, F6],
            F6 => [F3, F2, F4, F5],
        }
    }

    /// Returns the face opposing `self`.
    pub fn opposing_face(&self) -> Face {
        match *self {
            F1 => F6,
            F2 => F3,
            F3 => F2,
            F4 => F5,
            F5 => F4,
            F6 => F1,
        }
    }

    /// Converts a point on an adjacent face to the coordinate system of `self`.
    pub fn convert_adjacent(&self, adjacent : Face, p: &Point) -> Option<Point> {
        let Loc([x, y]) = *p;
        let mk = |x, y| Some(Loc([x, y]));
        match adjacent {
            F1 =>  match *self {
                F2 => mk(y, x - 1.0),
                F3 => mk(1.0 - y, -x),
                F4 => mk(x, -y),
                F5 => mk(1.0 - x, y - 1.0),
                F1 => mk(x, y),
                F6 => None,
            },
            F2 => match *self {
                F1 => mk(y + 1.0, x),
                F4 => mk(x + 1.0, y),
                F5 => mk(x - 1.0, y),
                F6 => mk(2.0 - y, x),
                F2 => mk(x, y),
                F3 => None,
            },
            F3 => match *self {
                F1 => mk(-y, 1.0 - x),
                F4 => mk(x - 1.0, y),
                F5 => mk(x + 1.0, y),
                F6 => mk(y - 1.0, 1.0 - x),
                F3 => mk(x, y),
                F2 => None,
            },
            F4 => match *self {
                F1 => mk(x, -y),
                F2 => mk(x - 1.0, y),
                F3 => mk(x + 1.0, y),
                F6 => mk(x, y - 1.0),
                F4 => mk(x, y),
                F5 => None,
            },
            F5 => match *self {
                F1 => mk(1.0 -x, y + 1.0),
                F2 => mk(x + 1.0, y),
                F3 => mk(x - 1.0, y),
                F6 => mk(1.0 -x, 2.0 - y),
                F5 => mk(x, y),
                F4 => None,
            },
            F6 => match *self {
                F2 => mk(y, 2.0 - x),
                F3 => mk(1.0 - y, x + 1.0),
                F4 => mk(x, y + 1.0),
                F5 => mk(1.0 - x, 2.0 - y),
                F6 => mk(x, y),
                F1 => None,
            }
        }
    }

    /// Converts a point behind a path to the coordinate system of `self`.
    pub fn convert(&self, path : &Path, p: &Point) -> Point {
        use Path::*;
        //dbg!(*self, path);
        match path {
            &Direct{ destination : d} => self.convert_adjacent(d, p),
            &Indirect{ destination : d, intermediate : i }
                => {self.convert_adjacent(i, &i.convert_adjacent(d, p).unwrap())}
        }.unwrap()
    }


    /// Returns an iterator over all the paths from `self` to `other`.
    fn paths(&self, other : Face) -> PathIter {
        if other == *self {
            PathIter::Same {
                destination : other,
                exhausted : false
            }
        } else {
            PathIter::Indirect {
                intermediate : self.adjacent_faces(),
                destination : other,
                current : 0
            }
        }
    }

    /// Indicates whether an unfolded point `p` is on this face, i.e.,
    /// has coordinates in $\[0,1\]^2$.
    pub fn is_in_face(&self, p: &Point) -> bool {
        p.iter().all(|t| 0.0 <= *t && *t <= 1.0)
    }

    /// Given an unfolded point `p` and a destination point `d` in unfolded coordinates,
    /// but possibly outside this face, find the crossing point of the line between
    /// `p` and `d` on an edge of (`self`). Return the point and the edge presented
    /// by an adjacent face.
    ///
    /// Crossing at corners is decided arbitrarily.
    pub fn find_crossing(&self, p :& Point, d : &Point) -> (Face, Point) {
        //assert!(self.is_in_face(p));
        
        if self.is_in_face(d) {
            return (*self, *p)
        }

        use std::cmp::Ordering::*;

        let &Loc([x, y]) = p;
        let &Loc([xd, yd]) = d;
        let tx = xd - x;
        let ty = yd - y;

        // Move towards tangent as (x + s tx, y + s ty) for the largest s<=1.0 for which
        // both coordinates is within [0, 1]. Also gives the direction of move along
        // each coordinate.
        let (sx, dirx) = match tx.partial_cmp(&0.0) {
            Some(Less) =>  (1.0f64.min(-x/tx), Less),
            Some(Greater) => (1.0f64.min((1.0-x)/tx), Greater),
            _ => (1.0, Equal)
        };
        let (sy, diry) = match ty.partial_cmp(&0.0) {
            Some(Less) => (1.0f64.min(-y/ty), Less),
            Some(Greater) => (1.0f64.min((1.0-y)/ty), Greater),
            _ => (1.0, Equal),
        };

        // TODO: how to properly handle corners? Just throw an error?
        let (crossing, c) = match (sx < sy, dirx, diry) {
            // x move is less than y move, so crossing is either on left or right edge
            (true, Less, _)      => (self.adjacent_faces()[0], sx),
            (true, Greater, _)   => (self.adjacent_faces()[1], sx),
            (true, Equal, _)     => (*self, sx),
            // y move is less than x move, so crossing is either on bottom or top edge
            (false, _, Less)     => (self.adjacent_faces()[2], sy),
            (false, _, Greater)  => (self.adjacent_faces()[3], sy),
            (false, _, Equal)    => (*self, sy),
        };
        (crossing, Loc([x + c*tx, y + c*ty]))
    }

    /// Get embedded 3D coordinates
    pub fn embedded_coords(&self, p : &Point) -> Loc<f64, 3> {
        let &Loc([x, y]) = p;
        Loc(match *self {
            F1 => [x, y, 0.0],
            F2 => [1.0, x, y],
            F3 => [0.0, 1.0-x, y],
            F4 => [x, 0.0, y],
            F5 => [1.0 - x, 1.0, y],
            F6 => [x, y, 1.0],
        })
    }
}

/// Point on a the surface of the unit cube $\[0,1\]^3$.
#[derive(Clone, Debug, PartialEq, Serialize)]
pub struct OnCube {
    face : Face,
    point : Point,
}

impl OnCube {
    /// Creates a new point on the cube, given a face and face-relative coordinates
    /// in [0, 1]^2
    pub fn new(face : Face, point : Point) -> Self {
        assert!(face.is_in_face(&point));
        OnCube { face, point }
    }

    /// Calculates both the logarithmic map and distance to another point
    fn log_dist(&self, other : &Self) -> (<Self as ManifoldPoint>::Tangent, f64) {
        let mut best_len = f64::INFINITY;
        let mut best_tan = Loc([0.0, 0.0]);
        for path in self.face.paths(other.face) {
            let tan = self.face.convert(&path, &other.point) - &self.point;
            let len = tan.norm(L2);
            if len < best_len {
                best_tan = tan;
                best_len = len;
            }
        }
        (best_tan, best_len)
    }
}

impl FacedManifoldPoint for OnCube {
    type Face = Face;
    /// Returns the face of this point.
    fn face(&self) -> Face {
        self.face
    }
}


impl EmbeddedManifoldPoint for OnCube {
    type EmbeddedCoords = Loc<f64, 3>;

    /// Get embedded 3D coordinates
    fn embedded_coords(&self) -> Loc<f64, 3> {
        self.face.embedded_coords(&self.point)
    }
}

impl ManifoldPoint for OnCube {
    type Tangent = Point;

    fn exp(self, tangent : &Self::Tangent) -> Self {
        let mut face = self.face;
        let mut point = self.point;
        let mut dest = self.point + tangent;
        loop {
            let (next_face, cross) = face.find_crossing(&point, &dest);
            if next_face == face {
                break
            }
            point = next_face.convert_adjacent(face, &cross).unwrap();
            dest = next_face.convert_adjacent(face, &dest).unwrap();
            face = next_face;
        }
        OnCube { face, point : dest }
    }

    fn log(&self, other : &Self) -> Self::Tangent {
        self.log_dist(other).0
    }

    fn dist_to(&self, other : &Self) -> f64 {
        self.log_dist(other).1
    }

    fn tangent_origin(&self) -> Self::Tangent {
        Loc([0.0, 0.0])
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Tests that the distancse between the centers of all the faces are correctly calculated.
    #[test]
    fn center_distance() {
        let center = Loc([0.5, 0.5]);

        for f1 in Face::all() {
            let p1 = OnCube { face : f1, point : center.clone() };
            for f2 in Face::all() {
                let p2 = OnCube { face : f2, point : center.clone() };
                if f1==f2 {
                    assert_eq!(p1.dist_to(&p2), 0.0);
                } else if f1.opposing_face()==f2 {
                    assert_eq!(p1.dist_to(&p2), 2.0);
                } else {
                    assert_eq!(p1.dist_to(&p2), 1.0);
                }
            }
        }
    }

    /// Tests that the distances between points on the boundaries of distinct faces are
    /// correctly calculated.
    #[test]
    fn boundary_distance() {
        let left = Loc([0.0, 0.5]);
        let right = Loc([1.0, 0.5]);
        let down = Loc([0.5, 0.0]);
        let up = Loc([0.5, 1.0]);
        let center = Loc([0.5, 0.5]);

        for f1 in Face::all() {
            let pl = OnCube { face : f1, point : left.clone() };
            let pr = OnCube { face : f1, point : right.clone() };
            let pd = OnCube { face : f1, point : down.clone() };
            let pu = OnCube { face : f1, point : up.clone() };
            let a = f1.adjacent_faces();
            let al = OnCube { face : a[0], point : center.clone() };
            let ar = OnCube { face : a[1], point : center.clone() };
            let ad = OnCube { face : a[2], point : center.clone() };
            let au = OnCube { face : a[3], point : center.clone() };
            let ao = OnCube { face : f1.opposing_face(), point : center.clone() };

            assert_eq!(pl.dist_to(&al), 0.5);
            assert_eq!(pr.dist_to(&ar), 0.5);
            assert_eq!(pd.dist_to(&ad), 0.5);
            assert_eq!(pu.dist_to(&au), 0.5);
            assert_eq!(pl.dist_to(&ao), 1.5);
            assert_eq!(pr.dist_to(&ao), 1.5);
            assert_eq!(pd.dist_to(&ao), 1.5);
            assert_eq!(pu.dist_to(&ao), 1.5);
        }
    }


    /// Tests that the conversions between the coordinate systems of each face is working correctly.
    #[test]
    fn convert_adjacent() {
        let point = Loc([0.4, 0.6]);

        for f1 in Face::all() {
            for f2 in Face::all() {
                println!("{:?}-{:?}", f1, f2);
                match f1.convert_adjacent(f2, &point) {
                    None => assert_eq!(f2.opposing_face(), f1),
                    Some(q) => {
                        match f2.convert_adjacent(f1, &q) {
                            None => assert_eq!(f1.opposing_face(), f2),
                            Some(p) => assert!((p-&point).norm(L2) < 1e-9),
                        }
                    }
                }
            }
        }
    }

    // This will fail, as different return path does not guarantee
    // that a point outside the face will be returned to its point of origin.
    // #[test]
    // fn convert_paths() {
    //     let point = Loc([0.4, 0.6]);

    //     for f1 in Face::all() {
    //         for f2 in Face::all() {
    //             for p1 in f2.paths(f1) {
    //                 for p2 in f1.paths(f2) {
    //                     println!("{:?}-{:?}; {:?} {:?}", f1, f2, p1, p2);
    //                     let v = &f2.convert(&p1, &point);
    //                     let q = f1.convert(&p2, v);
    //                     assert!((q-&point).norm(L2) < 1e-9,
    //                             "norm({}-{}) ≥ 1e-9 (dest {})", q, &point, &v);
    //                 }
    //             }
    //         }
    //     }
    // }

    /// Tests that the logarithmic map is working correctly between adjacent faces.
    #[test]
    fn log_adjacent() {
        let p1 = OnCube{ face : F1, point : Loc([0.5, 0.5])};
        let p2 = OnCube{ face : F2, point : Loc([0.5, 0.5])};

        assert_eq!(p1.log(&p2).norm(L2), 1.0);
    }

    /// Tests that the logarithmic map is working correctly between opposing faces.
    #[test]
    fn log_opposing_equal() {
        let p1 = OnCube{ face : F1, point : Loc([0.5, 0.5])};
        let p2 = OnCube{ face : F6, point : Loc([0.5, 0.5])};

        assert_eq!(p1.log(&p2).norm(L2), 2.0);
    }

    /// Tests that the logarithmic map is working correctly between opposing faces when there
    /// is a unique shortest geodesic.
    #[test]
    fn log_opposing_unique_shortest() {
        let p1 = OnCube{ face : F1, point : Loc([0.3, 0.25])};
        let p2 = OnCube{ face : F6, point : Loc([0.3, 0.25])};

        assert_eq!(p1.log(&p2).norm(L2), 1.5);
    }
}

