plot.py

changeset 3
c3a4f4bb87f7
parent 1
a4137aedcb3a
--- a/plot.py	Thu Feb 26 09:32:12 2026 -0500
+++ b/plot.py	Wed Apr 22 22:32:00 2026 -0500
@@ -1,3 +1,5 @@
+#!.venv/bin/python3
+import argparse
 import os
 import re
 import sys
@@ -12,16 +14,15 @@
 import matplotlib.tri as tri
 import numpy as np
 from dolfinx import fem
-
 from src.convection_diffusion import ConvectionDiffusion
 
 
-def find_files(directory):
+def find_files(dirname):
     """
     Return a list of Path objects for files named 'fubar%d.npz'
     in the given directory, sorted by the integer %d.
     """
-    directory = Path(directory)
+    directory = Path(dirname)
     pattern = re.compile(r"^res_(\d+)\.npz$")
 
     matches = []
@@ -39,7 +40,7 @@
     return matches
 
 
-def plot(prefix):
+def plot(prefix, simple=False, shaded=False, init_iter=None, save_only=False):
     quantisation = 32
 
     iter_files = find_files(prefix)
@@ -59,16 +60,27 @@
     triang = tri.Triangulation(x_coords, y_coords)
 
     m = 5
-    fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4))
-    plt.tight_layout()
+    if simple:
+        fig, axus = plt.subplots(1, m)
+        all_axes = lambda: axus
+        fig.set_layout_engine("compressed")
+    else:
+        fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4))
+        all_axes = lambda: chain(axus, axws, axjs, ax_wpbar_plus)
+        plt.tight_layout(pad=0.01)
+        fig.subplots_adjust(
+            left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2
+        )
+
     fig.set_size_inches(13, 8)
-    fig.subplots_adjust(
-        left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2
-    )
-    for ax in chain(axus, axws, axjs, ax_wpbar_plus):
-        ax.set_xticklabels([])
-        ax.set_yticklabels([])
-        ax.set_aspect("equal")
+
+    if not simple:
+        info = plt.text(
+            1.0,
+            -0.07,
+            "",
+            horizontalalignment="right",
+        )
 
     mpl.rcParams["axes.labelsize"] = 9
 
@@ -76,10 +88,12 @@
     sm = cm.ScalarMappable(norm=norm, cmap="viridis")
     sm.set_array([])  # REQUIRED (matplotlib quirk)
     colorbar_u = fig.colorbar(sm, ax=axus[-1])
-    colorbar_w = fig.colorbar(sm, ax=axws[-1])
-    colorbar_j = fig.colorbar(sm, ax=axjs[-1])
+    if not simple:
+        colorbar_w = fig.colorbar(sm, ax=axws[-1])
+        colorbar_j = fig.colorbar(sm, ax=axjs[-1])
 
     def do_plot(index_and_file, show_true=False):
+
         (iter, file) = index_and_file
 
         print("🎨 Plotting iteration %d..." % iter)
@@ -92,11 +106,6 @@
         ω0_min = ω0.real.min()
         ω0_max = ω0.real.max()
 
-        for ax in chain(axus, axws, axjs, ax_wpbar_plus):
-            # ax.clear()
-            for artist in chain(ax.lines, ax.collections, ax.images):
-                artist.remove()
-
         data = np.load(file)
         u_n_list_array = data["u_n_list_array"]
         w_n_list_array = data["w_n_list_array"]
@@ -122,11 +131,7 @@
         j_max = max(j.real.max() for j in j_n_list_array)
 
         μ = list(filter(lambda x: x[2] != 0.0 and not np.isnan(x).any(), mu))
-        μ_x = list(map(lambda x: x[0], μ))
-        μ_y = list(map(lambda x: x[1], μ))
         μ_alpha = list(map(lambda x: x[2], μ))
-        true_μ_x = list(map(lambda x: x[0], true_μ))
-        true_μ_y = list(map(lambda x: x[1], true_μ))
         true_μ_alpha = list(map(lambda x: x[2], true_μ))
 
         if len(μ_alpha) == 0:
@@ -143,78 +148,248 @@
             ms = lambda m: int(1 + (m - alpha_base) / alpha_scale * 10)
 
         def plot_array(
-            name, t_idx, ax, u_array, u_min, u_max, colorbar=None, measure=False
+            name,
+            t_idx,
+            ax,
+            u_array,
+            u_min,
+            u_max,
+            μ,
+            true_μ,
+            colorbar=None,
+            measure=False,
         ):
             if u_min == u_max:
                 levels = [u_min, u_min + 1e-9]
             else:
                 levels = np.linspace(u_min, u_max, quantisation)
             try:
-                contour = ax.tricontourf(triang, u_array, levels=levels, cmap="viridis")
+                if shaded:
+                    concentration = ax.tripcolor(
+                        triang,
+                        u_array,
+                        cmap="viridis",
+                        shading="gouraud",
+                    )
+                else:
+                    concentration = ax.tricontourf(
+                        triang,
+                        u_array,
+                        levels=levels,
+                        cmap="viridis",
+                        antialiased=False,
+                        zorder=-10,
+                    )
+                    ax.set_rasterization_zorder(0)
+
                 if colorbar:
-                    colorbar.update_normal(contour)
+                    colorbar.update_normal(concentration)
+                    fmt = mpl.ticker.ScalarFormatter(useOffset=False)
+                    fmt.set_scientific(False)
+                    colorbar.ax.yaxis.set_major_formatter(fmt)
+                    colorbar.update_ticks()
+                    colorbar.ax.yaxis.get_offset_text().set_visible(False)
+                    off = colorbar.ax.yaxis.get_offset_text()
+                    off.set_x(10)
+
             except Exception as e:
                 print(e)
-            if measure:
+            if μ is not None:
+                μ_x = list(map(lambda x: x[0], μ))
+                μ_y = list(map(lambda x: x[1], μ))
+                μ_alpha = list(map(lambda x: x[2], μ))
                 for x, y, m in zip(μ_x, μ_y, μ_alpha):
                     ax.plot([x], [y], "ro", markersize=ms(m), label="Sources")
+            if true_μ is not None:
+                true_μ_x = list(map(lambda x: x[0], true_μ))
+                true_μ_y = list(map(lambda x: x[1], true_μ))
+                true_μ_alpha = list(map(lambda x: x[2], true_μ))
                 for x, y, m in zip(true_μ_x, true_μ_y, true_μ_alpha):
-                    ax.plot([x], [y], "kx", markersize=ms(m), label="True sources")
+                    ax.plot([x], [y], "wx", markersize=ms(m), label="True sources")
             if t_idx >= 0:
-                ax.set_title(f"%s; t = {t_idx:.1f}" % name)
+                ax.set_title(
+                    f"$%s_{{{t_idx:.1f}}}$" % name,
+                    fontsize=20 if simple else 14,
+                )
             else:
-                ax.set_title("%s" % name)
+                ax.set_title(
+                    "%s" % name,
+                    fontsize=20 if simple else 14,
+                )
             ax.set_aspect("equal")
 
         n = len(u_n_list_array)
         # frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
         frames = range(0, n)
-        for i, axu, axw, axj, t_idx in zip(
-            frames,
-            axus,
-            axws,
-            axjs,
-            times,  # map(lambda i: i / (n - 1), frames),
-        ):
-            plot_array(
-                "u", t_idx, axu, u_n_list_array[i].real, u_min, u_max, colorbar_u, True
-            )
-            plot_array(
-                "w", t_idx, axw, w_n_list_array[i].real, w_min, w_max, colorbar_w, True
-            )
-            plot_array(
-                "j", t_idx, axj, j_n_list_array[i].real, j_min, j_max, colorbar_j, True
+
+        for ax in all_axes():
+            ax.clear()
+
+        if not simple:
+            for i, axu, axw, axj, t_idx in zip(
+                frames,
+                axus,
+                axws,
+                axjs,
+                times,  # map(lambda i: i / (n - 1), frames),
+            ):
+                plot_array(
+                    "u",
+                    t_idx,
+                    axu,
+                    u_n_list_array[i].real,
+                    u_min,
+                    u_max,
+                    μ,
+                    true_μ,
+                    colorbar_u,
+                )
+                plot_array(
+                    "w",
+                    t_idx,
+                    axw,
+                    w_n_list_array[i].real,
+                    w_min,
+                    w_max,
+                    μ,
+                    true_μ,
+                    colorbar_w,
+                )
+                plot_array(
+                    "j",
+                    t_idx,
+                    axj,
+                    j_n_list_array[i].real,
+                    j_min,
+                    j_max,
+                    μ,
+                    true_μ,
+                    colorbar_j,
+                )
+
+            if show_true:
+                for i, ax, t_idx in zip(
+                    frames,
+                    ax_wpbar_plus,
+                    map(lambda i: i / (n - 1), frames),
+                ):
+                    plot_array(
+                        "û",
+                        t_idx,
+                        ax,
+                        true_u_n_list_array[i].real,
+                        u_min,
+                        u_max,
+                        μ,
+                        true_μ,
+                        colorbar_u,
+                    )
+            else:
+                plot_array(
+                    "w̄ₚ",
+                    -1,
+                    ax_wpbar_plus[0],
+                    wp_bar_array.real,
+                    w_min,
+                    w_max,
+                    μ,
+                    None,
+                )
+                plot_array(
+                    "ω₀",
+                    -1,
+                    ax_wpbar_plus[1],
+                    ω0.real,
+                    w_min,
+                    w_max,
+                    μ,
+                    None,
+                )
+                # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max)
+
+            # Has to be here after everything, for some reason
+            for ax in all_axes():
+                ax.set_xticklabels([])
+                ax.set_yticklabels([])
+                ax.set_aspect("equal")
+
+            plt.suptitle(
+                "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f;"
+                % (
+                    iter,
+                    len(μ),
+                    alpha_mi,
+                    alpha_ma,
+                    alpha_me,
+                ),
+                fontsize=14,
             )
 
-        if show_true:
-            for i, ax, t_idx in zip(
+            str = "(k, c, c1) = (%g, %g, %g)" % (data["k"], data["c1"], data["c2"])
+            if "k" in data0:
+                str += "; true (%g, %g, %g)" % (data0["k"], data0["c1"], data0["c2"])
+            info.set_text(str)
+        else:  # simple
+            for i, axu, t_idx in zip(
                 frames,
-                ax_wpbar_plus,
-                map(lambda i: i / (n - 1), frames),
+                axus,
+                times,  # map(lambda i: i / (n - 1), frames),
             ):
-                plot_array(
-                    "û",
-                    t_idx,
-                    ax,
-                    true_u_n_list_array[i].real,
-                    u_min,
-                    u_max,
-                    colorbar_u,
-                    True,
-                )
+                if show_true:
+                    plot_array(
+                        "û",
+                        t_idx,
+                        axu,
+                        true_u_n_list_array[i].real,
+                        u_min,
+                        u_max,
+                        None,
+                        true_μ,
+                        colorbar_u,
+                    )
+                else:
+                    plot_array(
+                        "u",
+                        t_idx,
+                        axu,
+                        u_n_list_array[i].real,
+                        u_min,
+                        u_max,
+                        μ,
+                        true_μ,
+                        colorbar_u,
+                    )
+
+    def do_save(iter, show_true):
+        if not show_true:
+            filename = "%s/res_%d.pdf" % (prefix, iter)
         else:
-            plot_array("w̄ₚ", -1, ax_wpbar_plus[0], wp_bar_array.real, w_min, w_max)
-            plot_array("ω₀", -1, ax_wpbar_plus[1], ω0.real, w_min, w_max)
-            # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max)
+            filename = "%s/true_%d.pdf" % (prefix, iter)
+        print("Saving ", filename)
+        plt.savefig(
+            filename,
+            dpi=300,
+            bbox_inches="tight",
+            pad_inches=0,
+        )
 
-        plt.suptitle(
-            "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f"
-            % (iter, len(μ), alpha_mi, alpha_ma, alpha_me),
-            fontsize=14,
+    if init_iter is None:
+        k = 0
+    else:
+        k, _ = next(
+            filter(lambda fubar: fubar[1][0] == init_iter, enumerate(iter_files))
         )
-        # plt.savefig("solution_evolution_%d.png" % iter, dpi=300)
+
+    do_plot(iter_files[k], show_true=False)
 
-    state = {"k": 0, "show_true": False}
+    if save_only:
+        iter = iter_files[k][0]
+        do_save(iter, False)
+        do_plot(iter_files[k], show_true=True)
+        do_save(iter, True)
+        return
+
+    state = {"k": k, "show_true": False}
 
     def on_key(event):
         k0 = state["k"]
@@ -238,13 +413,14 @@
             k = k0
         elif event.key == "q":
             sys.exit()
+        elif event.key == "z":
+            iter, _file = iter_files[k0]
+            do_save(iter, state["show_true"])
         if k is not None:
             state["k"] = k
             do_plot(iter_files[k], state["show_true"])
             fig.canvas.draw()
 
-    do_plot(iter_files[0])
-
     fig.canvas.mpl_connect("key_press_event", on_key)
 
     plt.show()
@@ -265,4 +441,17 @@
     # print("Saved: solution_evolution.png + solution_time.png")
 
 
-plot(sys.argv[1])
+parser = argparse.ArgumentParser(
+    prog="plot.py",
+    description="Plots results of pointsource_pde",
+)
+
+parser.add_argument("filename")
+parser.add_argument("--simple", action="store_true")
+parser.add_argument("--shaded", action="store_true")
+parser.add_argument("--iter", type=int, action="store")
+parser.add_argument("--save-only", action="store_true")
+
+args = parser.parse_args()
+
+plot(args.filename, args.simple, args.shaded, args.iter, args.save_only)

mercurial