Thu, 25 Apr 2024 14:20:38 -0500
merge
36 | 1 | ################## |
2 | # Experiment running and interactive visualisation | |
3 | ################## | |
4 | ||
5 | module Run | |
6 | ||
7 | import GR | |
8 | using Setfield | |
9 | using Printf | |
10 | using ImageQualityIndexes: assess_psnr, assess_ssim | |
11 | using FileIO | |
12 | #using JLD2 | |
13 | using DelimitedFiles | |
14 | ||
15 | using AlgTools.Util | |
16 | using AlgTools.StructTools | |
17 | using AlgTools.LinkedLists | |
18 | using AlgTools.Comms | |
19 | ||
20 | using ImageTools.Visualise: secs_ns, grayimg, do_visualise | |
21 | ||
22 | using ..ImGenerate | |
23 | using ..OpticalFlow: identifier, DisplacementFull, DisplacementConstant | |
24 | ||
25 | export run_experiments, | |
26 | Experiment, | |
27 | State, | |
28 | LogEntry, | |
29 | LogEntryHiFi | |
30 | ||
31 | ################ | |
32 | # Experiment | |
33 | ################ | |
34 | ||
35 | struct Experiment | |
36 | mod :: Module | |
37 | DisplacementT :: Type | |
38 | imgen :: ImGen | |
39 | params :: NamedTuple | |
40 | end | |
41 | ||
42 | function Base.show(io::IO, e::Experiment) | |
43 | displacementname(::Type{DisplacementFull}) = "DisplacementFull" | |
44 | displacementname(::Type{DisplacementConstant}) = "DisplacementConstant" | |
45 | print(io, " | |
46 | mod: $(e.mod) | |
47 | DisplacementT: $(displacementname(e.DisplacementT)) | |
48 | imgen: $(e.imgen.name) $(e.imgen.dim[1])×$(e.imgen.dim[2]) | |
49 | params: $(e.params ⬿ (kernel = "(not shown)",)) | |
50 | ") | |
51 | end | |
52 | ||
53 | ################ | |
54 | # Log | |
55 | ################ | |
56 | ||
57 | struct LogEntry <: IterableStruct | |
58 | iter :: Int | |
59 | time :: Float64 | |
60 | function_value :: Float64 | |
61 | #v_cumul_true_y :: Float64 | |
62 | #v_cumul_true_x :: Float64 | |
63 | #v_cumul_est_y :: Float64 | |
64 | #v_cumul_est_x :: Float64 | |
65 | psnr :: Float64 | |
66 | ssim :: Float64 | |
67 | #psnr_data :: Float64 | |
68 | #ssim_data :: Float64 | |
69 | end | |
70 | ||
71 | struct LogEntryHiFi <: IterableStruct | |
72 | iter :: Int | |
73 | v_cumul_true_y :: Float64 | |
74 | v_cumul_true_x :: Float64 | |
75 | end | |
76 | ||
77 | ############### | |
78 | # Main routine | |
79 | ############### | |
80 | ||
81 | struct State | |
82 | vis :: Union{Channel,Bool,Nothing} | |
83 | start_time :: Union{Real,Nothing} | |
84 | wasted_time :: Real | |
85 | log :: LinkedList{LogEntry} | |
86 | log_hifi :: LinkedList{LogEntryHiFi} | |
87 | aborted :: Bool | |
88 | end | |
89 | ||
90 | function name(e::Experiment, p) | |
91 | ig = e.imgen | |
92 | id = if haskey(p, :predictor) && ~isnothing(p.predictor) | |
93 | identifier(p.predictor) | |
94 | else | |
95 | e.mod.identifier | |
96 | end | |
97 | # return "$(ig.name)_$(e.mod.identifier)_$(@sprintf "%x" hash(p))" | |
98 | return "$(ig.name)_$(id)_$(Int64(100*p.α))_$(Int64(10000*p.σ₀))_$(Int64(10000*p.τ₀))" | |
99 | end | |
100 | ||
101 | function write_tex(texfile, e_params) | |
102 | open(texfile, "w") do io | |
103 | wp = (n, v) -> println(io, "\\def\\EXPPARAM$(n){$(v)}") | |
104 | wf = (n, s) -> if isdefined(e_params, s) | |
105 | wp(n, getfield(e_params, s)) | |
106 | end | |
107 | wf("alpha", :α) | |
108 | wf("sigmazero", :σ₀) | |
109 | wf("tauzero", :τ₀) | |
110 | wf("tildetauzero", :τ̃₀) | |
111 | wf("delta", :δ) | |
112 | wf("lambda", :λ) | |
113 | wf("theta", :θ) | |
114 | wf("maxiter", :maxiter) | |
115 | wf("noiselevel", :noise_level) | |
116 | wf("shakenoiselevel", :shake_noise_level) | |
117 | wf("shake", :shake) | |
118 | wf("timestep", :timestep) | |
119 | wf("displacementcount", :displacementcount) | |
120 | wf("phantomrho", :phantom_ρ) | |
121 | if isdefined(e_params, :σ₀) | |
122 | wp("sigma", (e_params.σ₀ == 1 ? "" : "$(e_params.σ₀)") * "\\sigma_{\\max}") | |
123 | end | |
124 | end | |
125 | end | |
126 | ||
127 | function run_experiments(;visualise=true, | |
128 | visfn = iterate_visualise, | |
129 | datatype = OnlineData, | |
130 | recalculate=true, | |
131 | experiments, | |
132 | save_prefix="", | |
133 | fullscreen=false, | |
134 | kwargs...) | |
135 | ||
136 | # Create visualisation | |
137 | if visualise | |
138 | rc = Channel(1) | |
139 | visproc = Threads.@spawn bg_visualise_enhanced(rc, fullscreen=fullscreen) | |
140 | bind(rc, visproc) | |
141 | vis = rc | |
142 | else | |
143 | vis = false | |
144 | end | |
145 | ||
146 | # Run all experiments | |
147 | for e ∈ experiments | |
148 | ||
149 | # Parameters for this experiment | |
150 | e_params = e.params ⬿ kwargs | |
151 | ename = name(e, e_params) | |
152 | e_params = e_params ⬿ (save_prefix = save_prefix * ename, | |
153 | dynrange = e.imgen.dynrange, | |
154 | Λ = e.imgen.Λ) | |
155 | ||
156 | if recalculate || !isfile(e_params.save_prefix * ".txt") | |
157 | println("Running experiment \"$(ename)\"") | |
158 | ||
159 | # Start data generation task | |
160 | datachannel = Channel{datatype{e.DisplacementT}}(2) | |
161 | gentask = Threads.@spawn e.imgen.f(e.DisplacementT, datachannel, e_params) | |
162 | bind(datachannel, gentask) | |
163 | ||
164 | # Run algorithm | |
165 | iterate = curry(visfn, datachannel, | |
166 | State(vis, nothing, 0.0, nothing, nothing, false)) | |
167 | ||
168 | x, y, st = e.mod.solve(e.DisplacementT; | |
169 | dim=e.imgen.dim, | |
170 | iterate=iterate, | |
171 | params=e_params) | |
172 | ||
173 | # Clear non-saveable things | |
174 | st = @set st.vis = nothing | |
175 | ||
176 | println("Wasted_time: $(st.wasted_time)s") | |
177 | ||
178 | if e_params.save_results | |
179 | println("Saving " * e_params.save_prefix * "(.txt,_hifi.txt,_params.tex)") | |
180 | ||
181 | perffile = e_params.save_prefix * ".txt" | |
182 | hififile = e_params.save_prefix * "_hifi.txt" | |
183 | texfile = e_params.save_prefix * "_params.tex" | |
184 | # datafile = e_params.save_prefix * ".jld2" | |
185 | ||
186 | write_log(perffile, st.log, "# params = $(e_params)\n") | |
187 | #write_log(hififile, st.log_hifi, "# params = $(e_params)\n") | |
188 | #write_tex(texfile, e_params) | |
189 | # @save datafile x y st params | |
190 | end | |
191 | ||
192 | close(datachannel) | |
193 | wait(gentask) | |
194 | ||
195 | if st.aborted | |
196 | break | |
197 | end | |
198 | else | |
199 | println("Skipping already computed experiment \"$(ename)\"") | |
200 | # texfile = e_params.save_prefix * "_params.tex" | |
201 | # write_tex(texfile, e_params) | |
202 | end | |
203 | end | |
204 | ||
205 | if visualise | |
206 | # Tell subprocess to finish, and wait | |
207 | put!(rc, nothing) | |
208 | close(rc) | |
209 | wait(visproc) | |
210 | end | |
211 | ||
212 | return | |
213 | end | |
214 | ||
215 | ||
216 | ###################################################### | |
217 | # Iterator that does visualisation and log collection | |
218 | ###################################################### | |
219 | ||
220 | function iterate_visualise(datachannel::Channel{OnlineData{DisplacementT}}, | |
221 | st :: State, | |
222 | step :: Function, | |
223 | params :: NamedTuple) where DisplacementT | |
224 | try | |
225 | sc = nothing | |
226 | ||
227 | d = take!(datachannel) | |
228 | ||
229 | for iter=1:params.maxiter | |
230 | dnext = take!(datachannel) | |
231 | st = step(d.b_noisy, d.v, dnext.b_noisy) do calc_objective | |
232 | stn = st | |
233 | ||
234 | if isnothing(stn.start_time) | |
235 | # The Julia precompiler is a miserable joke, apparently not crossing module | |
236 | # boundaries, so only start timing after the first iteration. | |
237 | stn = @set stn.start_time=secs_ns() | |
238 | end | |
239 | ||
240 | verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0 | |
241 | ||
242 | # Normalise movement to image dimensions so | |
243 | # our TikZ plotting code doesn't need to know | |
244 | # the image pixel size. | |
245 | sc = 1.0./maximum(size(d.b_true)) | |
246 | ||
247 | if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0) | |
248 | verb_start = secs_ns() | |
249 | tm = verb_start - stn.start_time - stn.wasted_time | |
250 | value, x, v, vhist = calc_objective() | |
251 | ||
252 | entry = LogEntry(iter, tm, value, | |
253 | #sc*d.v_cumul_true[1], | |
254 | #sc*d.v_cumul_true[2], | |
255 | #sc*v[1], sc*v[2], | |
256 | assess_psnr(x, d.b_true), | |
257 | assess_ssim(x, d.b_true), | |
258 | #assess_psnr(d.b_noisy, d.b_true), | |
259 | #assess_ssim(d.b_noisy, d.b_true) | |
260 | ) | |
261 | ||
262 | # (**) Collect a singly-linked list of log to avoid array resizing | |
263 | # while iterating | |
264 | stn = @set stn.log=LinkedListEntry(entry, stn.log) | |
265 | ||
266 | if !isnothing(vhist) | |
267 | vhist=vhist.*sc | |
268 | end | |
269 | ||
270 | if verb | |
271 | @printf("%d/%d J=%f, PSNR=%f, SSIM=%f, avg. FPS=%f\n", | |
272 | iter, params.maxiter, value, entry.psnr, | |
273 | entry.ssim, entry.iter/entry.time) | |
274 | if isa(stn.vis, Channel) | |
275 | put_onlylatest!(stn.vis, ((d.b_noisy, x), | |
276 | params.plot_movement, | |
277 | stn.log, vhist)) | |
278 | ||
279 | end | |
280 | end | |
281 | ||
282 | if params.save_images && (!haskey(params, :save_images_iters) | |
283 | || iter ∈ params.save_images_iters) | |
284 | fn = (t, ext) -> "$(params.save_prefix)_$(t)_frame$(iter).$(ext)" | |
285 | # save(File(format"PNG", fn("true", "png")), grayimg(d.b_true)) | |
286 | # save(File(format"PNG", fn("data", "png")), grayimg(d.b_noisy)) | |
287 | save(File(format"PNG", fn("reco", "png")), grayimg(x)) | |
288 | if !isnothing(vhist) | |
289 | open(fn("movement", "txt"), "w") do io | |
290 | writedlm(io, ["est_y" "est_x"]) | |
291 | writedlm(io, vhist) | |
292 | end | |
293 | end | |
294 | end | |
295 | ||
296 | stn = @set stn.wasted_time += (secs_ns() - verb_start) | |
297 | ||
298 | return stn | |
299 | end | |
300 | ||
301 | hifientry = LogEntryHiFi(iter, sc*d.v_cumul_true[1], sc*d.v_cumul_true[2]) | |
302 | st = @set st.log_hifi=LinkedListEntry(hifientry, st.log_hifi) | |
303 | ||
304 | return st | |
305 | end | |
306 | d=dnext | |
307 | end | |
308 | catch ex | |
309 | if params.handle_interrupt && isa(ex, InterruptException) | |
310 | # If SIGINT is received (user pressed ^C), terminate computations, | |
311 | # returning current status. Effectively, we do not call `step()` again, | |
312 | # ending the iterations, but letting the algorithm finish up. | |
313 | # Assuming (**) above occurs atomically, `st.log` should be valid, but | |
314 | # any results returned by the algorithm itself may be partial, as for | |
315 | # reasons of efficiency we do *not* store results of an iteration until | |
316 | # the next iteration is finished. | |
317 | printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202) | |
318 | st = @set st.aborted = true | |
319 | else | |
320 | rethrow(ex) | |
321 | end | |
322 | end | |
323 | ||
324 | return st | |
325 | end | |
326 | ||
327 | function bg_visualise_enhanced(rc; fullscreen=false) | |
328 | process_channel(rc) do d | |
329 | imgs, plot_movement, log, vhist = d | |
330 | do_visualise(imgs, refresh=false, fullscreen=fullscreen) | |
331 | # Overlay movement | |
332 | GR.settextcolorind(5) | |
333 | GR.setcharheight(0.015) | |
334 | GR.settextpath(GR.TEXT_PATH_RIGHT) | |
335 | tx, ty = GR.wctondc(0, 1) | |
336 | GR.text(tx, ty, @sprintf "FPS %.1f, SSIM %.2f, PSNR %.1f" (log.value.iter/log.value.time) log.value.ssim log.value.psnr) | |
337 | if plot_movement | |
338 | sc=1.0 | |
339 | p=unfold_linked_list(log) | |
340 | x=map(e -> 1.5+sc*e.v_cumul_true_x, p) | |
341 | y=map(e -> 0.5+sc*e.v_cumul_true_y, p) | |
342 | GR.setlinewidth(2) | |
343 | GR.setlinecolorind(2) | |
344 | GR.polyline(x, y) | |
345 | x=map(e -> 1.5+sc*e.v_cumul_est_x, p) | |
346 | y=map(e -> 0.5+sc*e.v_cumul_est_y, p) | |
347 | GR.setlinecolorind(3) | |
348 | GR.polyline(x, y) | |
349 | if vhist != nothing | |
350 | GR.setlinecolorind(4) | |
351 | x=map(v -> 1.5+sc*v, vhist[:,2]) | |
352 | y=map(v -> 0.5+sc*v, vhist[:,1]) | |
353 | GR.polyline(x, y) | |
354 | end | |
355 | end | |
356 | GR.updatews() | |
357 | end | |
358 | end | |
359 | ||
360 | end # module |