| |
1 import os |
| |
2 import re |
| |
3 import sys |
| |
4 from itertools import chain |
| |
5 from pathlib import Path |
| |
6 from statistics import median |
| |
7 |
| |
8 import matplotlib as mpl |
| |
9 import matplotlib.cm as cm |
| |
10 import matplotlib.colors as colors |
| |
11 import matplotlib.pyplot as plt |
| |
12 import matplotlib.tri as tri |
| |
13 import numpy as np |
| |
14 from dolfinx import fem |
| |
15 |
| |
16 from src.convection_diffusion import ConvectionDiffusion |
| |
17 |
| |
18 |
| |
19 def find_files(directory): |
| |
20 """ |
| |
21 Return a list of Path objects for files named 'fubar%d.npz' |
| |
22 in the given directory, sorted by the integer %d. |
| |
23 """ |
| |
24 directory = Path(directory) |
| |
25 pattern = re.compile(r"^res_(\d+)\.npz$") |
| |
26 |
| |
27 matches = [] |
| |
28 |
| |
29 for path in directory.iterdir(): |
| |
30 if path.is_file(): |
| |
31 m = pattern.match(path.name) |
| |
32 if m: |
| |
33 number = int(m.group(1)) |
| |
34 matches.append((number, path)) |
| |
35 |
| |
36 # Sort numerically by extracted number |
| |
37 matches.sort(key=lambda x: x[0]) |
| |
38 |
| |
39 return matches |
| |
40 |
| |
41 |
| |
42 def plot(prefix): |
| |
43 quantisation = 32 |
| |
44 |
| |
45 iter_files = find_files(prefix) |
| |
46 |
| |
47 pde = ConvectionDiffusion( |
| |
48 u0=lambda x: 0.0 * x[1], |
| |
49 g=lambda x: 0.0 * x[1], # ((x[1] + 2) / 4) ** 2, # nonzero bcs |
| |
50 # w0=lambda x: np.cos(5 * x[1]), |
| |
51 domain_size=[0, 0.5, 0, 0.5], |
| |
52 # nx=128, |
| |
53 # ny=128, |
| |
54 ) |
| |
55 |
| |
56 # Get coordinates |
| |
57 coords = pde.V.tabulate_dof_coordinates() |
| |
58 x_coords, y_coords = coords[:, 0], coords[:, 1] |
| |
59 triang = tri.Triangulation(x_coords, y_coords) |
| |
60 |
| |
61 m = 5 |
| |
62 fig, (axus, axws, axjs, ax_wpbar_plus) = plt.subplots(4, m, figsize=(15, 4)) |
| |
63 plt.tight_layout() |
| |
64 fig.set_size_inches(13, 8) |
| |
65 fig.subplots_adjust( |
| |
66 left=0.02, right=0.95, bottom=0.02, top=0.93, wspace=0.2, hspace=0.2 |
| |
67 ) |
| |
68 for ax in chain(axus, axws, axjs, ax_wpbar_plus): |
| |
69 ax.set_xticklabels([]) |
| |
70 ax.set_yticklabels([]) |
| |
71 ax.set_aspect("equal") |
| |
72 |
| |
73 mpl.rcParams["axes.labelsize"] = 9 |
| |
74 |
| |
75 norm = colors.Normalize(vmin=0, vmax=1) |
| |
76 sm = cm.ScalarMappable(norm=norm, cmap="viridis") |
| |
77 sm.set_array([]) # REQUIRED (matplotlib quirk) |
| |
78 colorbar_u = fig.colorbar(sm, ax=axus[-1]) |
| |
79 colorbar_w = fig.colorbar(sm, ax=axws[-1]) |
| |
80 colorbar_j = fig.colorbar(sm, ax=axjs[-1]) |
| |
81 |
| |
82 def do_plot(index_and_file, show_true=False): |
| |
83 (iter, file) = index_and_file |
| |
84 |
| |
85 print("🎨 Plotting iteration %d..." % iter) |
| |
86 |
| |
87 data0 = np.load("%s/omega_0.npz" % Path(prefix).parent) |
| |
88 ω0 = data0["omega0"] |
| |
89 true_μ = data0["true_mu"] |
| |
90 # true_ω = data0["true_omega"] |
| |
91 true_u_n_list_array = data0["true_u_n_list_array"] |
| |
92 ω0_min = ω0.real.min() |
| |
93 ω0_max = ω0.real.max() |
| |
94 |
| |
95 for ax in chain(axus, axws, axjs, ax_wpbar_plus): |
| |
96 # ax.clear() |
| |
97 for artist in chain(ax.lines, ax.collections, ax.images): |
| |
98 artist.remove() |
| |
99 |
| |
100 data = np.load(file) |
| |
101 u_n_list_array = data["u_n_list_array"] |
| |
102 w_n_list_array = data["w_n_list_array"] |
| |
103 j_n_list_array = data["j_n_list_array"] |
| |
104 # frames = data["frames"] |
| |
105 times = data["times"] |
| |
106 mu = data["mu"] |
| |
107 wp_bar_array = data["wp_bar_array"] |
| |
108 |
| |
109 u_min = min(u.real.min() for u in u_n_list_array) |
| |
110 u_max = max(u.real.max() for u in u_n_list_array) |
| |
111 if show_true: |
| |
112 u_min = min(u_min, min(u.real.min() for u in true_u_n_list_array)) |
| |
113 u_max = max(u_max, max(u.real.max() for u in true_u_n_list_array)) |
| |
114 |
| |
115 w_min = min( |
| |
116 min(w.real.min() for w in w_n_list_array), wp_bar_array.real.min(), ω0_min |
| |
117 ) |
| |
118 w_max = max( |
| |
119 max(w.real.max() for w in w_n_list_array), wp_bar_array.real.max(), ω0_max |
| |
120 ) |
| |
121 j_min = min(j.real.min() for j in j_n_list_array) |
| |
122 j_max = max(j.real.max() for j in j_n_list_array) |
| |
123 |
| |
124 μ = list(filter(lambda x: x[2] != 0.0 and not np.isnan(x).any(), mu)) |
| |
125 μ_x = list(map(lambda x: x[0], μ)) |
| |
126 μ_y = list(map(lambda x: x[1], μ)) |
| |
127 μ_alpha = list(map(lambda x: x[2], μ)) |
| |
128 true_μ_x = list(map(lambda x: x[0], true_μ)) |
| |
129 true_μ_y = list(map(lambda x: x[1], true_μ)) |
| |
130 true_μ_alpha = list(map(lambda x: x[2], true_μ)) |
| |
131 |
| |
132 if len(μ_alpha) == 0: |
| |
133 alpha_mi, alpha_ma, alpha_me = 0, 0, 0 |
| |
134 else: |
| |
135 alpha_mi, alpha_ma, alpha_me = min(μ_alpha), max(μ_alpha), median(μ_alpha) |
| |
136 |
| |
137 true_alpha_mi, true_alpha_ma = min(true_μ_alpha), max(true_μ_alpha) |
| |
138 alpha_base = min(true_alpha_mi, alpha_mi) |
| |
139 alpha_scale = max(true_alpha_ma, alpha_ma) - alpha_base |
| |
140 if alpha_scale == 0: |
| |
141 ms = lambda m: 6 |
| |
142 else: |
| |
143 ms = lambda m: int(1 + (m - alpha_base) / alpha_scale * 10) |
| |
144 |
| |
145 def plot_array( |
| |
146 name, t_idx, ax, u_array, u_min, u_max, colorbar=None, measure=False |
| |
147 ): |
| |
148 if u_min == u_max: |
| |
149 levels = [u_min, u_min + 1e-9] |
| |
150 else: |
| |
151 levels = np.linspace(u_min, u_max, quantisation) |
| |
152 try: |
| |
153 contour = ax.tricontourf(triang, u_array, levels=levels, cmap="viridis") |
| |
154 if colorbar: |
| |
155 colorbar.update_normal(contour) |
| |
156 except Exception as e: |
| |
157 print(e) |
| |
158 if measure: |
| |
159 for x, y, m in zip(μ_x, μ_y, μ_alpha): |
| |
160 ax.plot([x], [y], "ro", markersize=ms(m), label="Sources") |
| |
161 for x, y, m in zip(true_μ_x, true_μ_y, true_μ_alpha): |
| |
162 ax.plot([x], [y], "kx", markersize=ms(m), label="True sources") |
| |
163 if t_idx >= 0: |
| |
164 ax.set_title(f"%s; t = {t_idx:.1f}" % name) |
| |
165 else: |
| |
166 ax.set_title("%s" % name) |
| |
167 ax.set_aspect("equal") |
| |
168 |
| |
169 n = len(u_n_list_array) |
| |
170 # frames = list(map(lambda i: i * (n - 1) // (m - 1), range(0, m))) |
| |
171 frames = range(0, n) |
| |
172 for i, axu, axw, axj, t_idx in zip( |
| |
173 frames, |
| |
174 axus, |
| |
175 axws, |
| |
176 axjs, |
| |
177 times, # map(lambda i: i / (n - 1), frames), |
| |
178 ): |
| |
179 plot_array( |
| |
180 "u", t_idx, axu, u_n_list_array[i].real, u_min, u_max, colorbar_u, True |
| |
181 ) |
| |
182 plot_array( |
| |
183 "w", t_idx, axw, w_n_list_array[i].real, w_min, w_max, colorbar_w, True |
| |
184 ) |
| |
185 plot_array( |
| |
186 "j", t_idx, axj, j_n_list_array[i].real, j_min, j_max, colorbar_j, True |
| |
187 ) |
| |
188 |
| |
189 if show_true: |
| |
190 for i, ax, t_idx in zip( |
| |
191 frames, |
| |
192 ax_wpbar_plus, |
| |
193 map(lambda i: i / (n - 1), frames), |
| |
194 ): |
| |
195 plot_array( |
| |
196 "û", |
| |
197 t_idx, |
| |
198 ax, |
| |
199 true_u_n_list_array[i].real, |
| |
200 u_min, |
| |
201 u_max, |
| |
202 colorbar_u, |
| |
203 True, |
| |
204 ) |
| |
205 else: |
| |
206 plot_array("w̄ₚ", -1, ax_wpbar_plus[0], wp_bar_array.real, w_min, w_max) |
| |
207 plot_array("ω₀", -1, ax_wpbar_plus[1], ω0.real, w_min, w_max) |
| |
208 # plot_array("ω̂", -1, ax_wpbar_plus[2], true_ω.real, w_min, w_max) |
| |
209 |
| |
210 plt.suptitle( |
| |
211 "Convection-Diffusion, iteration %d; len(μ) = %d; μ_min = %f, μ_max = %f, μ_median = %f" |
| |
212 % (iter, len(μ), alpha_mi, alpha_ma, alpha_me), |
| |
213 fontsize=14, |
| |
214 ) |
| |
215 # plt.savefig("solution_evolution_%d.png" % iter, dpi=300) |
| |
216 |
| |
217 state = {"k": 0, "show_true": False} |
| |
218 |
| |
219 def on_key(event): |
| |
220 k0 = state["k"] |
| |
221 k = None |
| |
222 if event.key == "right" or event.key == " ": |
| |
223 k = (k0 + 1) % len(iter_files) |
| |
224 elif event.key == "left" or event.key == "backspace": |
| |
225 k = (k0 - 1) % len(iter_files) |
| |
226 elif event.key == "shift+right": |
| |
227 k = (k0 + 10) % len(iter_files) |
| |
228 elif event.key == "shift+left": |
| |
229 k = (k0 - 10) % len(iter_files) |
| |
230 elif event.key == "up": |
| |
231 k = (k0 + 100) % len(iter_files) |
| |
232 elif event.key == "down": |
| |
233 k = (k0 - 100) % len(iter_files) |
| |
234 elif event.key == "0": |
| |
235 k = 0 |
| |
236 elif event.key == "t": |
| |
237 state["show_true"] = not state["show_true"] |
| |
238 k = k0 |
| |
239 elif event.key == "q": |
| |
240 sys.exit() |
| |
241 if k is not None: |
| |
242 state["k"] = k |
| |
243 do_plot(iter_files[k], state["show_true"]) |
| |
244 fig.canvas.draw() |
| |
245 |
| |
246 do_plot(iter_files[0]) |
| |
247 |
| |
248 fig.canvas.mpl_connect("key_press_event", on_key) |
| |
249 |
| |
250 plt.show() |
| |
251 # # Time evolution |
| |
252 # fig, ax = plt.subplots(figsize=(10, 4)) |
| |
253 # times = np.arange(len(u_n_list)) * pde.dt |
| |
254 # max_vals = [np.max(u.x.array.real) for u in u_n_list] |
| |
255 |
| |
256 # ax.plot(times, max_vals, "r-o", linewidth=2, markersize=4) |
| |
257 # ax.set_xlabel("Time t") |
| |
258 # ax.set_ylabel("max|u|") |
| |
259 # ax.grid(True, alpha=0.3) |
| |
260 # ax.set_title("Solution Evolution") |
| |
261 # plt.tight_layout() |
| |
262 # plt.savefig("solution_time.png", dpi=150) |
| |
263 # plt.show() |
| |
264 |
| |
265 # print("Saved: solution_evolution.png + solution_time.png") |
| |
266 |
| |
267 |
| |
268 plot(sys.argv[1]) |