Tue, 07 Apr 2020 14:19:48 -0500
Initialise independent repo
0 | 1 | ################## |
2 | # Our main module | |
3 | ################## | |
4 | ||
5 | __precompile__() | |
6 | ||
7 | module PredictPDPS | |
8 | ||
9 | ######################## | |
10 | # Load external modules | |
11 | ######################## | |
12 | ||
13 | using Printf | |
14 | using FileIO | |
15 | #using JLD2 | |
16 | using Setfield | |
17 | using ImageQualityIndexes: psnr, ssim | |
18 | using DelimitedFiles | |
19 | import GR | |
20 | ||
21 | using AlgTools.Util | |
22 | using AlgTools.StructTools | |
23 | using AlgTools.LinkedLists | |
24 | using AlgTools.Comms | |
25 | using ImageTools.Visualise: secs_ns, grayimg, do_visualise | |
26 | using ImageTools.ImFilter: gaussian | |
27 | ||
28 | ##################### | |
29 | # Load local modules | |
30 | ##################### | |
31 | ||
32 | include("OpticalFlow.jl") | |
33 | include("ImGenerate.jl") | |
34 | include("Algorithm.jl") | |
35 | include("AlgorithmBoth.jl") | |
36 | include("AlgorithmBothGreedyV.jl") | |
37 | include("AlgorithmBothCumul.jl") | |
38 | include("AlgorithmBothMulti.jl") | |
39 | include("AlgorithmBothNL.jl") | |
40 | include("AlgorithmFB.jl") | |
41 | include("AlgorithmFBDual.jl") | |
42 | ||
43 | import .Algorithm, | |
44 | .AlgorithmBoth, | |
45 | .AlgorithmBothGreedyV, | |
46 | .AlgorithmBothCumul, | |
47 | .AlgorithmBothMulti, | |
48 | .AlgorithmBothNL, | |
49 | .AlgorithmFB, | |
50 | .AlgorithmFBDual | |
51 | ||
52 | using .ImGenerate | |
53 | using .OpticalFlow: DisplacementFull, DisplacementConstant | |
54 | ||
55 | ############## | |
56 | # Our exports | |
57 | ############## | |
58 | ||
59 | export run_experiments, | |
60 | batchrun_article, | |
61 | demo_known1, | |
62 | demo_known2, | |
63 | demo_known3, | |
64 | demo_unknown1, | |
65 | demo_unknown2, | |
66 | demo_unknown3 | |
67 | ||
68 | ################################### | |
69 | # Parameterisation and experiments | |
70 | ################################### | |
71 | ||
72 | struct Experiment | |
73 | mod :: Module | |
74 | DisplacementT :: Type | |
75 | imgen :: ImGen | |
76 | params :: NamedTuple | |
77 | end | |
78 | ||
79 | function Base.show(io::IO, e::Experiment) | |
80 | displacementname(::Type{DisplacementFull}) = "DisplacementFull" | |
81 | displacementname(::Type{DisplacementConstant}) = "DisplacementConstant" | |
82 | print(io, " | |
83 | mod: $(e.mod) | |
84 | DisplacementT: $(displacementname(e.DisplacementT)) | |
85 | imgen: $(e.imgen.name) $(e.imgen.dim[1])×$(e.imgen.dim[2]) | |
86 | params: $(e.params ⬿ (kernel = "(not shown)",)) | |
87 | ") | |
88 | end | |
89 | ||
90 | const default_save_prefix="img/" | |
91 | ||
92 | const default_params = ( | |
93 | ρ = 0, | |
94 | verbose_iter = 100, | |
95 | maxiter = 10000, | |
96 | save_results = true, | |
97 | save_images = true, | |
98 | save_images_iters = Set([1, 2, 3, 5, | |
99 | 10, 25, 30, 50, | |
100 | 100, 250, 300, 500, | |
101 | 1000, 2500, 3000, 5000, | |
102 | 10000]), | |
103 | pixelwise_displacement=false, | |
104 | dual_flow = true, | |
105 | prox_predict = true, | |
106 | handle_interrupt = true, | |
107 | init = :zero, | |
108 | plot_movement = false, | |
109 | ) | |
110 | ||
111 | const square = imgen_square((200, 300)) | |
112 | const lighthouse = imgen_shake("lighthouse", (200, 300)) | |
113 | ||
114 | const p_known₀ = ( | |
115 | noise_level = 0.5, | |
116 | shake_noise_level = 0.05, | |
117 | shake = 2, | |
118 | α = 1, | |
119 | ρ̃₀ = 1, | |
120 | σ̃₀ = 1, | |
121 | δ = 0.9, | |
122 | σ₀ = 1, | |
123 | τ₀ = 0.01, | |
124 | ) | |
125 | ||
126 | const p_unknown₀ = ( | |
127 | noise_level = 0.3, | |
128 | shake_noise_level = 0.05, | |
129 | shake = 2, | |
130 | α = 0.2, | |
131 | ρ̃₀ = 1, | |
132 | σ̃₀ = 1, | |
133 | σ₀ = 1, | |
134 | δ = 0.9, | |
135 | λ = 1, | |
136 | θ = (300*200)*100^3, | |
137 | kernel = gaussian((3, 3), (11, 11)), | |
138 | timestep = 0.5, | |
139 | displacement_count = 100, | |
140 | τ₀ = 0.01, | |
141 | ) | |
142 | ||
143 | const experiments_pdps_known = ( | |
144 | Experiment(Algorithm, DisplacementConstant, lighthouse, | |
145 | p_known₀ ⬿ (phantom_ρ = 0,)), | |
146 | Experiment(Algorithm, DisplacementConstant, lighthouse, | |
147 | p_known₀ ⬿ (phantom_ρ = 100,)), | |
148 | Experiment(Algorithm, DisplacementConstant, square, | |
149 | p_known₀ ⬿ (phantom_ρ = 0,)) | |
150 | ) | |
151 | ||
152 | const experiments_pdps_unknown_multi = ( | |
153 | Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse, | |
154 | p_unknown₀ ⬿ (phantom_ρ = 0,)), | |
155 | Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse, | |
156 | p_unknown₀ ⬿ (phantom_ρ = 100,)), | |
157 | Experiment(AlgorithmBothMulti, DisplacementConstant, square, | |
158 | p_unknown₀ ⬿ (phantom_ρ = 0,)), | |
159 | ) | |
160 | ||
161 | const experiments_fb_known = ( | |
162 | Experiment(AlgorithmFB, DisplacementConstant, lighthouse, | |
163 | p_known₀ ⬿ (τ̃₀=0.9, fb_inner_iterations = 10)), | |
164 | ) | |
165 | ||
166 | const experiments_all = Iterators.flatten(( | |
167 | experiments_pdps_known, | |
168 | experiments_pdps_unknown_multi, | |
169 | experiments_fb_known | |
170 | )) | |
171 | ||
172 | ################ | |
173 | # Log | |
174 | ################ | |
175 | ||
176 | struct LogEntry <: IterableStruct | |
177 | iter :: Int | |
178 | time :: Float64 | |
179 | function_value :: Float64 | |
180 | v_cumul_true_y :: Float64 | |
181 | v_cumul_true_x :: Float64 | |
182 | v_cumul_est_y :: Float64 | |
183 | v_cumul_est_x :: Float64 | |
184 | psnr :: Float64 | |
185 | ssim :: Float64 | |
186 | psnr_data :: Float64 | |
187 | ssim_data :: Float64 | |
188 | end | |
189 | ||
190 | struct LogEntryHiFi <: IterableStruct | |
191 | iter :: Int | |
192 | v_cumul_true_y :: Float64 | |
193 | v_cumul_true_x :: Float64 | |
194 | end | |
195 | ||
196 | ############### | |
197 | # Main routine | |
198 | ############### | |
199 | ||
200 | struct State | |
201 | vis :: Union{Channel,Bool,Nothing} | |
202 | start_time :: Union{Real,Nothing} | |
203 | wasted_time :: Real | |
204 | log :: LinkedList{LogEntry} | |
205 | log_hifi :: LinkedList{LogEntryHiFi} | |
206 | aborted :: Bool | |
207 | end | |
208 | ||
209 | function name(e::Experiment, p) | |
210 | ig = e.imgen | |
211 | return "$(ig.name)_$(e.mod.identifier)_$(@sprintf "%x" hash(p))" | |
212 | end | |
213 | ||
214 | function write_tex(texfile, e_params) | |
215 | open(texfile, "w") do io | |
216 | wp = (n, v) -> println(io, "\\def\\EXPPARAM$(n){$(v)}") | |
217 | wf = (n, s) -> if isdefined(e_params, s) | |
218 | wp(n, getfield(e_params, s)) | |
219 | end | |
220 | wf("alpha", :α) | |
221 | wf("sigmazero", :σ₀) | |
222 | wf("tauzero", :τ₀) | |
223 | wf("tildetauzero", :τ̃₀) | |
224 | wf("delta", :δ) | |
225 | wf("lambda", :λ) | |
226 | wf("theta", :θ) | |
227 | wf("maxiter", :maxiter) | |
228 | wf("noiselevel", :noise_level) | |
229 | wf("shakenoiselevel", :shake_noise_level) | |
230 | wf("shake", :shake) | |
231 | wf("timestep", :timestep) | |
232 | wf("displacementcount", :displacementcount) | |
233 | wf("phantomrho", :phantom_ρ) | |
234 | if isdefined(e_params, :σ₀) | |
235 | wp("sigma", (e_params.σ₀ == 1 ? "" : "$(e_params.σ₀)") * "\\sigma_{\\max}") | |
236 | end | |
237 | end | |
238 | end | |
239 | ||
240 | function run_experiments(;visualise=true, | |
241 | recalculate=true, | |
242 | experiments, | |
243 | save_prefix=default_save_prefix, | |
244 | fullscreen=false, | |
245 | kwargs...) | |
246 | ||
247 | # Create visualisation | |
248 | if visualise | |
249 | rc = Channel(1) | |
250 | visproc = Threads.@spawn bg_visualise_enhanced(rc, fullscreen=fullscreen) | |
251 | bind(rc, visproc) | |
252 | vis = rc | |
253 | else | |
254 | vis = false | |
255 | end | |
256 | ||
257 | # Run all experiments | |
258 | for e ∈ experiments | |
259 | ||
260 | # Parameters for this experiment | |
261 | e_params = default_params ⬿ e.params ⬿ kwargs | |
262 | ename = name(e, e_params) | |
263 | e_params = e_params ⬿ (save_prefix = save_prefix * ename, | |
264 | dynrange = e.imgen.dynrange, | |
265 | Λ = e.imgen.Λ) | |
266 | ||
267 | if recalculate || !isfile(e_params.save_prefix * ".txt") | |
268 | println("Running experiment \"$(ename)\"") | |
269 | ||
270 | # Start data generation task | |
271 | datachannel = Channel{OnlineData{e.DisplacementT}}(2) | |
272 | gentask = Threads.@spawn e.imgen.f(e.DisplacementT, datachannel, e_params) | |
273 | bind(datachannel, gentask) | |
274 | ||
275 | # Run algorithm | |
276 | iterate = curry(iterate_visualise, datachannel, | |
277 | State(vis, nothing, 0.0, nothing, nothing, false)) | |
278 | ||
279 | x, y, st = e.mod.solve(e.DisplacementT; | |
280 | dim=e.imgen.dim, | |
281 | iterate=iterate, | |
282 | params=e_params) | |
283 | ||
284 | # Clear non-saveable things | |
285 | st = @set st.vis = nothing | |
286 | ||
287 | println("Wasted_time: $(st.wasted_time)s") | |
288 | ||
289 | if e_params.save_results | |
290 | println("Saving " * e_params.save_prefix * "(.txt,_hifi.txt,_params.tex)") | |
291 | ||
292 | perffile = e_params.save_prefix * ".txt" | |
293 | hififile = e_params.save_prefix * "_hifi.txt" | |
294 | texfile = e_params.save_prefix * "_params.tex" | |
295 | # datafile = e_params.save_prefix * ".jld2" | |
296 | ||
297 | write_log(perffile, st.log, "# params = $(e_params)\n") | |
298 | write_log(hififile, st.log_hifi, "# params = $(e_params)\n") | |
299 | write_tex(texfile, e_params) | |
300 | # @save datafile x y st params | |
301 | end | |
302 | ||
303 | close(datachannel) | |
304 | wait(gentask) | |
305 | ||
306 | if st.aborted | |
307 | break | |
308 | end | |
309 | else | |
310 | println("Skipping already computed experiment \"$(ename)\"") | |
311 | # texfile = e_params.save_prefix * "_params.tex" | |
312 | # write_tex(texfile, e_params) | |
313 | end | |
314 | end | |
315 | ||
316 | if visualise | |
317 | # Tell subprocess to finish, and wait | |
318 | put!(rc, nothing) | |
319 | close(rc) | |
320 | wait(visproc) | |
321 | end | |
322 | ||
323 | return | |
324 | end | |
325 | ||
326 | ####################### | |
327 | # Demos and batch runs | |
328 | ####################### | |
329 | ||
330 | function demo(experiment; kwargs...) | |
331 | run_experiments(;experiments=(experiment,), | |
332 | save_results=false, | |
333 | save_images=false, | |
334 | visualise=true, | |
335 | recalculate=true, | |
336 | verbose_iter=10, | |
337 | fullscreen=true, | |
338 | kwargs...) | |
339 | end | |
340 | ||
341 | demo_known1 = () -> demo(experiments_pdps_known[3]) | |
342 | demo_known2 = () -> demo(experiments_pdps_known[1]) | |
343 | demo_known3 = () -> demo(experiments_pdps_known[2]) | |
344 | demo_unknown1 = () -> demo(experiments_pdps_unknown_multi[3], plot_movement=true) | |
345 | demo_unknown2 = () -> demo(experiments_pdps_unknown_multi[1], plot_movement=true) | |
346 | demo_unknown3 = () -> demo(experiments_pdps_unknown_multi[2], plot_movement=true) | |
347 | ||
348 | function batchrun_article(kwargs...) | |
349 | run_experiments(;experiments=experiments_all, | |
350 | save_results=true, | |
351 | save_images=true, | |
352 | visualise=false, | |
353 | recalculate=false, | |
354 | kwargs...) | |
355 | end | |
356 | ||
357 | ###################################################### | |
358 | # Iterator that does visualisation and log collection | |
359 | ###################################################### | |
360 | ||
361 | function iterate_visualise(datachannel::Channel{OnlineData{DisplacementT}}, | |
362 | st :: State, | |
363 | step :: Function, | |
364 | params :: NamedTuple) where DisplacementT | |
365 | try | |
366 | sc = nothing | |
367 | ||
368 | d = take!(datachannel) | |
369 | ||
370 | for iter=1:params.maxiter | |
371 | dnext = take!(datachannel) | |
372 | st = step(d.b_noisy, d.v, dnext.b_noisy) do calc_objective | |
373 | stn = st | |
374 | ||
375 | if isnothing(stn.start_time) | |
376 | # The Julia precompiler is a miserable joke, apparently not crossing module | |
377 | # boundaries, so only start timing after the first iteration. | |
378 | stn = @set stn.start_time=secs_ns() | |
379 | end | |
380 | ||
381 | verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0 | |
382 | ||
383 | # Normalise movement to image dimensions so | |
384 | # our TikZ plotting code doesn't need to know | |
385 | # the image pixel size. | |
386 | sc = 1.0./maximum(size(d.b_true)) | |
387 | ||
388 | if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0) | |
389 | verb_start = secs_ns() | |
390 | tm = verb_start - stn.start_time - stn.wasted_time | |
391 | value, x, v, vhist = calc_objective() | |
392 | ||
393 | entry = LogEntry(iter, tm, value, | |
394 | sc*d.v_cumul_true[1], | |
395 | sc*d.v_cumul_true[2], | |
396 | sc*v[1], sc*v[2], | |
397 | psnr(x, d.b_true), | |
398 | ssim(x, d.b_true), | |
399 | psnr(d.b_noisy, d.b_true), | |
400 | ssim(d.b_noisy, d.b_true)) | |
401 | ||
402 | # (**) Collect a singly-linked list of log to avoid array resizing | |
403 | # while iterating | |
404 | stn = @set stn.log=LinkedListEntry(entry, stn.log) | |
405 | ||
406 | if !isnothing(vhist) | |
407 | vhist=vhist.*sc | |
408 | end | |
409 | ||
410 | if verb | |
411 | @printf("%d/%d J=%f, PSNR=%f, SSIM=%f, avg. FPS=%f\n", | |
412 | iter, params.maxiter, value, entry.psnr, | |
413 | entry.ssim, entry.iter/entry.time) | |
414 | if isa(stn.vis, Channel) | |
415 | put_onlylatest!(stn.vis, ((d.b_noisy, x), | |
416 | params.plot_movement, | |
417 | stn.log, vhist)) | |
418 | ||
419 | end | |
420 | end | |
421 | ||
422 | if params.save_images && (!haskey(params, :save_images_iters) | |
423 | || iter ∈ params.save_images_iters) | |
424 | fn = (t, ext) -> "$(params.save_prefix)_$(t)_frame$(iter).$(ext)" | |
425 | save(File(format"PNG", fn("true", "png")), grayimg(d.b_true)) | |
426 | save(File(format"PNG", fn("data", "png")), grayimg(d.b_noisy)) | |
427 | save(File(format"PNG", fn("reco", "png")), grayimg(x)) | |
428 | if !isnothing(vhist) | |
429 | open(fn("movement", "txt"), "w") do io | |
430 | writedlm(io, ["est_y" "est_x"]) | |
431 | writedlm(io, vhist) | |
432 | end | |
433 | end | |
434 | end | |
435 | ||
436 | stn = @set stn.wasted_time += (secs_ns() - verb_start) | |
437 | ||
438 | return stn | |
439 | end | |
440 | ||
441 | hifientry = LogEntryHiFi(iter, sc*d.v_cumul_true[1], sc*d.v_cumul_true[2]) | |
442 | st = @set st.log_hifi=LinkedListEntry(hifientry, st.log_hifi) | |
443 | ||
444 | return st | |
445 | end | |
446 | d=dnext | |
447 | end | |
448 | catch ex | |
449 | if params.handle_interrupt && isa(ex, InterruptException) | |
450 | # If SIGINT is received (user pressed ^C), terminate computations, | |
451 | # returning current status. Effectively, we do not call `step()` again, | |
452 | # ending the iterations, but letting the algorithm finish up. | |
453 | # Assuming (**) above occurs atomically, `st.log` should be valid, but | |
454 | # any results returned by the algorithm itself may be partial, as for | |
455 | # reasons of efficiency we do *not* store results of an iteration until | |
456 | # the next iteration is finished. | |
457 | printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202) | |
458 | st = @set st.aborted = true | |
459 | else | |
460 | rethrow(ex) | |
461 | end | |
462 | end | |
463 | ||
464 | return st | |
465 | end | |
466 | ||
467 | function bg_visualise_enhanced(rc; fullscreen=false) | |
468 | process_channel(rc) do d | |
469 | imgs, plot_movement, log, vhist = d | |
470 | do_visualise(imgs, refresh=false, fullscreen=fullscreen) | |
471 | # Overlay movement | |
472 | GR.settextcolorind(5) | |
473 | GR.setcharheight(0.015) | |
474 | GR.settextpath(GR.TEXT_PATH_RIGHT) | |
475 | tx, ty = GR.wctondc(0, 1) | |
476 | GR.text(tx, ty, @sprintf "FPS %.1f, SSIM %.2f, PSNR %.1f" (log.value.iter/log.value.time) log.value.ssim log.value.psnr) | |
477 | if plot_movement | |
478 | sc=1.0 | |
479 | p=unfold_linked_list(log) | |
480 | x=map(e -> 1.5+sc*e.v_cumul_true_x, p) | |
481 | y=map(e -> 0.5+sc*e.v_cumul_true_y, p) | |
482 | GR.setlinewidth(2) | |
483 | GR.setlinecolorind(2) | |
484 | GR.polyline(x, y) | |
485 | x=map(e -> 1.5+sc*e.v_cumul_est_x, p) | |
486 | y=map(e -> 0.5+sc*e.v_cumul_est_y, p) | |
487 | GR.setlinecolorind(3) | |
488 | GR.polyline(x, y) | |
489 | if vhist != nothing | |
490 | GR.setlinecolorind(4) | |
491 | x=map(v -> 1.5+sc*v, vhist[:,2]) | |
492 | y=map(v -> 0.5+sc*v, vhist[:,1]) | |
493 | GR.polyline(x, y) | |
494 | end | |
495 | end | |
496 | GR.updatews() | |
497 | end | |
498 | end | |
499 | ||
500 | ############### | |
501 | # Precompiling | |
502 | ############### | |
503 | ||
504 | # precompile(Tuple{typeof(GR.drawimage), Float64, Float64, Float64, Float64, Int64, Int64, Array{UInt32, 2}}) | |
505 | # precompile(Tuple{Type{Plots.Plot{T} where T<:RecipesBase.AbstractBackend}, Plots.GRBackend, Int64, Base.Dict{Symbol, Any}, Base.Dict{Symbol, Any}, Array{Plots.Series, 1}, Nothing, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Base.Dict{Any, Plots.Subplot{T} where T<:RecipesBase.AbstractBackend}, Plots.EmptyLayout, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Bool}) | |
506 | # precompile(Tuple{typeof(Plots._plot!), Plots.Plot{Plots.GRBackend}, Base.Dict{Symbol, Any}, Tuple{Array{ColorTypes.Gray{Float64}, 2}}}) | |
507 | ||
508 | end # Module |