plot.py

changeset 1
a4137aedcb3a
child 3
c3a4f4bb87f7
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/plot.py	Thu Feb 26 09:32:12 2026 -0500
@@ -0,0 +1,268 @@
+import os
+import re
+import sys
+from itertools import chain
+from pathlib import Path
+from statistics import median
+
+import matplotlib as mpl
+import matplotlib.cm as cm
+import matplotlib.colors as colors
+import matplotlib.pyplot as plt
+import matplotlib.tri as tri
+import numpy as np
+from dolfinx import fem
+
+from src.convection_diffusion import ConvectionDiffusion
+
+
+def find_files(directory):
+    """
+    Return a list of Path objects for files named 'fubar%d.npz'
+    in the given directory, sorted by the integer %d.
+    """
+    directory = Path(directory)
+    pattern = re.compile(r"^res_(\d+)\.npz$")
+
+    matches = []
+
+    for path in directory.iterdir():
+        if path.is_file():
+            m = pattern.match(path.name)
+            if m:
+                number = int(m.group(1))
+                matches.append((number, path))
+
+    # Sort numerically by extracted number
+    matches.sort(key=lambda x: x[0])
+
+    return matches
+
+
+def plot(prefix):
+    quantisation = 32
+
+    iter_files = find_files(prefix)
+
+    pde = ConvectionDiffusion(
+        u0=lambda x: 0.0 * x[1],
+        g=lambda x: 0.0 * x[1],  # ((x[1] + 2) / 4) ** 2,  # nonzero bcs
+        # w0=lambda x: np.cos(5 * x[1]),
+        domain_size=[0, 0.5, 0, 0.5],
+        # nx=128,
+        # ny=128,
+    )
+
+    # Get coordinates
+    coords = pde.V.tabulate_dof_coordinates()
+    x_coords, y_coords = coords[:, 0], coords[:, 1]
+    triang = tri.Triangulation(x_coords, y_coords)
+
+    m = 5
+    fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4))
+    plt.tight_layout()
+    fig.set_size_inches(13, 8)
+    fig.subplots_adjust(
+        left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2
+    )
+    for ax in chain(axus, axws, axjs, ax_wpbar_plus):
+        ax.set_xticklabels([])
+        ax.set_yticklabels([])
+        ax.set_aspect("equal")
+
+    mpl.rcParams["axes.labelsize"] = 9
+
+    norm = colors.Normalize(vmin=0, vmax=1)
+    sm = cm.ScalarMappable(norm=norm, cmap="viridis")
+    sm.set_array([])  # REQUIRED (matplotlib quirk)
+    colorbar_u = fig.colorbar(sm, ax=axus[-1])
+    colorbar_w = fig.colorbar(sm, ax=axws[-1])
+    colorbar_j = fig.colorbar(sm, ax=axjs[-1])
+
+    def do_plot(index_and_file, show_true=False):
+        (iter, file) = index_and_file
+
+        print("🎨 Plotting iteration %d..." % iter)
+
+        data0 = np.load("%s/omega_0.npz" % Path(prefix).parent)
+        ω0 = data0["omega0"]
+        true_μ = data0["true_mu"]
+        # true_ω = data0["true_omega"]
+        true_u_n_list_array = data0["true_u_n_list_array"]
+        ω0_min = ω0.real.min()
+        ω0_max = ω0.real.max()
+
+        for ax in chain(axus, axws, axjs, ax_wpbar_plus):
+            # ax.clear()
+            for artist in chain(ax.lines, ax.collections, ax.images):
+                artist.remove()
+
+        data = np.load(file)
+        u_n_list_array = data["u_n_list_array"]
+        w_n_list_array = data["w_n_list_array"]
+        j_n_list_array = data["j_n_list_array"]
+        # frames = data["frames"]
+        times = data["times"]
+        mu = data["mu"]
+        wp_bar_array = data["wp_bar_array"]
+
+        u_min = min(u.real.min() for u in u_n_list_array)
+        u_max = max(u.real.max() for u in u_n_list_array)
+        if show_true:
+            u_min = min(u_min, min(u.real.min() for u in true_u_n_list_array))
+            u_max = max(u_max, max(u.real.max() for u in true_u_n_list_array))
+
+        w_min = min(
+            min(w.real.min() for w in w_n_list_array), wp_bar_array.real.min(), ω0_min
+        )
+        w_max = max(
+            max(w.real.max() for w in w_n_list_array), wp_bar_array.real.max(), ω0_max
+        )
+        j_min = min(j.real.min() for j in j_n_list_array)
+        j_max = max(j.real.max() for j in j_n_list_array)
+
+        μ = list(filter(lambda x: x[2] != 0.0 and not np.isnan(x).any(), mu))
+        μ_x = list(map(lambda x: x[0], μ))
+        μ_y = list(map(lambda x: x[1], μ))
+        μ_alpha = list(map(lambda x: x[2], μ))
+        true_μ_x = list(map(lambda x: x[0], true_μ))
+        true_μ_y = list(map(lambda x: x[1], true_μ))
+        true_μ_alpha = list(map(lambda x: x[2], true_μ))
+
+        if len(μ_alpha) == 0:
+            alpha_mi, alpha_ma, alpha_me = 0, 0, 0
+        else:
+            alpha_mi, alpha_ma, alpha_me = min(μ_alpha), max(μ_alpha), median(μ_alpha)
+
+        true_alpha_mi, true_alpha_ma = min(true_μ_alpha), max(true_μ_alpha)
+        alpha_base = min(true_alpha_mi, alpha_mi)
+        alpha_scale = max(true_alpha_ma, alpha_ma) - alpha_base
+        if alpha_scale == 0:
+            ms = lambda m: 6
+        else:
+            ms = lambda m: int(1 + (m - alpha_base) / alpha_scale * 10)
+
+        def plot_array(
+            name, t_idx, ax, u_array, u_min, u_max, colorbar=None, measure=False
+        ):
+            if u_min == u_max:
+                levels = [u_min, u_min + 1e-9]
+            else:
+                levels = np.linspace(u_min, u_max, quantisation)
+            try:
+                contour = ax.tricontourf(triang, u_array, levels=levels, cmap="viridis")
+                if colorbar:
+                    colorbar.update_normal(contour)
+            except Exception as e:
+                print(e)
+            if measure:
+                for x, y, m in zip(μ_x, μ_y, μ_alpha):
+                    ax.plot([x], [y], "ro", markersize=ms(m), label="Sources")
+                for x, y, m in zip(true_μ_x, true_μ_y, true_μ_alpha):
+                    ax.plot([x], [y], "kx", markersize=ms(m), label="True sources")
+            if t_idx >= 0:
+                ax.set_title(f"%s; t = {t_idx:.1f}" % name)
+            else:
+                ax.set_title("%s" % name)
+            ax.set_aspect("equal")
+
+        n = len(u_n_list_array)
+        # frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
+        frames = range(0, n)
+        for i, axu, axw, axj, t_idx in zip(
+            frames,
+            axus,
+            axws,
+            axjs,
+            times,  # map(lambda i: i / (n - 1), frames),
+        ):
+            plot_array(
+                "u", t_idx, axu, u_n_list_array[i].real, u_min, u_max, colorbar_u, True
+            )
+            plot_array(
+                "w", t_idx, axw, w_n_list_array[i].real, w_min, w_max, colorbar_w, True
+            )
+            plot_array(
+                "j", t_idx, axj, j_n_list_array[i].real, j_min, j_max, colorbar_j, True
+            )
+
+        if show_true:
+            for i, ax, t_idx in zip(
+                frames,
+                ax_wpbar_plus,
+                map(lambda i: i / (n - 1), frames),
+            ):
+                plot_array(
+                    "û",
+                    t_idx,
+                    ax,
+                    true_u_n_list_array[i].real,
+                    u_min,
+                    u_max,
+                    colorbar_u,
+                    True,
+                )
+        else:
+            plot_array("w̄ₚ", -1, ax_wpbar_plus[0], wp_bar_array.real, w_min, w_max)
+            plot_array("ω₀", -1, ax_wpbar_plus[1], ω0.real, w_min, w_max)
+            # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max)
+
+        plt.suptitle(
+            "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f"
+            % (iter, len(μ), alpha_mi, alpha_ma, alpha_me),
+            fontsize=14,
+        )
+        # plt.savefig("solution_evolution_%d.png" % iter, dpi=300)
+
+    state = {"k": 0, "show_true": False}
+
+    def on_key(event):
+        k0 = state["k"]
+        k = None
+        if event.key == "right" or event.key == " ":
+            k = (k0 + 1) % len(iter_files)
+        elif event.key == "left" or event.key == "backspace":
+            k = (k0 - 1) % len(iter_files)
+        elif event.key == "shift+right":
+            k = (k0 + 10) % len(iter_files)
+        elif event.key == "shift+left":
+            k = (k0 - 10) % len(iter_files)
+        elif event.key == "up":
+            k = (k0 + 100) % len(iter_files)
+        elif event.key == "down":
+            k = (k0 - 100) % len(iter_files)
+        elif event.key == "0":
+            k = 0
+        elif event.key == "t":
+            state["show_true"] = not state["show_true"]
+            k = k0
+        elif event.key == "q":
+            sys.exit()
+        if k is not None:
+            state["k"] = k
+            do_plot(iter_files[k], state["show_true"])
+            fig.canvas.draw()
+
+    do_plot(iter_files[0])
+
+    fig.canvas.mpl_connect("key_press_event", on_key)
+
+    plt.show()
+    # # Time evolution
+    # fig, ax = plt.subplots(figsize=(10, 4))
+    # times = np.arange(len(u_n_list)) * pde.dt
+    # max_vals = [np.max(u.x.array.real) for u in u_n_list]
+
+    # ax.plot(times, max_vals, "r-o", linewidth=2, markersize=4)
+    # ax.set_xlabel("Time t")
+    # ax.set_ylabel("max|u|")
+    # ax.grid(True, alpha=0.3)
+    # ax.set_title("Solution Evolution")
+    # plt.tight_layout()
+    # plt.savefig("solution_time.png", dpi=150)
+    # plt.show()
+
+    # print("Saved: solution_evolution.png + solution_time.png")
+
+
+plot(sys.argv[1])

mercurial