--- a/plot.py Thu Feb 26 09:32:12 2026 -0500 +++ b/plot.py Wed Apr 22 22:32:00 2026 -0500 @@ -1,3 +1,5 @@ +#!.venv/bin/python3 +import argparse import os import re import sys @@ -12,16 +14,15 @@ import matplotlib.tri as tri import numpy as np from dolfinx import fem - from src.convection_diffusion import ConvectionDiffusion -def find_files(directory): +def find_files(dirname): """ Return a list of Path objects for files named 'fubar%d.npz' in the given directory, sorted by the integer %d. """ - directory = Path(directory) + directory = Path(dirname) pattern = re.compile(r"^res_(\d+)\.npz$") matches = [] @@ -39,7 +40,7 @@ return matches -def plot(prefix): +def plot(prefix, simple=False, shaded=False, init_iter=None, save_only=False): quantisation = 32 iter_files = find_files(prefix) @@ -59,16 +60,27 @@ triang = tri.Triangulation(x_coords, y_coords) m = 5 - fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4)) - plt.tight_layout() + if simple: + fig, axus = plt.subplots(1, m) + all_axes = lambda: axus + fig.set_layout_engine("compressed") + else: + fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4)) + all_axes = lambda: chain(axus, axws, axjs, ax_wpbar_plus) + plt.tight_layout(pad=0.01) + fig.subplots_adjust( + left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2 + ) + fig.set_size_inches(13, 8) - fig.subplots_adjust( - left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2 - ) - for ax in chain(axus, axws, axjs, ax_wpbar_plus): - ax.set_xticklabels([]) - ax.set_yticklabels([]) - ax.set_aspect("equal") + + if not simple: + info = plt.text( + 1.0, + -0.07, + "", + horizontalalignment="right", + ) mpl.rcParams["axes.labelsize"] = 9 @@ -76,10 +88,12 @@ sm = cm.ScalarMappable(norm=norm, cmap="viridis") sm.set_array([]) # REQUIRED (matplotlib quirk) colorbar_u = fig.colorbar(sm, ax=axus[-1]) - colorbar_w = fig.colorbar(sm, ax=axws[-1]) - colorbar_j = fig.colorbar(sm, ax=axjs[-1]) + if not simple: + colorbar_w = fig.colorbar(sm, ax=axws[-1]) + colorbar_j = fig.colorbar(sm, ax=axjs[-1]) def do_plot(index_and_file, show_true=False): + (iter, file) = index_and_file print("🎨 Plotting iteration %d..." % iter) @@ -92,11 +106,6 @@ ω0_min = ω0.real.min() ω0_max = ω0.real.max() - for ax in chain(axus, axws, axjs, ax_wpbar_plus): - # ax.clear() - for artist in chain(ax.lines, ax.collections, ax.images): - artist.remove() - data = np.load(file) u_n_list_array = data["u_n_list_array"] w_n_list_array = data["w_n_list_array"] @@ -122,11 +131,7 @@ j_max = max(j.real.max() for j in j_n_list_array) μ = list(filter(lambda x: x[2] != 0.0 and not np.isnan(x).any(), mu)) - μ_x = list(map(lambda x: x[0], μ)) - μ_y = list(map(lambda x: x[1], μ)) μ_alpha = list(map(lambda x: x[2], μ)) - true_μ_x = list(map(lambda x: x[0], true_μ)) - true_μ_y = list(map(lambda x: x[1], true_μ)) true_μ_alpha = list(map(lambda x: x[2], true_μ)) if len(μ_alpha) == 0: @@ -143,78 +148,248 @@ ms = lambda m: int(1 + (m - alpha_base) / alpha_scale * 10) def plot_array( - name, t_idx, ax, u_array, u_min, u_max, colorbar=None, measure=False + name, + t_idx, + ax, + u_array, + u_min, + u_max, + μ, + true_μ, + colorbar=None, + measure=False, ): if u_min == u_max: levels = [u_min, u_min + 1e-9] else: levels = np.linspace(u_min, u_max, quantisation) try: - contour = ax.tricontourf(triang, u_array, levels=levels, cmap="viridis") + if shaded: + concentration = ax.tripcolor( + triang, + u_array, + cmap="viridis", + shading="gouraud", + ) + else: + concentration = ax.tricontourf( + triang, + u_array, + levels=levels, + cmap="viridis", + antialiased=False, + zorder=-10, + ) + ax.set_rasterization_zorder(0) + if colorbar: - colorbar.update_normal(contour) + colorbar.update_normal(concentration) + fmt = mpl.ticker.ScalarFormatter(useOffset=False) + fmt.set_scientific(False) + colorbar.ax.yaxis.set_major_formatter(fmt) + colorbar.update_ticks() + colorbar.ax.yaxis.get_offset_text().set_visible(False) + off = colorbar.ax.yaxis.get_offset_text() + off.set_x(10) + except Exception as e: print(e) - if measure: + if μ is not None: + μ_x = list(map(lambda x: x[0], μ)) + μ_y = list(map(lambda x: x[1], μ)) + μ_alpha = list(map(lambda x: x[2], μ)) for x, y, m in zip(μ_x, μ_y, μ_alpha): ax.plot([x], [y], "ro", markersize=ms(m), label="Sources") + if true_μ is not None: + true_μ_x = list(map(lambda x: x[0], true_μ)) + true_μ_y = list(map(lambda x: x[1], true_μ)) + true_μ_alpha = list(map(lambda x: x[2], true_μ)) for x, y, m in zip(true_μ_x, true_μ_y, true_μ_alpha): - ax.plot([x], [y], "kx", markersize=ms(m), label="True sources") + ax.plot([x], [y], "wx", markersize=ms(m), label="True sources") if t_idx >= 0: - ax.set_title(f"%s; t = {t_idx:.1f}" % name) + ax.set_title( + f"$%s_{{{t_idx:.1f}}}$" % name, + fontsize=20 if simple else 14, + ) else: - ax.set_title("%s" % name) + ax.set_title( + "%s" % name, + fontsize=20 if simple else 14, + ) ax.set_aspect("equal") n = len(u_n_list_array) # frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m))) frames = range(0, n) - for i, axu, axw, axj, t_idx in zip( - frames, - axus, - axws, - axjs, - times, # map(lambda i: i / (n - 1), frames), - ): - plot_array( - "u", t_idx, axu, u_n_list_array[i].real, u_min, u_max, colorbar_u, True - ) - plot_array( - "w", t_idx, axw, w_n_list_array[i].real, w_min, w_max, colorbar_w, True - ) - plot_array( - "j", t_idx, axj, j_n_list_array[i].real, j_min, j_max, colorbar_j, True + + for ax in all_axes(): + ax.clear() + + if not simple: + for i, axu, axw, axj, t_idx in zip( + frames, + axus, + axws, + axjs, + times, # map(lambda i: i / (n - 1), frames), + ): + plot_array( + "u", + t_idx, + axu, + u_n_list_array[i].real, + u_min, + u_max, + μ, + true_μ, + colorbar_u, + ) + plot_array( + "w", + t_idx, + axw, + w_n_list_array[i].real, + w_min, + w_max, + μ, + true_μ, + colorbar_w, + ) + plot_array( + "j", + t_idx, + axj, + j_n_list_array[i].real, + j_min, + j_max, + μ, + true_μ, + colorbar_j, + ) + + if show_true: + for i, ax, t_idx in zip( + frames, + ax_wpbar_plus, + map(lambda i: i / (n - 1), frames), + ): + plot_array( + "û", + t_idx, + ax, + true_u_n_list_array[i].real, + u_min, + u_max, + μ, + true_μ, + colorbar_u, + ) + else: + plot_array( + "w̄ₚ", + -1, + ax_wpbar_plus[0], + wp_bar_array.real, + w_min, + w_max, + μ, + None, + ) + plot_array( + "ω₀", + -1, + ax_wpbar_plus[1], + ω0.real, + w_min, + w_max, + μ, + None, + ) + # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max) + + # Has to be here after everything, for some reason + for ax in all_axes(): + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_aspect("equal") + + plt.suptitle( + "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f;" + % ( + iter, + len(μ), + alpha_mi, + alpha_ma, + alpha_me, + ), + fontsize=14, ) - if show_true: - for i, ax, t_idx in zip( + str = "(k, c, c1) = (%g, %g, %g)" % (data["k"], data["c1"], data["c2"]) + if "k" in data0: + str += "; true (%g, %g, %g)" % (data0["k"], data0["c1"], data0["c2"]) + info.set_text(str) + else: # simple + for i, axu, t_idx in zip( frames, - ax_wpbar_plus, - map(lambda i: i / (n - 1), frames), + axus, + times, # map(lambda i: i / (n - 1), frames), ): - plot_array( - "û", - t_idx, - ax, - true_u_n_list_array[i].real, - u_min, - u_max, - colorbar_u, - True, - ) + if show_true: + plot_array( + "û", + t_idx, + axu, + true_u_n_list_array[i].real, + u_min, + u_max, + None, + true_μ, + colorbar_u, + ) + else: + plot_array( + "u", + t_idx, + axu, + u_n_list_array[i].real, + u_min, + u_max, + μ, + true_μ, + colorbar_u, + ) + + def do_save(iter, show_true): + if not show_true: + filename = "%s/res_%d.pdf" % (prefix, iter) else: - plot_array("w̄ₚ", -1, ax_wpbar_plus[0], wp_bar_array.real, w_min, w_max) - plot_array("ω₀", -1, ax_wpbar_plus[1], ω0.real, w_min, w_max) - # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max) + filename = "%s/true_%d.pdf" % (prefix, iter) + print("Saving ", filename) + plt.savefig( + filename, + dpi=300, + bbox_inches="tight", + pad_inches=0, + ) - plt.suptitle( - "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f" - % (iter, len(μ), alpha_mi, alpha_ma, alpha_me), - fontsize=14, + if init_iter is None: + k = 0 + else: + k, _ = next( + filter(lambda fubar: fubar[1][0] == init_iter, enumerate(iter_files)) ) - # plt.savefig("solution_evolution_%d.png" % iter, dpi=300) + + do_plot(iter_files[k], show_true=False) - state = {"k": 0, "show_true": False} + if save_only: + iter = iter_files[k][0] + do_save(iter, False) + do_plot(iter_files[k], show_true=True) + do_save(iter, True) + return + + state = {"k": k, "show_true": False} def on_key(event): k0 = state["k"] @@ -238,13 +413,14 @@ k = k0 elif event.key == "q": sys.exit() + elif event.key == "z": + iter, _file = iter_files[k0] + do_save(iter, state["show_true"]) if k is not None: state["k"] = k do_plot(iter_files[k], state["show_true"]) fig.canvas.draw() - do_plot(iter_files[0]) - fig.canvas.mpl_connect("key_press_event", on_key) plt.show() @@ -265,4 +441,17 @@ # print("Saved: solution_evolution.png + solution_time.png") -plot(sys.argv[1]) +parser = argparse.ArgumentParser( + prog="plot.py", + description="Plots results of pointsource_pde", +) + +parser.add_argument("filename") +parser.add_argument("--simple", action="store_true") +parser.add_argument("--shaded", action="store_true") +parser.add_argument("--iter", type=int, action="store") +parser.add_argument("--save-only", action="store_true") + +args = parser.parse_args() + +plot(args.filename, args.simple, args.shaded, args.iter, args.save_only)