Mon, 06 May 2024 20:03:44 -0500
Add arXiv link for new paper to README
36 | 1 | ################## |
2 | # Experiment running and interactive visualisation | |
3 | ################## | |
4 | ||
5 | module Run | |
6 | ||
7 | import GR | |
8 | using Setfield | |
9 | using Printf | |
10 | using ImageQualityIndexes: assess_psnr, assess_ssim | |
11 | using FileIO | |
12 | #using JLD2 | |
13 | using DelimitedFiles | |
14 | ||
15 | using AlgTools.Util | |
16 | using AlgTools.StructTools | |
17 | using AlgTools.LinkedLists | |
18 | using AlgTools.Comms | |
19 | ||
20 | using ImageTools.Visualise: secs_ns, grayimg, do_visualise | |
21 | ||
22 | using ..ImGenerate | |
23 | using ..OpticalFlow: identifier, DisplacementFull, DisplacementConstant | |
24 | ||
25 | export run_experiments, | |
26 | Experiment, | |
27 | State, | |
28 | LogEntry, | |
29 | LogEntryHiFi | |
30 | ||
31 | ################ | |
32 | # Experiment | |
33 | ################ | |
34 | ||
35 | struct Experiment | |
36 | mod :: Module | |
37 | DisplacementT :: Type | |
38 | imgen :: ImGen | |
39 | params :: NamedTuple | |
40 | end | |
41 | ||
42 | function Base.show(io::IO, e::Experiment) | |
43 | displacementname(::Type{DisplacementFull}) = "DisplacementFull" | |
44 | displacementname(::Type{DisplacementConstant}) = "DisplacementConstant" | |
45 | print(io, " | |
46 | mod: $(e.mod) | |
47 | DisplacementT: $(displacementname(e.DisplacementT)) | |
48 | imgen: $(e.imgen.name) $(e.imgen.dim[1])×$(e.imgen.dim[2]) | |
49 | params: $(e.params ⬿ (kernel = "(not shown)",)) | |
50 | ") | |
51 | end | |
52 | ||
53 | ################ | |
54 | # Log | |
55 | ################ | |
56 | ||
57 | struct LogEntry <: IterableStruct | |
58 | iter :: Int | |
59 | time :: Float64 | |
60 | function_value :: Float64 | |
61 | #v_cumul_true_y :: Float64 | |
62 | #v_cumul_true_x :: Float64 | |
63 | #v_cumul_est_y :: Float64 | |
64 | #v_cumul_est_x :: Float64 | |
65 | psnr :: Float64 | |
66 | ssim :: Float64 | |
67 | #psnr_data :: Float64 | |
68 | #ssim_data :: Float64 | |
69 | end | |
70 | ||
71 | struct LogEntryHiFi <: IterableStruct | |
72 | iter :: Int | |
73 | v_cumul_true_y :: Float64 | |
74 | v_cumul_true_x :: Float64 | |
75 | end | |
76 | ||
77 | ############### | |
78 | # Main routine | |
79 | ############### | |
80 | ||
81 | struct State | |
82 | vis :: Union{Channel,Bool,Nothing} | |
83 | start_time :: Union{Real,Nothing} | |
84 | wasted_time :: Real | |
85 | log :: LinkedList{LogEntry} | |
86 | log_hifi :: LinkedList{LogEntryHiFi} | |
87 | aborted :: Bool | |
88 | end | |
89 | ||
90 | function name(e::Experiment, p) | |
91 | ig = e.imgen | |
92 | id = if haskey(p, :predictor) && ~isnothing(p.predictor) | |
93 | identifier(p.predictor) | |
94 | else | |
95 | e.mod.identifier | |
96 | end | |
43 | 97 | if haskey(p, :variant) |
98 | id *= "_" * p.variant | |
99 | end | |
36 | 100 | # return "$(ig.name)_$(e.mod.identifier)_$(@sprintf "%x" hash(p))" |
101 | return "$(ig.name)_$(id)_$(Int64(100*p.α))_$(Int64(10000*p.σ₀))_$(Int64(10000*p.τ₀))" | |
102 | end | |
103 | ||
104 | function write_tex(texfile, e_params) | |
105 | open(texfile, "w") do io | |
106 | wp = (n, v) -> println(io, "\\def\\EXPPARAM$(n){$(v)}") | |
107 | wf = (n, s) -> if isdefined(e_params, s) | |
108 | wp(n, getfield(e_params, s)) | |
109 | end | |
110 | wf("alpha", :α) | |
111 | wf("sigmazero", :σ₀) | |
112 | wf("tauzero", :τ₀) | |
113 | wf("tildetauzero", :τ̃₀) | |
114 | wf("delta", :δ) | |
115 | wf("lambda", :λ) | |
116 | wf("theta", :θ) | |
117 | wf("maxiter", :maxiter) | |
118 | wf("noiselevel", :noise_level) | |
119 | wf("shakenoiselevel", :shake_noise_level) | |
120 | wf("shake", :shake) | |
121 | wf("timestep", :timestep) | |
122 | wf("displacementcount", :displacementcount) | |
123 | wf("phantomrho", :phantom_ρ) | |
124 | if isdefined(e_params, :σ₀) | |
125 | wp("sigma", (e_params.σ₀ == 1 ? "" : "$(e_params.σ₀)") * "\\sigma_{\\max}") | |
126 | end | |
127 | end | |
128 | end | |
129 | ||
130 | function run_experiments(;visualise=true, | |
131 | visfn = iterate_visualise, | |
132 | datatype = OnlineData, | |
133 | recalculate=true, | |
134 | experiments, | |
135 | save_prefix="", | |
136 | fullscreen=false, | |
137 | kwargs...) | |
138 | ||
139 | # Create visualisation | |
140 | if visualise | |
141 | rc = Channel(1) | |
142 | visproc = Threads.@spawn bg_visualise_enhanced(rc, fullscreen=fullscreen) | |
143 | bind(rc, visproc) | |
144 | vis = rc | |
145 | else | |
146 | vis = false | |
147 | end | |
148 | ||
149 | # Run all experiments | |
150 | for e ∈ experiments | |
151 | ||
152 | # Parameters for this experiment | |
153 | e_params = e.params ⬿ kwargs | |
154 | ename = name(e, e_params) | |
155 | e_params = e_params ⬿ (save_prefix = save_prefix * ename, | |
156 | dynrange = e.imgen.dynrange, | |
157 | Λ = e.imgen.Λ) | |
158 | ||
159 | if recalculate || !isfile(e_params.save_prefix * ".txt") | |
160 | println("Running experiment \"$(ename)\"") | |
161 | ||
162 | # Start data generation task | |
163 | datachannel = Channel{datatype{e.DisplacementT}}(2) | |
164 | gentask = Threads.@spawn e.imgen.f(e.DisplacementT, datachannel, e_params) | |
165 | bind(datachannel, gentask) | |
166 | ||
167 | # Run algorithm | |
168 | iterate = curry(visfn, datachannel, | |
169 | State(vis, nothing, 0.0, nothing, nothing, false)) | |
170 | ||
171 | x, y, st = e.mod.solve(e.DisplacementT; | |
172 | dim=e.imgen.dim, | |
173 | iterate=iterate, | |
174 | params=e_params) | |
175 | ||
176 | # Clear non-saveable things | |
177 | st = @set st.vis = nothing | |
178 | ||
179 | println("Wasted_time: $(st.wasted_time)s") | |
180 | ||
181 | if e_params.save_results | |
182 | println("Saving " * e_params.save_prefix * "(.txt,_hifi.txt,_params.tex)") | |
183 | ||
184 | perffile = e_params.save_prefix * ".txt" | |
185 | hififile = e_params.save_prefix * "_hifi.txt" | |
186 | texfile = e_params.save_prefix * "_params.tex" | |
187 | # datafile = e_params.save_prefix * ".jld2" | |
188 | ||
189 | write_log(perffile, st.log, "# params = $(e_params)\n") | |
190 | #write_log(hififile, st.log_hifi, "# params = $(e_params)\n") | |
191 | #write_tex(texfile, e_params) | |
192 | # @save datafile x y st params | |
193 | end | |
194 | ||
195 | close(datachannel) | |
196 | wait(gentask) | |
197 | ||
198 | if st.aborted | |
199 | break | |
200 | end | |
201 | else | |
202 | println("Skipping already computed experiment \"$(ename)\"") | |
203 | # texfile = e_params.save_prefix * "_params.tex" | |
204 | # write_tex(texfile, e_params) | |
205 | end | |
206 | end | |
207 | ||
208 | if visualise | |
209 | # Tell subprocess to finish, and wait | |
210 | put!(rc, nothing) | |
211 | close(rc) | |
212 | wait(visproc) | |
213 | end | |
214 | ||
215 | return | |
216 | end | |
217 | ||
218 | ||
219 | ###################################################### | |
220 | # Iterator that does visualisation and log collection | |
221 | ###################################################### | |
222 | ||
223 | function iterate_visualise(datachannel::Channel{OnlineData{DisplacementT}}, | |
224 | st :: State, | |
225 | step :: Function, | |
226 | params :: NamedTuple) where DisplacementT | |
227 | try | |
228 | sc = nothing | |
229 | ||
230 | d = take!(datachannel) | |
231 | ||
232 | for iter=1:params.maxiter | |
233 | dnext = take!(datachannel) | |
234 | st = step(d.b_noisy, d.v, dnext.b_noisy) do calc_objective | |
235 | stn = st | |
236 | ||
237 | if isnothing(stn.start_time) | |
238 | # The Julia precompiler is a miserable joke, apparently not crossing module | |
239 | # boundaries, so only start timing after the first iteration. | |
240 | stn = @set stn.start_time=secs_ns() | |
241 | end | |
242 | ||
243 | verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0 | |
244 | ||
245 | # Normalise movement to image dimensions so | |
246 | # our TikZ plotting code doesn't need to know | |
247 | # the image pixel size. | |
248 | sc = 1.0./maximum(size(d.b_true)) | |
249 | ||
250 | if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0) | |
251 | verb_start = secs_ns() | |
252 | tm = verb_start - stn.start_time - stn.wasted_time | |
253 | value, x, v, vhist = calc_objective() | |
254 | ||
255 | entry = LogEntry(iter, tm, value, | |
256 | #sc*d.v_cumul_true[1], | |
257 | #sc*d.v_cumul_true[2], | |
258 | #sc*v[1], sc*v[2], | |
259 | assess_psnr(x, d.b_true), | |
260 | assess_ssim(x, d.b_true), | |
261 | #assess_psnr(d.b_noisy, d.b_true), | |
262 | #assess_ssim(d.b_noisy, d.b_true) | |
263 | ) | |
264 | ||
265 | # (**) Collect a singly-linked list of log to avoid array resizing | |
266 | # while iterating | |
267 | stn = @set stn.log=LinkedListEntry(entry, stn.log) | |
268 | ||
269 | if !isnothing(vhist) | |
270 | vhist=vhist.*sc | |
271 | end | |
272 | ||
273 | if verb | |
274 | @printf("%d/%d J=%f, PSNR=%f, SSIM=%f, avg. FPS=%f\n", | |
275 | iter, params.maxiter, value, entry.psnr, | |
276 | entry.ssim, entry.iter/entry.time) | |
277 | if isa(stn.vis, Channel) | |
278 | put_onlylatest!(stn.vis, ((d.b_noisy, x), | |
279 | params.plot_movement, | |
280 | stn.log, vhist)) | |
281 | ||
282 | end | |
283 | end | |
284 | ||
285 | if params.save_images && (!haskey(params, :save_images_iters) | |
286 | || iter ∈ params.save_images_iters) | |
287 | fn = (t, ext) -> "$(params.save_prefix)_$(t)_frame$(iter).$(ext)" | |
288 | # save(File(format"PNG", fn("true", "png")), grayimg(d.b_true)) | |
289 | # save(File(format"PNG", fn("data", "png")), grayimg(d.b_noisy)) | |
290 | save(File(format"PNG", fn("reco", "png")), grayimg(x)) | |
291 | if !isnothing(vhist) | |
292 | open(fn("movement", "txt"), "w") do io | |
293 | writedlm(io, ["est_y" "est_x"]) | |
294 | writedlm(io, vhist) | |
295 | end | |
296 | end | |
297 | end | |
298 | ||
299 | stn = @set stn.wasted_time += (secs_ns() - verb_start) | |
300 | ||
301 | return stn | |
302 | end | |
303 | ||
304 | hifientry = LogEntryHiFi(iter, sc*d.v_cumul_true[1], sc*d.v_cumul_true[2]) | |
305 | st = @set st.log_hifi=LinkedListEntry(hifientry, st.log_hifi) | |
306 | ||
307 | return st | |
308 | end | |
309 | d=dnext | |
310 | end | |
311 | catch ex | |
312 | if params.handle_interrupt && isa(ex, InterruptException) | |
313 | # If SIGINT is received (user pressed ^C), terminate computations, | |
314 | # returning current status. Effectively, we do not call `step()` again, | |
315 | # ending the iterations, but letting the algorithm finish up. | |
316 | # Assuming (**) above occurs atomically, `st.log` should be valid, but | |
317 | # any results returned by the algorithm itself may be partial, as for | |
318 | # reasons of efficiency we do *not* store results of an iteration until | |
319 | # the next iteration is finished. | |
320 | printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202) | |
321 | st = @set st.aborted = true | |
322 | else | |
323 | rethrow(ex) | |
324 | end | |
325 | end | |
326 | ||
327 | return st | |
328 | end | |
329 | ||
330 | function bg_visualise_enhanced(rc; fullscreen=false) | |
331 | process_channel(rc) do d | |
332 | imgs, plot_movement, log, vhist = d | |
333 | do_visualise(imgs, refresh=false, fullscreen=fullscreen) | |
334 | # Overlay movement | |
335 | GR.settextcolorind(5) | |
336 | GR.setcharheight(0.015) | |
337 | GR.settextpath(GR.TEXT_PATH_RIGHT) | |
338 | tx, ty = GR.wctondc(0, 1) | |
339 | GR.text(tx, ty, @sprintf "FPS %.1f, SSIM %.2f, PSNR %.1f" (log.value.iter/log.value.time) log.value.ssim log.value.psnr) | |
340 | if plot_movement | |
341 | sc=1.0 | |
342 | p=unfold_linked_list(log) | |
343 | x=map(e -> 1.5+sc*e.v_cumul_true_x, p) | |
344 | y=map(e -> 0.5+sc*e.v_cumul_true_y, p) | |
345 | GR.setlinewidth(2) | |
346 | GR.setlinecolorind(2) | |
347 | GR.polyline(x, y) | |
348 | x=map(e -> 1.5+sc*e.v_cumul_est_x, p) | |
349 | y=map(e -> 0.5+sc*e.v_cumul_est_y, p) | |
350 | GR.setlinecolorind(3) | |
351 | GR.polyline(x, y) | |
352 | if vhist != nothing | |
353 | GR.setlinecolorind(4) | |
354 | x=map(v -> 1.5+sc*v, vhist[:,2]) | |
355 | y=map(v -> 0.5+sc*v, vhist[:,1]) | |
356 | GR.polyline(x, y) | |
357 | end | |
358 | end | |
359 | GR.updatews() | |
360 | end | |
361 | end | |
362 | ||
363 | end # module |