plot.py

changeset 3
c3a4f4bb87f7
parent 1
a4137aedcb3a
equal deleted inserted replaced
1:a4137aedcb3a 3:c3a4f4bb87f7
1 #!.venv/bin/python3
2 import argparse
1 import os 3 import os
2 import re 4 import re
3 import sys 5 import sys
4 from itertools import chain 6 from itertools import chain
5 from pathlib import Path 7 from pathlib import Path
10 import matplotlib.colors as colors 12 import matplotlib.colors as colors
11 import matplotlib.pyplot as plt 13 import matplotlib.pyplot as plt
12 import matplotlib.tri as tri 14 import matplotlib.tri as tri
13 import numpy as np 15 import numpy as np
14 from dolfinx import fem 16 from dolfinx import fem
15
16 from src.convection_diffusion import ConvectionDiffusion 17 from src.convection_diffusion import ConvectionDiffusion
17 18
18 19
19 def find_files(directory): 20 def find_files(dirname):
20 """ 21 """
21 Return a list of Path objects for files named 'fubar%d.npz' 22 Return a list of Path objects for files named 'fubar%d.npz'
22 in the given directory, sorted by the integer %d. 23 in the given directory, sorted by the integer %d.
23 """ 24 """
24 directory = Path(directory) 25 directory = Path(dirname)
25 pattern = re.compile(r"^res_(\d+)\.npz$") 26 pattern = re.compile(r"^res_(\d+)\.npz$")
26 27
27 matches = [] 28 matches = []
28 29
29 for path in directory.iterdir(): 30 for path in directory.iterdir():
37 matches.sort(key=lambda x: x[0]) 38 matches.sort(key=lambda x: x[0])
38 39
39 return matches 40 return matches
40 41
41 42
42 def plot(prefix): 43 def plot(prefix, simple=False, shaded=False, init_iter=None, save_only=False):
43 quantisation = 32 44 quantisation = 32
44 45
45 iter_files = find_files(prefix) 46 iter_files = find_files(prefix)
46 47
47 pde = ConvectionDiffusion( 48 pde = ConvectionDiffusion(
57 coords = pde.V.tabulate_dof_coordinates() 58 coords = pde.V.tabulate_dof_coordinates()
58 x_coords, y_coords = coords[:, 0], coords[:, 1] 59 x_coords, y_coords = coords[:, 0], coords[:, 1]
59 triang = tri.Triangulation(x_coords, y_coords) 60 triang = tri.Triangulation(x_coords, y_coords)
60 61
61 m = 5 62 m = 5
62 fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4)) 63 if simple:
63 plt.tight_layout() 64 fig, axus = plt.subplots(1, m)
65 all_axes = lambda: axus
66 fig.set_layout_engine("compressed")
67 else:
68 fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4))
69 all_axes = lambda: chain(axus, axws, axjs, ax_wpbar_plus)
70 plt.tight_layout(pad=0.01)
71 fig.subplots_adjust(
72 left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2
73 )
74
64 fig.set_size_inches(13, 8) 75 fig.set_size_inches(13, 8)
65 fig.subplots_adjust( 76
66 left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2 77 if not simple:
67 ) 78 info = plt.text(
68 for ax in chain(axus, axws, axjs, ax_wpbar_plus): 79 1.0,
69 ax.set_xticklabels([]) 80 -0.07,
70 ax.set_yticklabels([]) 81 "",
71 ax.set_aspect("equal") 82 horizontalalignment="right",
83 )
72 84
73 mpl.rcParams["axes.labelsize"] = 9 85 mpl.rcParams["axes.labelsize"] = 9
74 86
75 norm = colors.Normalize(vmin=0, vmax=1) 87 norm = colors.Normalize(vmin=0, vmax=1)
76 sm = cm.ScalarMappable(norm=norm, cmap="viridis") 88 sm = cm.ScalarMappable(norm=norm, cmap="viridis")
77 sm.set_array([]) # REQUIRED (matplotlib quirk) 89 sm.set_array([]) # REQUIRED (matplotlib quirk)
78 colorbar_u = fig.colorbar(sm, ax=axus[-1]) 90 colorbar_u = fig.colorbar(sm, ax=axus[-1])
79 colorbar_w = fig.colorbar(sm, ax=axws[-1]) 91 if not simple:
80 colorbar_j = fig.colorbar(sm, ax=axjs[-1]) 92 colorbar_w = fig.colorbar(sm, ax=axws[-1])
93 colorbar_j = fig.colorbar(sm, ax=axjs[-1])
81 94
82 def do_plot(index_and_file, show_true=False): 95 def do_plot(index_and_file, show_true=False):
96
83 (iter, file) = index_and_file 97 (iter, file) = index_and_file
84 98
85 print("🎨 Plotting iteration %d..." % iter) 99 print("🎨 Plotting iteration %d..." % iter)
86 100
87 data0 = np.load("%s/omega_0.npz" % Path(prefix).parent) 101 data0 = np.load("%s/omega_0.npz" % Path(prefix).parent)
89 true_μ = data0["true_mu"] 103 true_μ = data0["true_mu"]
90 # true_ω = data0["true_omega"] 104 # true_ω = data0["true_omega"]
91 true_u_n_list_array = data0["true_u_n_list_array"] 105 true_u_n_list_array = data0["true_u_n_list_array"]
92 ω0_min = ω0.real.min() 106 ω0_min = ω0.real.min()
93 ω0_max = ω0.real.max() 107 ω0_max = ω0.real.max()
94
95 for ax in chain(axus, axws, axjs, ax_wpbar_plus):
96 # ax.clear()
97 for artist in chain(ax.lines, ax.collections, ax.images):
98 artist.remove()
99 108
100 data = np.load(file) 109 data = np.load(file)
101 u_n_list_array = data["u_n_list_array"] 110 u_n_list_array = data["u_n_list_array"]
102 w_n_list_array = data["w_n_list_array"] 111 w_n_list_array = data["w_n_list_array"]
103 j_n_list_array = data["j_n_list_array"] 112 j_n_list_array = data["j_n_list_array"]
120 ) 129 )
121 j_min = min(j.real.min() for j in j_n_list_array) 130 j_min = min(j.real.min() for j in j_n_list_array)
122 j_max = max(j.real.max() for j in j_n_list_array) 131 j_max = max(j.real.max() for j in j_n_list_array)
123 132
124 μ = list(filter(lambda x: x[2] != 0.0 and not np.isnan(x).any(), mu)) 133 μ = list(filter(lambda x: x[2] != 0.0 and not np.isnan(x).any(), mu))
125 μ_x = list(map(lambda x: x[0], μ))
126 μ_y = list(map(lambda x: x[1], μ))
127 μ_alpha = list(map(lambda x: x[2], μ)) 134 μ_alpha = list(map(lambda x: x[2], μ))
128 true_μ_x = list(map(lambda x: x[0], true_μ))
129 true_μ_y = list(map(lambda x: x[1], true_μ))
130 true_μ_alpha = list(map(lambda x: x[2], true_μ)) 135 true_μ_alpha = list(map(lambda x: x[2], true_μ))
131 136
132 if len(μ_alpha) == 0: 137 if len(μ_alpha) == 0:
133 alpha_mi, alpha_ma, alpha_me = 0, 0, 0 138 alpha_mi, alpha_ma, alpha_me = 0, 0, 0
134 else: 139 else:
141 ms = lambda m: 6 146 ms = lambda m: 6
142 else: 147 else:
143 ms = lambda m: int(1 + (m - alpha_base) / alpha_scale * 10) 148 ms = lambda m: int(1 + (m - alpha_base) / alpha_scale * 10)
144 149
145 def plot_array( 150 def plot_array(
146 name, t_idx, ax, u_array, u_min, u_max, colorbar=None, measure=False 151 name,
152 t_idx,
153 ax,
154 u_array,
155 u_min,
156 u_max,
157 μ,
158 true_μ,
159 colorbar=None,
160 measure=False,
147 ): 161 ):
148 if u_min == u_max: 162 if u_min == u_max:
149 levels = [u_min, u_min + 1e-9] 163 levels = [u_min, u_min + 1e-9]
150 else: 164 else:
151 levels = np.linspace(u_min, u_max, quantisation) 165 levels = np.linspace(u_min, u_max, quantisation)
152 try: 166 try:
153 contour = ax.tricontourf(triang, u_array, levels=levels, cmap="viridis") 167 if shaded:
168 concentration = ax.tripcolor(
169 triang,
170 u_array,
171 cmap="viridis",
172 shading="gouraud",
173 )
174 else:
175 concentration = ax.tricontourf(
176 triang,
177 u_array,
178 levels=levels,
179 cmap="viridis",
180 antialiased=False,
181 zorder=-10,
182 )
183 ax.set_rasterization_zorder(0)
184
154 if colorbar: 185 if colorbar:
155 colorbar.update_normal(contour) 186 colorbar.update_normal(concentration)
187 fmt = mpl.ticker.ScalarFormatter(useOffset=False)
188 fmt.set_scientific(False)
189 colorbar.ax.yaxis.set_major_formatter(fmt)
190 colorbar.update_ticks()
191 colorbar.ax.yaxis.get_offset_text().set_visible(False)
192 off = colorbar.ax.yaxis.get_offset_text()
193 off.set_x(10)
194
156 except Exception as e: 195 except Exception as e:
157 print(e) 196 print(e)
158 if measure: 197 if μ is not None:
198 μ_x = list(map(lambda x: x[0], μ))
199 μ_y = list(map(lambda x: x[1], μ))
200 μ_alpha = list(map(lambda x: x[2], μ))
159 for x, y, m in zip(μ_x, μ_y, μ_alpha): 201 for x, y, m in zip(μ_x, μ_y, μ_alpha):
160 ax.plot([x], [y], "ro", markersize=ms(m), label="Sources") 202 ax.plot([x], [y], "ro", markersize=ms(m), label="Sources")
203 if true_μ is not None:
204 true_μ_x = list(map(lambda x: x[0], true_μ))
205 true_μ_y = list(map(lambda x: x[1], true_μ))
206 true_μ_alpha = list(map(lambda x: x[2], true_μ))
161 for x, y, m in zip(true_μ_x, true_μ_y, true_μ_alpha): 207 for x, y, m in zip(true_μ_x, true_μ_y, true_μ_alpha):
162 ax.plot([x], [y], "kx", markersize=ms(m), label="True sources") 208 ax.plot([x], [y], "wx", markersize=ms(m), label="True sources")
163 if t_idx >= 0: 209 if t_idx >= 0:
164 ax.set_title(f"%s; t = {t_idx:.1f}" % name) 210 ax.set_title(
211 f"$%s_{{{t_idx:.1f}}}$" % name,
212 fontsize=20 if simple else 14,
213 )
165 else: 214 else:
166 ax.set_title("%s" % name) 215 ax.set_title(
216 "%s" % name,
217 fontsize=20 if simple else 14,
218 )
167 ax.set_aspect("equal") 219 ax.set_aspect("equal")
168 220
169 n = len(u_n_list_array) 221 n = len(u_n_list_array)
170 # frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m))) 222 # frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m)))
171 frames = range(0, n) 223 frames = range(0, n)
172 for i, axu, axw, axj, t_idx in zip( 224
173 frames, 225 for ax in all_axes():
174 axus, 226 ax.clear()
175 axws, 227
176 axjs, 228 if not simple:
177 times, # map(lambda i: i / (n - 1), frames), 229 for i, axu, axw, axj, t_idx in zip(
178 ):
179 plot_array(
180 "u", t_idx, axu, u_n_list_array[i].real, u_min, u_max, colorbar_u, True
181 )
182 plot_array(
183 "w", t_idx, axw, w_n_list_array[i].real, w_min, w_max, colorbar_w, True
184 )
185 plot_array(
186 "j", t_idx, axj, j_n_list_array[i].real, j_min, j_max, colorbar_j, True
187 )
188
189 if show_true:
190 for i, ax, t_idx in zip(
191 frames, 230 frames,
192 ax_wpbar_plus, 231 axus,
193 map(lambda i: i / (n - 1), frames), 232 axws,
233 axjs,
234 times, # map(lambda i: i / (n - 1), frames),
194 ): 235 ):
195 plot_array( 236 plot_array(
196 "û", 237 "u",
197 t_idx, 238 t_idx,
198 ax, 239 axu,
199 true_u_n_list_array[i].real, 240 u_n_list_array[i].real,
200 u_min, 241 u_min,
201 u_max, 242 u_max,
243 μ,
244 true_μ,
202 colorbar_u, 245 colorbar_u,
203 True, 246 )
204 ) 247 plot_array(
248 "w",
249 t_idx,
250 axw,
251 w_n_list_array[i].real,
252 w_min,
253 w_max,
254 μ,
255 true_μ,
256 colorbar_w,
257 )
258 plot_array(
259 "j",
260 t_idx,
261 axj,
262 j_n_list_array[i].real,
263 j_min,
264 j_max,
265 μ,
266 true_μ,
267 colorbar_j,
268 )
269
270 if show_true:
271 for i, ax, t_idx in zip(
272 frames,
273 ax_wpbar_plus,
274 map(lambda i: i / (n - 1), frames),
275 ):
276 plot_array(
277 "û",
278 t_idx,
279 ax,
280 true_u_n_list_array[i].real,
281 u_min,
282 u_max,
283 μ,
284 true_μ,
285 colorbar_u,
286 )
287 else:
288 plot_array(
289 "w̄ₚ",
290 -1,
291 ax_wpbar_plus[0],
292 wp_bar_array.real,
293 w_min,
294 w_max,
295 μ,
296 None,
297 )
298 plot_array(
299 "ω₀",
300 -1,
301 ax_wpbar_plus[1],
302 ω0.real,
303 w_min,
304 w_max,
305 μ,
306 None,
307 )
308 # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max)
309
310 # Has to be here after everything, for some reason
311 for ax in all_axes():
312 ax.set_xticklabels([])
313 ax.set_yticklabels([])
314 ax.set_aspect("equal")
315
316 plt.suptitle(
317 "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f;"
318 % (
319 iter,
320 len(μ),
321 alpha_mi,
322 alpha_ma,
323 alpha_me,
324 ),
325 fontsize=14,
326 )
327
328 str = "(k, c, c1) = (%g, %g, %g)" % (data["k"], data["c1"], data["c2"])
329 if "k" in data0:
330 str += "; true (%g, %g, %g)" % (data0["k"], data0["c1"], data0["c2"])
331 info.set_text(str)
332 else: # simple
333 for i, axu, t_idx in zip(
334 frames,
335 axus,
336 times, # map(lambda i: i / (n - 1), frames),
337 ):
338 if show_true:
339 plot_array(
340 "û",
341 t_idx,
342 axu,
343 true_u_n_list_array[i].real,
344 u_min,
345 u_max,
346 None,
347 true_μ,
348 colorbar_u,
349 )
350 else:
351 plot_array(
352 "u",
353 t_idx,
354 axu,
355 u_n_list_array[i].real,
356 u_min,
357 u_max,
358 μ,
359 true_μ,
360 colorbar_u,
361 )
362
363 def do_save(iter, show_true):
364 if not show_true:
365 filename = "%s/res_%d.pdf" % (prefix, iter)
205 else: 366 else:
206 plot_array("w̄ₚ", -1, ax_wpbar_plus[0], wp_bar_array.real, w_min, w_max) 367 filename = "%s/true_%d.pdf" % (prefix, iter)
207 plot_array("ω₀", -1, ax_wpbar_plus[1], ω0.real, w_min, w_max) 368 print("Saving ", filename)
208 # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max) 369 plt.savefig(
209 370 filename,
210 plt.suptitle( 371 dpi=300,
211 "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f" 372 bbox_inches="tight",
212 % (iter, len(μ), alpha_mi, alpha_ma, alpha_me), 373 pad_inches=0,
213 fontsize=14, 374 )
214 ) 375
215 # plt.savefig("solution_evolution_%d.png" % iter, dpi=300) 376 if init_iter is None:
216 377 k = 0
217 state = {"k": 0, "show_true": False} 378 else:
379 k, _ = next(
380 filter(lambda fubar: fubar[1][0] == init_iter, enumerate(iter_files))
381 )
382
383 do_plot(iter_files[k], show_true=False)
384
385 if save_only:
386 iter = iter_files[k][0]
387 do_save(iter, False)
388 do_plot(iter_files[k], show_true=True)
389 do_save(iter, True)
390 return
391
392 state = {"k": k, "show_true": False}
218 393
219 def on_key(event): 394 def on_key(event):
220 k0 = state["k"] 395 k0 = state["k"]
221 k = None 396 k = None
222 if event.key == "right" or event.key == " ": 397 if event.key == "right" or event.key == " ":
236 elif event.key == "t": 411 elif event.key == "t":
237 state["show_true"] = not state["show_true"] 412 state["show_true"] = not state["show_true"]
238 k = k0 413 k = k0
239 elif event.key == "q": 414 elif event.key == "q":
240 sys.exit() 415 sys.exit()
416 elif event.key == "z":
417 iter, _file = iter_files[k0]
418 do_save(iter, state["show_true"])
241 if k is not None: 419 if k is not None:
242 state["k"] = k 420 state["k"] = k
243 do_plot(iter_files[k], state["show_true"]) 421 do_plot(iter_files[k], state["show_true"])
244 fig.canvas.draw() 422 fig.canvas.draw()
245
246 do_plot(iter_files[0])
247 423
248 fig.canvas.mpl_connect("key_press_event", on_key) 424 fig.canvas.mpl_connect("key_press_event", on_key)
249 425
250 plt.show() 426 plt.show()
251 # # Time evolution 427 # # Time evolution
263 # plt.show() 439 # plt.show()
264 440
265 # print("Saved: solution_evolution.png + solution_time.png") 441 # print("Saved: solution_evolution.png + solution_time.png")
266 442
267 443
268 plot(sys.argv[1]) 444 parser = argparse.ArgumentParser(
445 prog="plot.py",
446 description="Plots results of pointsource_pde",
447 )
448
449 parser.add_argument("filename")
450 parser.add_argument("--simple", action="store_true")
451 parser.add_argument("--shaded", action="store_true")
452 parser.add_argument("--iter", type=int, action="store")
453 parser.add_argument("--save-only", action="store_true")
454
455 args = parser.parse_args()
456
457 plot(args.filename, args.simple, args.shaded, args.iter, args.save_only)

mercurial