plot.py

changeset 1
a4137aedcb3a
child 3
c3a4f4bb87f7
equal deleted inserted replaced
0:7ec1cfe19a24 1:a4137aedcb3a
1 import os
2 import re
3 import sys
4 from itertools import chain
5 from pathlib import Path
6 from statistics import median
7
8 import matplotlib as mpl
9 import matplotlib.cm as cm
10 import matplotlib.colors as colors
11 import matplotlib.pyplot as plt
12 import matplotlib.tri as tri
13 import numpy as np
14 from dolfinx import fem
15
16 from src.convection_diffusion import ConvectionDiffusion
17
18
19 def find_files(directory):
20 """
21 Return a list of Path objects for files named 'fubar%d.npz'
22 in the given directory, sorted by the integer %d.
23 """
24 directory = Path(directory)
25 pattern = re.compile(r"^res_(\d+)\.npz$")
26
27 matches = []
28
29 for path in directory.iterdir():
30 if path.is_file():
31 m = pattern.match(path.name)
32 if m:
33 number = int(m.group(1))
34 matches.append((number, path))
35
36 # Sort numerically by extracted number
37 matches.sort(key=lambda x: x[0])
38
39 return matches
40
41
42 def plot(prefix):
43 quantisation = 32
44
45 iter_files = find_files(prefix)
46
47 pde = ConvectionDiffusion(
48 u0=lambda x: 0.0 * x[1],
49 g=lambda x: 0.0 * x[1], # ((x[1] + 2) / 4) ** 2, # nonzero bcs
50 # w0=lambda x: np.cos(5 * x[1]),
51 domain_size=[0, 0.5, 0, 0.5],
52 # nx=128,
53 # ny=128,
54 )
55
56 # Get coordinates
57 coords = pde.V.tabulate_dof_coordinates()
58 x_coords, y_coords = coords[:, 0], coords[:, 1]
59 triang = tri.Triangulation(x_coords, y_coords)
60
61 m = 5
62 fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4))
63 plt.tight_layout()
64 fig.set_size_inches(13, 8)
65 fig.subplots_adjust(
66 left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2
67 )
68 for ax in chain(axus, axws, axjs, ax_wpbar_plus):
69 ax.set_xticklabels([])
70 ax.set_yticklabels([])
71 ax.set_aspect("equal")
72
73 mpl.rcParams["axes.labelsize"] = 9
74
75 norm = colors.Normalize(vmin=0, vmax=1)
76 sm = cm.ScalarMappable(norm=norm, cmap="viridis")
77 sm.set_array([]) # REQUIRED (matplotlib quirk)
78 colorbar_u = fig.colorbar(sm, ax=axus[-1])
79 colorbar_w = fig.colorbar(sm, ax=axws[-1])
80 colorbar_j = fig.colorbar(sm, ax=axjs[-1])
81
82 def do_plot(index_and_file, show_true=False):
83 (iter, file) = index_and_file
84
85 print("🎨 Plotting iteration %d..." % iter)
86
87 data0 = np.load("%s/omega_0.npz" % Path(prefix).parent)
88 ω0 = data0["omega0"]
89 true_μ = data0["true_mu"]
90 # true_ω = data0["true_omega"]
91 true_u_n_list_array = data0["true_u_n_list_array"]
92 ω0_min = ω0.real.min()
93 ω0_max = ω0.real.max()
94
95 for ax in chain(axus, axws, axjs, ax_wpbar_plus):
96 # ax.clear()
97 for artist in chain(ax.lines, ax.collections, ax.images):
98 artist.remove()
99
100 data = np.load(file)
101 u_n_list_array = data["u_n_list_array"]
102 w_n_list_array = data["w_n_list_array"]
103 j_n_list_array = data["j_n_list_array"]
104 # frames = data["frames"]
105 times = data["times"]
106 mu = data["mu"]
107 wp_bar_array = data["wp_bar_array"]
108
109 u_min = min(u.real.min() for u in u_n_list_array)
110 u_max = max(u.real.max() for u in u_n_list_array)
111 if show_true:
112 u_min = min(u_min, min(u.real.min() for u in true_u_n_list_array))
113 u_max = max(u_max, max(u.real.max() for u in true_u_n_list_array))
114
115 w_min = min(
116 min(w.real.min() for w in w_n_list_array), wp_bar_array.real.min(), ω0_min
117 )
118 w_max = max(
119 max(w.real.max() for w in w_n_list_array), wp_bar_array.real.max(), ω0_max
120 )
121 j_min = min(j.real.min() for j in j_n_list_array)
122 j_max = max(j.real.max() for j in j_n_list_array)
123
124 μ = list(filter(lambda x: x[2] != 0.0 and not np.isnan(x).any(), mu))
125 μ_x = list(map(lambda x: x[0], μ))
126 μ_y = list(map(lambda x: x[1], μ))
127 μ_alpha = list(map(lambda x: x[2], μ))
128 true_μ_x = list(map(lambda x: x[0], true_μ))
129 true_μ_y = list(map(lambda x: x[1], true_μ))
130 true_μ_alpha = list(map(lambda x: x[2], true_μ))
131
132 if len(μ_alpha) == 0:
133 alpha_mi, alpha_ma, alpha_me = 0, 0, 0
134 else:
135 alpha_mi, alpha_ma, alpha_me = min(μ_alpha), max(μ_alpha), median(μ_alpha)
136
137 true_alpha_mi, true_alpha_ma = min(true_μ_alpha), max(true_μ_alpha)
138 alpha_base = min(true_alpha_mi, alpha_mi)
139 alpha_scale = max(true_alpha_ma, alpha_ma) - alpha_base
140 if alpha_scale == 0:
141 ms = lambda m: 6
142 else:
143 ms = lambda m: int(1 + (m - alpha_base) / alpha_scale * 10)
144
145 def plot_array(
146 name, t_idx, ax, u_array, u_min, u_max, colorbar=None, measure=False
147 ):
148 if u_min == u_max:
149 levels = [u_min, u_min + 1e-9]
150 else:
151 levels = np.linspace(u_min, u_max, quantisation)
152 try:
153 contour = ax.tricontourf(triang, u_array, levels=levels, cmap="viridis")
154 if colorbar:
155 colorbar.update_normal(contour)
156 except Exception as e:
157 print(e)
158 if measure:
159 for x, y, m in zip(μ_x, μ_y, μ_alpha):
160 ax.plot([x], [y], "ro", markersize=ms(m), label="Sources")
161 for x, y, m in zip(true_μ_x, true_μ_y, true_μ_alpha):
162 ax.plot([x], [y], "kx", markersize=ms(m), label="True sources")
163 if t_idx >= 0:
164 ax.set_title(f"%s; t = {t_idx:.1f}" % name)
165 else:
166 ax.set_title("%s" % name)
167 ax.set_aspect("equal")
168
169 n = len(u_n_list_array)
170 # frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
171 frames = range(0, n)
172 for i, axu, axw, axj, t_idx in zip(
173 frames,
174 axus,
175 axws,
176 axjs,
177 times, # map(lambda i: i / (n - 1), frames),
178 ):
179 plot_array(
180 "u", t_idx, axu, u_n_list_array[i].real, u_min, u_max, colorbar_u, True
181 )
182 plot_array(
183 "w", t_idx, axw, w_n_list_array[i].real, w_min, w_max, colorbar_w, True
184 )
185 plot_array(
186 "j", t_idx, axj, j_n_list_array[i].real, j_min, j_max, colorbar_j, True
187 )
188
189 if show_true:
190 for i, ax, t_idx in zip(
191 frames,
192 ax_wpbar_plus,
193 map(lambda i: i / (n - 1), frames),
194 ):
195 plot_array(
196 "û",
197 t_idx,
198 ax,
199 true_u_n_list_array[i].real,
200 u_min,
201 u_max,
202 colorbar_u,
203 True,
204 )
205 else:
206 plot_array("w̄ₚ", -1, ax_wpbar_plus[0], wp_bar_array.real, w_min, w_max)
207 plot_array("ω₀", -1, ax_wpbar_plus[1], ω0.real, w_min, w_max)
208 # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max)
209
210 plt.suptitle(
211 "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f"
212 % (iter, len(μ), alpha_mi, alpha_ma, alpha_me),
213 fontsize=14,
214 )
215 # plt.savefig("solution_evolution_%d.png" % iter, dpi=300)
216
217 state = {"k": 0, "show_true": False}
218
219 def on_key(event):
220 k0 = state["k"]
221 k = None
222 if event.key == "right" or event.key == " ":
223 k = (k0 + 1) % len(iter_files)
224 elif event.key == "left" or event.key == "backspace":
225 k = (k0 - 1) % len(iter_files)
226 elif event.key == "shift+right":
227 k = (k0 + 10) % len(iter_files)
228 elif event.key == "shift+left":
229 k = (k0 - 10) % len(iter_files)
230 elif event.key == "up":
231 k = (k0 + 100) % len(iter_files)
232 elif event.key == "down":
233 k = (k0 - 100) % len(iter_files)
234 elif event.key == "0":
235 k = 0
236 elif event.key == "t":
237 state["show_true"] = not state["show_true"]
238 k = k0
239 elif event.key == "q":
240 sys.exit()
241 if k is not None:
242 state["k"] = k
243 do_plot(iter_files[k], state["show_true"])
244 fig.canvas.draw()
245
246 do_plot(iter_files[0])
247
248 fig.canvas.mpl_connect("key_press_event", on_key)
249
250 plt.show()
251 # # Time evolution
252 # fig, ax = plt.subplots(figsize=(10, 4))
253 # times = np.arange(len(u_n_list)) * pde.dt
254 # max_vals = [np.max(u.x.array.real) for u in u_n_list]
255
256 # ax.plot(times, max_vals, "r-o", linewidth=2, markersize=4)
257 # ax.set_xlabel("Time t")
258 # ax.set_ylabel("max|u|")
259 # ax.grid(True, alpha=0.3)
260 # ax.set_title("Solution Evolution")
261 # plt.tight_layout()
262 # plt.savefig("solution_time.png", dpi=150)
263 # plt.show()
264
265 # print("Saved: solution_evolution.png + solution_time.png")
266
267
268 plot(sys.argv[1])

mercurial