src/PredictPDPS.jl

changeset 0
a55e35d20336
child 3
0cfbe340796d
child 2
be7cab83b14a
equal deleted inserted replaced
-1:000000000000 0:a55e35d20336
1 ##################
2 # Our main module
3 ##################
4
5 __precompile__()
6
7 module PredictPDPS
8
9 ########################
10 # Load external modules
11 ########################
12
13 using Printf
14 using FileIO
15 #using JLD2
16 using Setfield
17 using ImageQualityIndexes: psnr, ssim
18 using DelimitedFiles
19 import GR
20
21 using AlgTools.Util
22 using AlgTools.StructTools
23 using AlgTools.LinkedLists
24 using AlgTools.Comms
25 using ImageTools.Visualise: secs_ns, grayimg, do_visualise
26 using ImageTools.ImFilter: gaussian
27
28 #####################
29 # Load local modules
30 #####################
31
32 include("OpticalFlow.jl")
33 include("ImGenerate.jl")
34 include("Algorithm.jl")
35 include("AlgorithmBoth.jl")
36 include("AlgorithmBothGreedyV.jl")
37 include("AlgorithmBothCumul.jl")
38 include("AlgorithmBothMulti.jl")
39 include("AlgorithmBothNL.jl")
40 include("AlgorithmFB.jl")
41 include("AlgorithmFBDual.jl")
42
43 import .Algorithm,
44 .AlgorithmBoth,
45 .AlgorithmBothGreedyV,
46 .AlgorithmBothCumul,
47 .AlgorithmBothMulti,
48 .AlgorithmBothNL,
49 .AlgorithmFB,
50 .AlgorithmFBDual
51
52 using .ImGenerate
53 using .OpticalFlow: DisplacementFull, DisplacementConstant
54
55 ##############
56 # Our exports
57 ##############
58
59 export run_experiments,
60 batchrun_article,
61 demo_known1,
62 demo_known2,
63 demo_known3,
64 demo_unknown1,
65 demo_unknown2,
66 demo_unknown3
67
68 ###################################
69 # Parameterisation and experiments
70 ###################################
71
72 struct Experiment
73 mod :: Module
74 DisplacementT :: Type
75 imgen :: ImGen
76 params :: NamedTuple
77 end
78
79 function Base.show(io::IO, e::Experiment)
80 displacementname(::Type{DisplacementFull}) = "DisplacementFull"
81 displacementname(::Type{DisplacementConstant}) = "DisplacementConstant"
82 print(io, "
83 mod: $(e.mod)
84 DisplacementT: $(displacementname(e.DisplacementT))
85 imgen: $(e.imgen.name) $(e.imgen.dim[1])×$(e.imgen.dim[2])
86 params: $(e.params ⬿ (kernel = "(not shown)",))
87 ")
88 end
89
90 const default_save_prefix="img/"
91
92 const default_params = (
93 ρ = 0,
94 verbose_iter = 100,
95 maxiter = 10000,
96 save_results = true,
97 save_images = true,
98 save_images_iters = Set([1, 2, 3, 5,
99 10, 25, 30, 50,
100 100, 250, 300, 500,
101 1000, 2500, 3000, 5000,
102 10000]),
103 pixelwise_displacement=false,
104 dual_flow = true,
105 prox_predict = true,
106 handle_interrupt = true,
107 init = :zero,
108 plot_movement = false,
109 )
110
111 const square = imgen_square((200, 300))
112 const lighthouse = imgen_shake("lighthouse", (200, 300))
113
114 const p_known₀ = (
115 noise_level = 0.5,
116 shake_noise_level = 0.05,
117 shake = 2,
118 α = 1,
119 ρ̃₀ = 1,
120 σ̃₀ = 1,
121 δ = 0.9,
122 σ₀ = 1,
123 τ₀ = 0.01,
124 )
125
126 const p_unknown₀ = (
127 noise_level = 0.3,
128 shake_noise_level = 0.05,
129 shake = 2,
130 α = 0.2,
131 ρ̃₀ = 1,
132 σ̃₀ = 1,
133 σ₀ = 1,
134 δ = 0.9,
135 λ = 1,
136 θ = (300*200)*100^3,
137 kernel = gaussian((3, 3), (11, 11)),
138 timestep = 0.5,
139 displacement_count = 100,
140 τ₀ = 0.01,
141 )
142
143 const experiments_pdps_known = (
144 Experiment(Algorithm, DisplacementConstant, lighthouse,
145 p_known₀ ⬿ (phantom_ρ = 0,)),
146 Experiment(Algorithm, DisplacementConstant, lighthouse,
147 p_known₀ ⬿ (phantom_ρ = 100,)),
148 Experiment(Algorithm, DisplacementConstant, square,
149 p_known₀ ⬿ (phantom_ρ = 0,))
150 )
151
152 const experiments_pdps_unknown_multi = (
153 Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse,
154 p_unknown₀ ⬿ (phantom_ρ = 0,)),
155 Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse,
156 p_unknown₀ ⬿ (phantom_ρ = 100,)),
157 Experiment(AlgorithmBothMulti, DisplacementConstant, square,
158 p_unknown₀ ⬿ (phantom_ρ = 0,)),
159 )
160
161 const experiments_fb_known = (
162 Experiment(AlgorithmFB, DisplacementConstant, lighthouse,
163 p_known₀ ⬿ (τ̃₀=0.9, fb_inner_iterations = 10)),
164 )
165
166 const experiments_all = Iterators.flatten((
167 experiments_pdps_known,
168 experiments_pdps_unknown_multi,
169 experiments_fb_known
170 ))
171
172 ################
173 # Log
174 ################
175
176 struct LogEntry <: IterableStruct
177 iter :: Int
178 time :: Float64
179 function_value :: Float64
180 v_cumul_true_y :: Float64
181 v_cumul_true_x :: Float64
182 v_cumul_est_y :: Float64
183 v_cumul_est_x :: Float64
184 psnr :: Float64
185 ssim :: Float64
186 psnr_data :: Float64
187 ssim_data :: Float64
188 end
189
190 struct LogEntryHiFi <: IterableStruct
191 iter :: Int
192 v_cumul_true_y :: Float64
193 v_cumul_true_x :: Float64
194 end
195
196 ###############
197 # Main routine
198 ###############
199
200 struct State
201 vis :: Union{Channel,Bool,Nothing}
202 start_time :: Union{Real,Nothing}
203 wasted_time :: Real
204 log :: LinkedList{LogEntry}
205 log_hifi :: LinkedList{LogEntryHiFi}
206 aborted :: Bool
207 end
208
209 function name(e::Experiment, p)
210 ig = e.imgen
211 return "$(ig.name)_$(e.mod.identifier)_$(@sprintf "%x" hash(p))"
212 end
213
214 function write_tex(texfile, e_params)
215 open(texfile, "w") do io
216 wp = (n, v) -> println(io, "\\def\\EXPPARAM$(n){$(v)}")
217 wf = (n, s) -> if isdefined(e_params, s)
218 wp(n, getfield(e_params, s))
219 end
220 wf("alpha", :α)
221 wf("sigmazero", :σ₀)
222 wf("tauzero", :τ₀)
223 wf("tildetauzero", :τ̃₀)
224 wf("delta", :δ)
225 wf("lambda", :λ)
226 wf("theta", :θ)
227 wf("maxiter", :maxiter)
228 wf("noiselevel", :noise_level)
229 wf("shakenoiselevel", :shake_noise_level)
230 wf("shake", :shake)
231 wf("timestep", :timestep)
232 wf("displacementcount", :displacementcount)
233 wf("phantomrho", :phantom_ρ)
234 if isdefined(e_params, :σ₀)
235 wp("sigma", (e_params.σ₀ == 1 ? "" : "$(e_params.σ₀)") * "\\sigma_{\\max}")
236 end
237 end
238 end
239
240 function run_experiments(;visualise=true,
241 recalculate=true,
242 experiments,
243 save_prefix=default_save_prefix,
244 fullscreen=false,
245 kwargs...)
246
247 # Create visualisation
248 if visualise
249 rc = Channel(1)
250 visproc = Threads.@spawn bg_visualise_enhanced(rc, fullscreen=fullscreen)
251 bind(rc, visproc)
252 vis = rc
253 else
254 vis = false
255 end
256
257 # Run all experiments
258 for e ∈ experiments
259
260 # Parameters for this experiment
261 e_params = default_params ⬿ e.params ⬿ kwargs
262 ename = name(e, e_params)
263 e_params = e_params ⬿ (save_prefix = save_prefix * ename,
264 dynrange = e.imgen.dynrange,
265 Λ = e.imgen.Λ)
266
267 if recalculate || !isfile(e_params.save_prefix * ".txt")
268 println("Running experiment \"$(ename)\"")
269
270 # Start data generation task
271 datachannel = Channel{OnlineData{e.DisplacementT}}(2)
272 gentask = Threads.@spawn e.imgen.f(e.DisplacementT, datachannel, e_params)
273 bind(datachannel, gentask)
274
275 # Run algorithm
276 iterate = curry(iterate_visualise, datachannel,
277 State(vis, nothing, 0.0, nothing, nothing, false))
278
279 x, y, st = e.mod.solve(e.DisplacementT;
280 dim=e.imgen.dim,
281 iterate=iterate,
282 params=e_params)
283
284 # Clear non-saveable things
285 st = @set st.vis = nothing
286
287 println("Wasted_time: $(st.wasted_time)s")
288
289 if e_params.save_results
290 println("Saving " * e_params.save_prefix * "(.txt,_hifi.txt,_params.tex)")
291
292 perffile = e_params.save_prefix * ".txt"
293 hififile = e_params.save_prefix * "_hifi.txt"
294 texfile = e_params.save_prefix * "_params.tex"
295 # datafile = e_params.save_prefix * ".jld2"
296
297 write_log(perffile, st.log, "# params = $(e_params)\n")
298 write_log(hififile, st.log_hifi, "# params = $(e_params)\n")
299 write_tex(texfile, e_params)
300 # @save datafile x y st params
301 end
302
303 close(datachannel)
304 wait(gentask)
305
306 if st.aborted
307 break
308 end
309 else
310 println("Skipping already computed experiment \"$(ename)\"")
311 # texfile = e_params.save_prefix * "_params.tex"
312 # write_tex(texfile, e_params)
313 end
314 end
315
316 if visualise
317 # Tell subprocess to finish, and wait
318 put!(rc, nothing)
319 close(rc)
320 wait(visproc)
321 end
322
323 return
324 end
325
326 #######################
327 # Demos and batch runs
328 #######################
329
330 function demo(experiment; kwargs...)
331 run_experiments(;experiments=(experiment,),
332 save_results=false,
333 save_images=false,
334 visualise=true,
335 recalculate=true,
336 verbose_iter=10,
337 fullscreen=true,
338 kwargs...)
339 end
340
341 demo_known1 = () -> demo(experiments_pdps_known[3])
342 demo_known2 = () -> demo(experiments_pdps_known[1])
343 demo_known3 = () -> demo(experiments_pdps_known[2])
344 demo_unknown1 = () -> demo(experiments_pdps_unknown_multi[3], plot_movement=true)
345 demo_unknown2 = () -> demo(experiments_pdps_unknown_multi[1], plot_movement=true)
346 demo_unknown3 = () -> demo(experiments_pdps_unknown_multi[2], plot_movement=true)
347
348 function batchrun_article(kwargs...)
349 run_experiments(;experiments=experiments_all,
350 save_results=true,
351 save_images=true,
352 visualise=false,
353 recalculate=false,
354 kwargs...)
355 end
356
357 ######################################################
358 # Iterator that does visualisation and log collection
359 ######################################################
360
361 function iterate_visualise(datachannel::Channel{OnlineData{DisplacementT}},
362 st :: State,
363 step :: Function,
364 params :: NamedTuple) where DisplacementT
365 try
366 sc = nothing
367
368 d = take!(datachannel)
369
370 for iter=1:params.maxiter
371 dnext = take!(datachannel)
372 st = step(d.b_noisy, d.v, dnext.b_noisy) do calc_objective
373 stn = st
374
375 if isnothing(stn.start_time)
376 # The Julia precompiler is a miserable joke, apparently not crossing module
377 # boundaries, so only start timing after the first iteration.
378 stn = @set stn.start_time=secs_ns()
379 end
380
381 verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0
382
383 # Normalise movement to image dimensions so
384 # our TikZ plotting code doesn't need to know
385 # the image pixel size.
386 sc = 1.0./maximum(size(d.b_true))
387
388 if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0)
389 verb_start = secs_ns()
390 tm = verb_start - stn.start_time - stn.wasted_time
391 value, x, v, vhist = calc_objective()
392
393 entry = LogEntry(iter, tm, value,
394 sc*d.v_cumul_true[1],
395 sc*d.v_cumul_true[2],
396 sc*v[1], sc*v[2],
397 psnr(x, d.b_true),
398 ssim(x, d.b_true),
399 psnr(d.b_noisy, d.b_true),
400 ssim(d.b_noisy, d.b_true))
401
402 # (**) Collect a singly-linked list of log to avoid array resizing
403 # while iterating
404 stn = @set stn.log=LinkedListEntry(entry, stn.log)
405
406 if !isnothing(vhist)
407 vhist=vhist.*sc
408 end
409
410 if verb
411 @printf("%d/%d J=%f, PSNR=%f, SSIM=%f, avg. FPS=%f\n",
412 iter, params.maxiter, value, entry.psnr,
413 entry.ssim, entry.iter/entry.time)
414 if isa(stn.vis, Channel)
415 put_onlylatest!(stn.vis, ((d.b_noisy, x),
416 params.plot_movement,
417 stn.log, vhist))
418
419 end
420 end
421
422 if params.save_images && (!haskey(params, :save_images_iters)
423 || iter ∈ params.save_images_iters)
424 fn = (t, ext) -> "$(params.save_prefix)_$(t)_frame$(iter).$(ext)"
425 save(File(format"PNG", fn("true", "png")), grayimg(d.b_true))
426 save(File(format"PNG", fn("data", "png")), grayimg(d.b_noisy))
427 save(File(format"PNG", fn("reco", "png")), grayimg(x))
428 if !isnothing(vhist)
429 open(fn("movement", "txt"), "w") do io
430 writedlm(io, ["est_y" "est_x"])
431 writedlm(io, vhist)
432 end
433 end
434 end
435
436 stn = @set stn.wasted_time += (secs_ns() - verb_start)
437
438 return stn
439 end
440
441 hifientry = LogEntryHiFi(iter, sc*d.v_cumul_true[1], sc*d.v_cumul_true[2])
442 st = @set st.log_hifi=LinkedListEntry(hifientry, st.log_hifi)
443
444 return st
445 end
446 d=dnext
447 end
448 catch ex
449 if params.handle_interrupt && isa(ex, InterruptException)
450 # If SIGINT is received (user pressed ^C), terminate computations,
451 # returning current status. Effectively, we do not call `step()` again,
452 # ending the iterations, but letting the algorithm finish up.
453 # Assuming (**) above occurs atomically, `st.log` should be valid, but
454 # any results returned by the algorithm itself may be partial, as for
455 # reasons of efficiency we do *not* store results of an iteration until
456 # the next iteration is finished.
457 printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202)
458 st = @set st.aborted = true
459 else
460 rethrow(ex)
461 end
462 end
463
464 return st
465 end
466
467 function bg_visualise_enhanced(rc; fullscreen=false)
468 process_channel(rc) do d
469 imgs, plot_movement, log, vhist = d
470 do_visualise(imgs, refresh=false, fullscreen=fullscreen)
471 # Overlay movement
472 GR.settextcolorind(5)
473 GR.setcharheight(0.015)
474 GR.settextpath(GR.TEXT_PATH_RIGHT)
475 tx, ty = GR.wctondc(0, 1)
476 GR.text(tx, ty, @sprintf "FPS %.1f, SSIM %.2f, PSNR %.1f" (log.value.iter/log.value.time) log.value.ssim log.value.psnr)
477 if plot_movement
478 sc=1.0
479 p=unfold_linked_list(log)
480 x=map(e -> 1.5+sc*e.v_cumul_true_x, p)
481 y=map(e -> 0.5+sc*e.v_cumul_true_y, p)
482 GR.setlinewidth(2)
483 GR.setlinecolorind(2)
484 GR.polyline(x, y)
485 x=map(e -> 1.5+sc*e.v_cumul_est_x, p)
486 y=map(e -> 0.5+sc*e.v_cumul_est_y, p)
487 GR.setlinecolorind(3)
488 GR.polyline(x, y)
489 if vhist != nothing
490 GR.setlinecolorind(4)
491 x=map(v -> 1.5+sc*v, vhist[:,2])
492 y=map(v -> 0.5+sc*v, vhist[:,1])
493 GR.polyline(x, y)
494 end
495 end
496 GR.updatews()
497 end
498 end
499
500 ###############
501 # Precompiling
502 ###############
503
504 # precompile(Tuple{typeof(GR.drawimage), Float64, Float64, Float64, Float64, Int64, Int64, Array{UInt32, 2}})
505 # precompile(Tuple{Type{Plots.Plot{T} where T<:RecipesBase.AbstractBackend}, Plots.GRBackend, Int64, Base.Dict{Symbol, Any}, Base.Dict{Symbol, Any}, Array{Plots.Series, 1}, Nothing, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Base.Dict{Any, Plots.Subplot{T} where T<:RecipesBase.AbstractBackend}, Plots.EmptyLayout, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Bool})
506 # precompile(Tuple{typeof(Plots._plot!), Plots.Plot{Plots.GRBackend}, Base.Dict{Symbol, Any}, Tuple{Array{ColorTypes.Gray{Float64}, 2}}})
507
508 end # Module

mercurial