plot.py

Thu, 26 Feb 2026 09:32:12 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 26 Feb 2026 09:32:12 -0500
changeset 1
a4137aedcb3a
child 3
c3a4f4bb87f7
permissions
-rw-r--r--

Initial working version for known convectivity and diffusivity

import os
import re
import sys
from itertools import chain
from pathlib import Path
from statistics import median

import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import numpy as np
from dolfinx import fem

from src.convection_diffusion import ConvectionDiffusion


def find_files(directory):
    """
    Return a list of Path objects for files named 'fubar%d.npz'
    in the given directory, sorted by the integer %d.
    """
    directory = Path(directory)
    pattern = re.compile(r"^res_(\d+)\.npz$")

    matches = []

    for path in directory.iterdir():
        if path.is_file():
            m = pattern.match(path.name)
            if m:
                number = int(m.group(1))
                matches.append((number, path))

    # Sort numerically by extracted number
    matches.sort(key=lambda x: x[0])

    return matches


def plot(prefix):
    quantisation = 32

    iter_files = find_files(prefix)

    pde = ConvectionDiffusion(
        u0=lambda x: 0.0 * x[1],
        g=lambda x: 0.0 * x[1],  # ((x[1] + 2) / 4) ** 2,  # nonzero bcs
        # w0=lambda x: np.cos(5 * x[1]),
        domain_size=[0, 0.5, 0, 0.5],
        # nx=128,
        # ny=128,
    )

    # Get coordinates
    coords = pde.V.tabulate_dof_coordinates()
    x_coords, y_coords = coords[:, 0], coords[:, 1]
    triang = tri.Triangulation(x_coords, y_coords)

    m = 5
    fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4))
    plt.tight_layout()
    fig.set_size_inches(13, 8)
    fig.subplots_adjust(
        left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2
    )
    for ax in chain(axus, axws, axjs, ax_wpbar_plus):
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect("equal")

    mpl.rcParams["axes.labelsize"] = 9

    norm = colors.Normalize(vmin=0, vmax=1)
    sm = cm.ScalarMappable(norm=norm, cmap="viridis")
    sm.set_array([])  # REQUIRED (matplotlib quirk)
    colorbar_u = fig.colorbar(sm, ax=axus[-1])
    colorbar_w = fig.colorbar(sm, ax=axws[-1])
    colorbar_j = fig.colorbar(sm, ax=axjs[-1])

    def do_plot(index_and_file, show_true=False):
        (iter, file) = index_and_file

        print("🎨 Plotting iteration %d..." % iter)

        data0 = np.load("%s/omega_0.npz" % Path(prefix).parent)
        ω0 = data0["omega0"]
        true_μ = data0["true_mu"]
        # true_ω = data0["true_omega"]
        true_u_n_list_array = data0["true_u_n_list_array"]
        ω0_min = ω0.real.min()
        ω0_max = ω0.real.max()

        for ax in chain(axus, axws, axjs, ax_wpbar_plus):
            # ax.clear()
            for artist in chain(ax.lines, ax.collections, ax.images):
                artist.remove()

        data = np.load(file)
        u_n_list_array = data["u_n_list_array"]
        w_n_list_array = data["w_n_list_array"]
        j_n_list_array = data["j_n_list_array"]
        # frames = data["frames"]
        times = data["times"]
        mu = data["mu"]
        wp_bar_array = data["wp_bar_array"]

        u_min = min(u.real.min() for u in u_n_list_array)
        u_max = max(u.real.max() for u in u_n_list_array)
        if show_true:
            u_min = min(u_min, min(u.real.min() for u in true_u_n_list_array))
            u_max = max(u_max, max(u.real.max() for u in true_u_n_list_array))

        w_min = min(
            min(w.real.min() for w in w_n_list_array), wp_bar_array.real.min(), ω0_min
        )
        w_max = max(
            max(w.real.max() for w in w_n_list_array), wp_bar_array.real.max(), ω0_max
        )
        j_min = min(j.real.min() for j in j_n_list_array)
        j_max = max(j.real.max() for j in j_n_list_array)

        μ = list(filter(lambda x: x[2] != 0.0 and not np.isnan(x).any(), mu))
        μ_x = list(map(lambda x: x[0], μ))
        μ_y = list(map(lambda x: x[1], μ))
        μ_alpha = list(map(lambda x: x[2], μ))
        true_μ_x = list(map(lambda x: x[0], true_μ))
        true_μ_y = list(map(lambda x: x[1], true_μ))
        true_μ_alpha = list(map(lambda x: x[2], true_μ))

        if len(μ_alpha) == 0:
            alpha_mi, alpha_ma, alpha_me = 0, 0, 0
        else:
            alpha_mi, alpha_ma, alpha_me = min(μ_alpha), max(μ_alpha), median(μ_alpha)

        true_alpha_mi, true_alpha_ma = min(true_μ_alpha), max(true_μ_alpha)
        alpha_base = min(true_alpha_mi, alpha_mi)
        alpha_scale = max(true_alpha_ma, alpha_ma) - alpha_base
        if alpha_scale == 0:
            ms = lambda m: 6
        else:
            ms = lambda m: int(1 + (m - alpha_base) / alpha_scale * 10)

        def plot_array(
            name, t_idx, ax, u_array, u_min, u_max, colorbar=None, measure=False
        ):
            if u_min == u_max:
                levels = [u_min, u_min + 1e-9]
            else:
                levels = np.linspace(u_min, u_max, quantisation)
            try:
                contour = ax.tricontourf(triang, u_array, levels=levels, cmap="viridis")
                if colorbar:
                    colorbar.update_normal(contour)
            except Exception as e:
                print(e)
            if measure:
                for x, y, m in zip(μ_x, μ_y, μ_alpha):
                    ax.plot([x], [y], "ro", markersize=ms(m), label="Sources")
                for x, y, m in zip(true_μ_x, true_μ_y, true_μ_alpha):
                    ax.plot([x], [y], "kx", markersize=ms(m), label="True sources")
            if t_idx >= 0:
                ax.set_title(f"%s; t = {t_idx:.1f}" % name)
            else:
                ax.set_title("%s" % name)
            ax.set_aspect("equal")

        n = len(u_n_list_array)
        # frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
        frames = range(0, n)
        for i, axu, axw, axj, t_idx in zip(
            frames,
            axus,
            axws,
            axjs,
            times,  # map(lambda i: i / (n - 1), frames),
        ):
            plot_array(
                "u", t_idx, axu, u_n_list_array[i].real, u_min, u_max, colorbar_u, True
            )
            plot_array(
                "w", t_idx, axw, w_n_list_array[i].real, w_min, w_max, colorbar_w, True
            )
            plot_array(
                "j", t_idx, axj, j_n_list_array[i].real, j_min, j_max, colorbar_j, True
            )

        if show_true:
            for i, ax, t_idx in zip(
                frames,
                ax_wpbar_plus,
                map(lambda i: i / (n - 1), frames),
            ):
                plot_array(
                    "û",
                    t_idx,
                    ax,
                    true_u_n_list_array[i].real,
                    u_min,
                    u_max,
                    colorbar_u,
                    True,
                )
        else:
            plot_array("w̄ₚ", -1, ax_wpbar_plus[0], wp_bar_array.real, w_min, w_max)
            plot_array("ω₀", -1, ax_wpbar_plus[1], ω0.real, w_min, w_max)
            # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max)

        plt.suptitle(
            "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f"
            % (iter, len(μ), alpha_mi, alpha_ma, alpha_me),
            fontsize=14,
        )
        # plt.savefig("solution_evolution_%d.png" % iter, dpi=300)

    state = {"k": 0, "show_true": False}

    def on_key(event):
        k0 = state["k"]
        k = None
        if event.key == "right" or event.key == " ":
            k = (k0 + 1) % len(iter_files)
        elif event.key == "left" or event.key == "backspace":
            k = (k0 - 1) % len(iter_files)
        elif event.key == "shift+right":
            k = (k0 + 10) % len(iter_files)
        elif event.key == "shift+left":
            k = (k0 - 10) % len(iter_files)
        elif event.key == "up":
            k = (k0 + 100) % len(iter_files)
        elif event.key == "down":
            k = (k0 - 100) % len(iter_files)
        elif event.key == "0":
            k = 0
        elif event.key == "t":
            state["show_true"] = not state["show_true"]
            k = k0
        elif event.key == "q":
            sys.exit()
        if k is not None:
            state["k"] = k
            do_plot(iter_files[k], state["show_true"])
            fig.canvas.draw()

    do_plot(iter_files[0])

    fig.canvas.mpl_connect("key_press_event", on_key)

    plt.show()
    # # Time evolution
    # fig, ax = plt.subplots(figsize=(10, 4))
    # times = np.arange(len(u_n_list)) * pde.dt
    # max_vals = [np.max(u.x.array.real) for u in u_n_list]

    # ax.plot(times, max_vals, "r-o", linewidth=2, markersize=4)
    # ax.set_xlabel("Time t")
    # ax.set_ylabel("max|u|")
    # ax.grid(True, alpha=0.3)
    # ax.set_title("Solution Evolution")
    # plt.tight_layout()
    # plt.savefig("solution_time.png", dpi=150)
    # plt.show()

    # print("Saved: solution_evolution.png + solution_time.png")


plot(sys.argv[1])

mercurial