#!.venv/bin/python3
import argparse
import os
import re
import sys
from itertools import chain
from pathlib import Path
from statistics import median

import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import numpy as np
from dolfinx import fem
from src.convection_diffusion import ConvectionDiffusion


def find_files(dirname):
    """
    Return a list of Path objects for files named 'fubar%d.npz'
    in the given directory, sorted by the integer %d.
    """
    directory = Path(dirname)
    pattern = re.compile(r"^res_(\d+)\.npz$")

    matches = []

    for path in directory.iterdir():
        if path.is_file():
            m = pattern.match(path.name)
            if m:
                number = int(m.group(1))
                matches.append((number, path))

    # Sort numerically by extracted number
    matches.sort(key=lambda x: x[0])

    return matches


def plot(prefix, simple=False, shaded=False, init_iter=None, save_only=False):
    quantisation = 32

    iter_files = find_files(prefix)

    pde = ConvectionDiffusion(
        u0=lambda x: 0.0 * x[1],
        g=lambda x: 0.0 * x[1],  # ((x[1] + 2) / 4) ** 2,  # nonzero bcs
        # w0=lambda x: np.cos(5 * x[1]),
        domain_size=[0, 0.5, 0, 0.5],
        # nx=128,
        # ny=128,
    )

    # Get coordinates
    coords = pde.V.tabulate_dof_coordinates()
    x_coords, y_coords = coords[:, 0], coords[:, 1]
    triang = tri.Triangulation(x_coords, y_coords)

    m = 5
    if simple:
        fig, axus = plt.subplots(1, m)
        all_axes = lambda: axus
        fig.set_layout_engine("compressed")
    else:
        fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4))
        all_axes = lambda: chain(axus, axws, axjs, ax_wpbar_plus)
        plt.tight_layout(pad=0.01)
        fig.subplots_adjust(
            left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2
        )

    fig.set_size_inches(13, 8)

    if not simple:
        info = plt.text(
            1.0,
            -0.07,
            "",
            horizontalalignment="right",
        )

    mpl.rcParams["axes.labelsize"] = 9

    norm = colors.Normalize(vmin=0, vmax=1)
    sm = cm.ScalarMappable(norm=norm, cmap="viridis")
    sm.set_array([])  # REQUIRED (matplotlib quirk)
    colorbar_u = fig.colorbar(sm, ax=axus[-1])
    if not simple:
        colorbar_w = fig.colorbar(sm, ax=axws[-1])
        colorbar_j = fig.colorbar(sm, ax=axjs[-1])

    def do_plot(index_and_file, show_true=False):

        (iter, file) = index_and_file

        print("🎨 Plotting iteration %d..." % iter)

        data0 = np.load("%s/omega_0.npz" % Path(prefix).parent)
        ω0 = data0["omega0"]
        true_μ = data0["true_mu"]
        # true_ω = data0["true_omega"]
        true_u_n_list_array = data0["true_u_n_list_array"]
        ω0_min = ω0.real.min()
        ω0_max = ω0.real.max()

        data = np.load(file)
        u_n_list_array = data["u_n_list_array"]
        w_n_list_array = data["w_n_list_array"]
        j_n_list_array = data["j_n_list_array"]
        # frames = data["frames"]
        times = data["times"]
        mu = data["mu"]
        wp_bar_array = data["wp_bar_array"]

        u_min = min(u.real.min() for u in u_n_list_array)
        u_max = max(u.real.max() for u in u_n_list_array)
        if show_true:
            u_min = min(u_min, min(u.real.min() for u in true_u_n_list_array))
            u_max = max(u_max, max(u.real.max() for u in true_u_n_list_array))

        w_min = min(
            min(w.real.min() for w in w_n_list_array), wp_bar_array.real.min(), ω0_min
        )
        w_max = max(
            max(w.real.max() for w in w_n_list_array), wp_bar_array.real.max(), ω0_max
        )
        j_min = min(j.real.min() for j in j_n_list_array)
        j_max = max(j.real.max() for j in j_n_list_array)

        μ = list(filter(lambda x: x[2] != 0.0 and not np.isnan(x).any(), mu))
        μ_alpha = list(map(lambda x: x[2], μ))
        true_μ_alpha = list(map(lambda x: x[2], true_μ))

        if len(μ_alpha) == 0:
            alpha_mi, alpha_ma, alpha_me = 0, 0, 0
        else:
            alpha_mi, alpha_ma, alpha_me = min(μ_alpha), max(μ_alpha), median(μ_alpha)

        true_alpha_mi, true_alpha_ma = min(true_μ_alpha), max(true_μ_alpha)
        alpha_base = min(true_alpha_mi, alpha_mi)
        alpha_scale = max(true_alpha_ma, alpha_ma) - alpha_base
        if alpha_scale == 0:
            ms = lambda m: 6
        else:
            ms = lambda m: int(1 + (m - alpha_base) / alpha_scale * 10)

        def plot_array(
            name,
            t_idx,
            ax,
            u_array,
            u_min,
            u_max,
            μ,
            true_μ,
            colorbar=None,
            measure=False,
        ):
            if u_min == u_max:
                levels = [u_min, u_min + 1e-9]
            else:
                levels = np.linspace(u_min, u_max, quantisation)
            try:
                if shaded:
                    concentration = ax.tripcolor(
                        triang,
                        u_array,
                        cmap="viridis",
                        shading="gouraud",
                    )
                else:
                    concentration = ax.tricontourf(
                        triang,
                        u_array,
                        levels=levels,
                        cmap="viridis",
                        antialiased=False,
                        zorder=-10,
                    )
                    ax.set_rasterization_zorder(0)

                if colorbar:
                    colorbar.update_normal(concentration)
                    fmt = mpl.ticker.ScalarFormatter(useOffset=False)
                    fmt.set_scientific(False)
                    colorbar.ax.yaxis.set_major_formatter(fmt)
                    colorbar.update_ticks()
                    colorbar.ax.yaxis.get_offset_text().set_visible(False)
                    off = colorbar.ax.yaxis.get_offset_text()
                    off.set_x(10)

            except Exception as e:
                print(e)
            if μ is not None:
                μ_x = list(map(lambda x: x[0], μ))
                μ_y = list(map(lambda x: x[1], μ))
                μ_alpha = list(map(lambda x: x[2], μ))
                for x, y, m in zip(μ_x, μ_y, μ_alpha):
                    ax.plot([x], [y], "ro", markersize=ms(m), label="Sources")
            if true_μ is not None:
                true_μ_x = list(map(lambda x: x[0], true_μ))
                true_μ_y = list(map(lambda x: x[1], true_μ))
                true_μ_alpha = list(map(lambda x: x[2], true_μ))
                for x, y, m in zip(true_μ_x, true_μ_y, true_μ_alpha):
                    ax.plot([x], [y], "wx", markersize=ms(m), label="True sources")
            if t_idx >= 0:
                ax.set_title(
                    f"$%s_{{{t_idx:.1f}}}$" % name,
                    fontsize=20 if simple else 14,
                )
            else:
                ax.set_title(
                    "%s" % name,
                    fontsize=20 if simple else 14,
                )
            ax.set_aspect("equal")

        n = len(u_n_list_array)
        # frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
        frames = range(0, n)

        for ax in all_axes():
            ax.clear()

        if not simple:
            for i, axu, axw, axj, t_idx in zip(
                frames,
                axus,
                axws,
                axjs,
                times,  # map(lambda i: i / (n - 1), frames),
            ):
                plot_array(
                    "u",
                    t_idx,
                    axu,
                    u_n_list_array[i].real,
                    u_min,
                    u_max,
                    μ,
                    true_μ,
                    colorbar_u,
                )
                plot_array(
                    "w",
                    t_idx,
                    axw,
                    w_n_list_array[i].real,
                    w_min,
                    w_max,
                    μ,
                    true_μ,
                    colorbar_w,
                )
                plot_array(
                    "j",
                    t_idx,
                    axj,
                    j_n_list_array[i].real,
                    j_min,
                    j_max,
                    μ,
                    true_μ,
                    colorbar_j,
                )

            if show_true:
                for i, ax, t_idx in zip(
                    frames,
                    ax_wpbar_plus,
                    map(lambda i: i / (n - 1), frames),
                ):
                    plot_array(
                        "û",
                        t_idx,
                        ax,
                        true_u_n_list_array[i].real,
                        u_min,
                        u_max,
                        μ,
                        true_μ,
                        colorbar_u,
                    )
            else:
                plot_array(
                    "w̄ₚ",
                    -1,
                    ax_wpbar_plus[0],
                    wp_bar_array.real,
                    w_min,
                    w_max,
                    μ,
                    None,
                )
                plot_array(
                    "ω₀",
                    -1,
                    ax_wpbar_plus[1],
                    ω0.real,
                    w_min,
                    w_max,
                    μ,
                    None,
                )
                # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max)

            # Has to be here after everything, for some reason
            for ax in all_axes():
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                ax.set_aspect("equal")

            plt.suptitle(
                "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f;"
                % (
                    iter,
                    len(μ),
                    alpha_mi,
                    alpha_ma,
                    alpha_me,
                ),
                fontsize=14,
            )

            str = "(k, c, c1) = (%g, %g, %g)" % (data["k"], data["c1"], data["c2"])
            if "k" in data0:
                str += "; true (%g, %g, %g)" % (data0["k"], data0["c1"], data0["c2"])
            info.set_text(str)
        else:  # simple
            for i, axu, t_idx in zip(
                frames,
                axus,
                times,  # map(lambda i: i / (n - 1), frames),
            ):
                if show_true:
                    plot_array(
                        "û",
                        t_idx,
                        axu,
                        true_u_n_list_array[i].real,
                        u_min,
                        u_max,
                        None,
                        true_μ,
                        colorbar_u,
                    )
                else:
                    plot_array(
                        "u",
                        t_idx,
                        axu,
                        u_n_list_array[i].real,
                        u_min,
                        u_max,
                        μ,
                        true_μ,
                        colorbar_u,
                    )

    def do_save(iter, show_true):
        if not show_true:
            filename = "%s/res_%d.pdf" % (prefix, iter)
        else:
            filename = "%s/true_%d.pdf" % (prefix, iter)
        print("Saving ", filename)
        plt.savefig(
            filename,
            dpi=300,
            bbox_inches="tight",
            pad_inches=0,
        )

    if init_iter is None:
        k = 0
    else:
        k, _ = next(
            filter(lambda fubar: fubar[1][0] == init_iter, enumerate(iter_files))
        )

    do_plot(iter_files[k], show_true=False)

    if save_only:
        iter = iter_files[k][0]
        do_save(iter, False)
        do_plot(iter_files[k], show_true=True)
        do_save(iter, True)
        return

    state = {"k": k, "show_true": False}

    def on_key(event):
        k0 = state["k"]
        k = None
        if event.key == "right" or event.key == " ":
            k = (k0 + 1) % len(iter_files)
        elif event.key == "left" or event.key == "backspace":
            k = (k0 - 1) % len(iter_files)
        elif event.key == "shift+right":
            k = (k0 + 10) % len(iter_files)
        elif event.key == "shift+left":
            k = (k0 - 10) % len(iter_files)
        elif event.key == "up":
            k = (k0 + 100) % len(iter_files)
        elif event.key == "down":
            k = (k0 - 100) % len(iter_files)
        elif event.key == "0":
            k = 0
        elif event.key == "t":
            state["show_true"] = not state["show_true"]
            k = k0
        elif event.key == "q":
            sys.exit()
        elif event.key == "z":
            iter, _file = iter_files[k0]
            do_save(iter, state["show_true"])
        if k is not None:
            state["k"] = k
            do_plot(iter_files[k], state["show_true"])
            fig.canvas.draw()

    fig.canvas.mpl_connect("key_press_event", on_key)

    plt.show()
    # # Time evolution
    # fig, ax = plt.subplots(figsize=(10, 4))
    # times = np.arange(len(u_n_list)) * pde.dt
    # max_vals = [np.max(u.x.array.real) for u in u_n_list]

    # ax.plot(times, max_vals, "r-o", linewidth=2, markersize=4)
    # ax.set_xlabel("Time t")
    # ax.set_ylabel("max|u|")
    # ax.grid(True, alpha=0.3)
    # ax.set_title("Solution Evolution")
    # plt.tight_layout()
    # plt.savefig("solution_time.png", dpi=150)
    # plt.show()

    # print("Saved: solution_evolution.png + solution_time.png")


parser = argparse.ArgumentParser(
    prog="plot.py",
    description="Plots results of pointsource_pde",
)

parser.add_argument("filename")
parser.add_argument("--simple", action="store_true")
parser.add_argument("--shaded", action="store_true")
parser.add_argument("--iter", type=int, action="store")
parser.add_argument("--save-only", action="store_true")

args = parser.parse_args()

plot(args.filename, args.simple, args.shaded, args.iter, args.save_only)
