diff -r 7ec1cfe19a24 -r a4137aedcb3a plot.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/plot.py Thu Feb 26 09:32:12 2026 -0500 @@ -0,0 +1,268 @@ +import os +import re +import sys +from itertools import chain +from pathlib import Path +from statistics import median + +import matplotlib as mpl +import matplotlib.cm as cm +import matplotlib.colors as colors +import matplotlib.pyplot as plt +import matplotlib.tri as tri +import numpy as np +from dolfinx import fem + +from src.convection_diffusion import ConvectionDiffusion + + +def find_files(directory): + """ + Return a list of Path objects for files named 'fubar%d.npz' + in the given directory, sorted by the integer %d. + """ + directory = Path(directory) + pattern = re.compile(r"^res_(\d+)\.npz$") + + matches = [] + + for path in directory.iterdir(): + if path.is_file(): + m = pattern.match(path.name) + if m: + number = int(m.group(1)) + matches.append((number, path)) + + # Sort numerically by extracted number + matches.sort(key=lambda x: x[0]) + + return matches + + +def plot(prefix): + quantisation = 32 + + iter_files = find_files(prefix) + + pde = ConvectionDiffusion( + u0=lambda x: 0.0 * x[1], + g=lambda x: 0.0 * x[1], # ((x[1] + 2) / 4) ** 2, # nonzero bcs + # w0=lambda x: np.cos(5 * x[1]), + domain_size=[0, 0.5, 0, 0.5], + # nx=128, + # ny=128, + ) + + # Get coordinates + coords = pde.V.tabulate_dof_coordinates() + x_coords, y_coords = coords[:, 0], coords[:, 1] + triang = tri.Triangulation(x_coords, y_coords) + + m = 5 + fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4)) + plt.tight_layout() + fig.set_size_inches(13, 8) + fig.subplots_adjust( + left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2 + ) + for ax in chain(axus, axws, axjs, ax_wpbar_plus): + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_aspect("equal") + + mpl.rcParams["axes.labelsize"] = 9 + + norm = colors.Normalize(vmin=0, vmax=1) + sm = cm.ScalarMappable(norm=norm, cmap="viridis") + sm.set_array([]) # REQUIRED (matplotlib quirk) + colorbar_u = fig.colorbar(sm, ax=axus[-1]) + colorbar_w = fig.colorbar(sm, ax=axws[-1]) + colorbar_j = fig.colorbar(sm, ax=axjs[-1]) + + def do_plot(index_and_file, show_true=False): + (iter, file) = index_and_file + + print("🎨 Plotting iteration %d..." % iter) + + data0 = np.load("%s/omega_0.npz" % Path(prefix).parent) + ω0 = data0["omega0"] + true_μ = data0["true_mu"] + # true_ω = data0["true_omega"] + true_u_n_list_array = data0["true_u_n_list_array"] + ω0_min = ω0.real.min() + ω0_max = ω0.real.max() + + for ax in chain(axus, axws, axjs, ax_wpbar_plus): + # ax.clear() + for artist in chain(ax.lines, ax.collections, ax.images): + artist.remove() + + data = np.load(file) + u_n_list_array = data["u_n_list_array"] + w_n_list_array = data["w_n_list_array"] + j_n_list_array = data["j_n_list_array"] + # frames = data["frames"] + times = data["times"] + mu = data["mu"] + wp_bar_array = data["wp_bar_array"] + + u_min = min(u.real.min() for u in u_n_list_array) + u_max = max(u.real.max() for u in u_n_list_array) + if show_true: + u_min = min(u_min, min(u.real.min() for u in true_u_n_list_array)) + u_max = max(u_max, max(u.real.max() for u in true_u_n_list_array)) + + w_min = min( + min(w.real.min() for w in w_n_list_array), wp_bar_array.real.min(), ω0_min + ) + w_max = max( + max(w.real.max() for w in w_n_list_array), wp_bar_array.real.max(), ω0_max + ) + j_min = min(j.real.min() for j in j_n_list_array) + j_max = max(j.real.max() for j in j_n_list_array) + + μ = list(filter(lambda x: x[2] != 0.0 and not np.isnan(x).any(), mu)) + μ_x = list(map(lambda x: x[0], μ)) + μ_y = list(map(lambda x: x[1], μ)) + μ_alpha = list(map(lambda x: x[2], μ)) + true_μ_x = list(map(lambda x: x[0], true_μ)) + true_μ_y = list(map(lambda x: x[1], true_μ)) + true_μ_alpha = list(map(lambda x: x[2], true_μ)) + + if len(μ_alpha) == 0: + alpha_mi, alpha_ma, alpha_me = 0, 0, 0 + else: + alpha_mi, alpha_ma, alpha_me = min(μ_alpha), max(μ_alpha), median(μ_alpha) + + true_alpha_mi, true_alpha_ma = min(true_μ_alpha), max(true_μ_alpha) + alpha_base = min(true_alpha_mi, alpha_mi) + alpha_scale = max(true_alpha_ma, alpha_ma) - alpha_base + if alpha_scale == 0: + ms = lambda m: 6 + else: + ms = lambda m: int(1 + (m - alpha_base) / alpha_scale * 10) + + def plot_array( + name, t_idx, ax, u_array, u_min, u_max, colorbar=None, measure=False + ): + if u_min == u_max: + levels = [u_min, u_min + 1e-9] + else: + levels = np.linspace(u_min, u_max, quantisation) + try: + contour = ax.tricontourf(triang, u_array, levels=levels, cmap="viridis") + if colorbar: + colorbar.update_normal(contour) + except Exception as e: + print(e) + if measure: + for x, y, m in zip(μ_x, μ_y, μ_alpha): + ax.plot([x], [y], "ro", markersize=ms(m), label="Sources") + for x, y, m in zip(true_μ_x, true_μ_y, true_μ_alpha): + ax.plot([x], [y], "kx", markersize=ms(m), label="True sources") + if t_idx >= 0: + ax.set_title(f"%s; t = {t_idx:.1f}" % name) + else: + ax.set_title("%s" % name) + ax.set_aspect("equal") + + n = len(u_n_list_array) + # frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m))) + frames = range(0, n) + for i, axu, axw, axj, t_idx in zip( + frames, + axus, + axws, + axjs, + times, # map(lambda i: i / (n - 1), frames), + ): + plot_array( + "u", t_idx, axu, u_n_list_array[i].real, u_min, u_max, colorbar_u, True + ) + plot_array( + "w", t_idx, axw, w_n_list_array[i].real, w_min, w_max, colorbar_w, True + ) + plot_array( + "j", t_idx, axj, j_n_list_array[i].real, j_min, j_max, colorbar_j, True + ) + + if show_true: + for i, ax, t_idx in zip( + frames, + ax_wpbar_plus, + map(lambda i: i / (n - 1), frames), + ): + plot_array( + "û", + t_idx, + ax, + true_u_n_list_array[i].real, + u_min, + u_max, + colorbar_u, + True, + ) + else: + plot_array("w̄ₚ", -1, ax_wpbar_plus[0], wp_bar_array.real, w_min, w_max) + plot_array("ω₀", -1, ax_wpbar_plus[1], ω0.real, w_min, w_max) + # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max) + + plt.suptitle( + "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f" + % (iter, len(μ), alpha_mi, alpha_ma, alpha_me), + fontsize=14, + ) + # plt.savefig("solution_evolution_%d.png" % iter, dpi=300) + + state = {"k": 0, "show_true": False} + + def on_key(event): + k0 = state["k"] + k = None + if event.key == "right" or event.key == " ": + k = (k0 + 1) % len(iter_files) + elif event.key == "left" or event.key == "backspace": + k = (k0 - 1) % len(iter_files) + elif event.key == "shift+right": + k = (k0 + 10) % len(iter_files) + elif event.key == "shift+left": + k = (k0 - 10) % len(iter_files) + elif event.key == "up": + k = (k0 + 100) % len(iter_files) + elif event.key == "down": + k = (k0 - 100) % len(iter_files) + elif event.key == "0": + k = 0 + elif event.key == "t": + state["show_true"] = not state["show_true"] + k = k0 + elif event.key == "q": + sys.exit() + if k is not None: + state["k"] = k + do_plot(iter_files[k], state["show_true"]) + fig.canvas.draw() + + do_plot(iter_files[0]) + + fig.canvas.mpl_connect("key_press_event", on_key) + + plt.show() + # # Time evolution + # fig, ax = plt.subplots(figsize=(10, 4)) + # times = np.arange(len(u_n_list)) * pde.dt + # max_vals = [np.max(u.x.array.real) for u in u_n_list] + + # ax.plot(times, max_vals, "r-o", linewidth=2, markersize=4) + # ax.set_xlabel("Time t") + # ax.set_ylabel("max|u|") + # ax.grid(True, alpha=0.3) + # ax.set_title("Solution Evolution") + # plt.tight_layout() + # plt.savefig("solution_time.png", dpi=150) + # plt.show() + + # print("Saved: solution_evolution.png + solution_time.png") + + +plot(sys.argv[1])