Thu, 25 Apr 2024 14:48:54 -0500
oops
| 36 | 1 | ################## | 
| 2 | # Experiment running and interactive visualisation | |
| 3 | ################## | |
| 4 | ||
| 5 | module Run | |
| 6 | ||
| 7 | import GR | |
| 8 | using Setfield | |
| 9 | using Printf | |
| 10 | using ImageQualityIndexes: assess_psnr, assess_ssim | |
| 11 | using FileIO | |
| 12 | #using JLD2 | |
| 13 | using DelimitedFiles | |
| 14 | ||
| 15 | using AlgTools.Util | |
| 16 | using AlgTools.StructTools | |
| 17 | using AlgTools.LinkedLists | |
| 18 | using AlgTools.Comms | |
| 19 | ||
| 20 | using ImageTools.Visualise: secs_ns, grayimg, do_visualise | |
| 21 | ||
| 22 | using ..ImGenerate | |
| 23 | using ..OpticalFlow: identifier, DisplacementFull, DisplacementConstant | |
| 24 | ||
| 25 | export run_experiments, | |
| 26 | Experiment, | |
| 27 | State, | |
| 28 | LogEntry, | |
| 29 | LogEntryHiFi | |
| 30 | ||
| 31 | ################ | |
| 32 | # Experiment | |
| 33 | ################ | |
| 34 | ||
| 35 | struct Experiment | |
| 36 | mod :: Module | |
| 37 | DisplacementT :: Type | |
| 38 | imgen :: ImGen | |
| 39 | params :: NamedTuple | |
| 40 | end | |
| 41 | ||
| 42 | function Base.show(io::IO, e::Experiment) | |
| 43 | displacementname(::Type{DisplacementFull}) = "DisplacementFull" | |
| 44 | displacementname(::Type{DisplacementConstant}) = "DisplacementConstant" | |
| 45 | print(io, " | |
| 46 | mod: $(e.mod) | |
| 47 | DisplacementT: $(displacementname(e.DisplacementT)) | |
| 48 | imgen: $(e.imgen.name) $(e.imgen.dim[1])×$(e.imgen.dim[2]) | |
| 49 | params: $(e.params ⬿ (kernel = "(not shown)",)) | |
| 50 | ") | |
| 51 | end | |
| 52 | ||
| 53 | ################ | |
| 54 | # Log | |
| 55 | ################ | |
| 56 | ||
| 57 | struct LogEntry <: IterableStruct | |
| 58 | iter :: Int | |
| 59 | time :: Float64 | |
| 60 | function_value :: Float64 | |
| 61 | #v_cumul_true_y :: Float64 | |
| 62 | #v_cumul_true_x :: Float64 | |
| 63 | #v_cumul_est_y :: Float64 | |
| 64 | #v_cumul_est_x :: Float64 | |
| 65 | psnr :: Float64 | |
| 66 | ssim :: Float64 | |
| 67 | #psnr_data :: Float64 | |
| 68 | #ssim_data :: Float64 | |
| 69 | end | |
| 70 | ||
| 71 | struct LogEntryHiFi <: IterableStruct | |
| 72 | iter :: Int | |
| 73 | v_cumul_true_y :: Float64 | |
| 74 | v_cumul_true_x :: Float64 | |
| 75 | end | |
| 76 | ||
| 77 | ############### | |
| 78 | # Main routine | |
| 79 | ############### | |
| 80 | ||
| 81 | struct State | |
| 82 | vis :: Union{Channel,Bool,Nothing} | |
| 83 | start_time :: Union{Real,Nothing} | |
| 84 | wasted_time :: Real | |
| 85 | log :: LinkedList{LogEntry} | |
| 86 | log_hifi :: LinkedList{LogEntryHiFi} | |
| 87 | aborted :: Bool | |
| 88 | end | |
| 89 | ||
| 90 | function name(e::Experiment, p) | |
| 91 | ig = e.imgen | |
| 92 | id = if haskey(p, :predictor) && ~isnothing(p.predictor) | |
| 93 | identifier(p.predictor) | |
| 94 | else | |
| 95 | e.mod.identifier | |
| 96 | end | |
| 97 | # return "$(ig.name)_$(e.mod.identifier)_$(@sprintf "%x" hash(p))" | |
| 98 | return "$(ig.name)_$(id)_$(Int64(100*p.α))_$(Int64(10000*p.σ₀))_$(Int64(10000*p.τ₀))" | |
| 99 | end | |
| 100 | ||
| 101 | function write_tex(texfile, e_params) | |
| 102 | open(texfile, "w") do io | |
| 103 | wp = (n, v) -> println(io, "\\def\\EXPPARAM$(n){$(v)}") | |
| 104 | wf = (n, s) -> if isdefined(e_params, s) | |
| 105 | wp(n, getfield(e_params, s)) | |
| 106 | end | |
| 107 | wf("alpha", :α) | |
| 108 | wf("sigmazero", :σ₀) | |
| 109 | wf("tauzero", :τ₀) | |
| 110 | wf("tildetauzero", :τ̃₀) | |
| 111 | wf("delta", :δ) | |
| 112 | wf("lambda", :λ) | |
| 113 | wf("theta", :θ) | |
| 114 | wf("maxiter", :maxiter) | |
| 115 | wf("noiselevel", :noise_level) | |
| 116 | wf("shakenoiselevel", :shake_noise_level) | |
| 117 | wf("shake", :shake) | |
| 118 | wf("timestep", :timestep) | |
| 119 | wf("displacementcount", :displacementcount) | |
| 120 | wf("phantomrho", :phantom_ρ) | |
| 121 | if isdefined(e_params, :σ₀) | |
| 122 | wp("sigma", (e_params.σ₀ == 1 ? "" : "$(e_params.σ₀)") * "\\sigma_{\\max}") | |
| 123 | end | |
| 124 | end | |
| 125 | end | |
| 126 | ||
| 127 | function run_experiments(;visualise=true, | |
| 128 | visfn = iterate_visualise, | |
| 129 | datatype = OnlineData, | |
| 130 | recalculate=true, | |
| 131 | experiments, | |
| 132 | save_prefix="", | |
| 133 | fullscreen=false, | |
| 134 | kwargs...) | |
| 135 | ||
| 136 | # Create visualisation | |
| 137 | if visualise | |
| 138 | rc = Channel(1) | |
| 139 | visproc = Threads.@spawn bg_visualise_enhanced(rc, fullscreen=fullscreen) | |
| 140 | bind(rc, visproc) | |
| 141 | vis = rc | |
| 142 | else | |
| 143 | vis = false | |
| 144 | end | |
| 145 | ||
| 146 | # Run all experiments | |
| 147 | for e ∈ experiments | |
| 148 | ||
| 149 | # Parameters for this experiment | |
| 150 | e_params = e.params ⬿ kwargs | |
| 151 | ename = name(e, e_params) | |
| 152 | e_params = e_params ⬿ (save_prefix = save_prefix * ename, | |
| 153 | dynrange = e.imgen.dynrange, | |
| 154 | Λ = e.imgen.Λ) | |
| 155 | ||
| 156 | if recalculate || !isfile(e_params.save_prefix * ".txt") | |
| 157 | println("Running experiment \"$(ename)\"") | |
| 158 | ||
| 159 | # Start data generation task | |
| 160 | datachannel = Channel{datatype{e.DisplacementT}}(2) | |
| 161 | gentask = Threads.@spawn e.imgen.f(e.DisplacementT, datachannel, e_params) | |
| 162 | bind(datachannel, gentask) | |
| 163 | ||
| 164 | # Run algorithm | |
| 165 | iterate = curry(visfn, datachannel, | |
| 166 | State(vis, nothing, 0.0, nothing, nothing, false)) | |
| 167 | ||
| 168 | x, y, st = e.mod.solve(e.DisplacementT; | |
| 169 | dim=e.imgen.dim, | |
| 170 | iterate=iterate, | |
| 171 | params=e_params) | |
| 172 | ||
| 173 | # Clear non-saveable things | |
| 174 | st = @set st.vis = nothing | |
| 175 | ||
| 176 | println("Wasted_time: $(st.wasted_time)s") | |
| 177 | ||
| 178 | if e_params.save_results | |
| 179 | println("Saving " * e_params.save_prefix * "(.txt,_hifi.txt,_params.tex)") | |
| 180 | ||
| 181 | perffile = e_params.save_prefix * ".txt" | |
| 182 | hififile = e_params.save_prefix * "_hifi.txt" | |
| 183 | texfile = e_params.save_prefix * "_params.tex" | |
| 184 | # datafile = e_params.save_prefix * ".jld2" | |
| 185 | ||
| 186 | write_log(perffile, st.log, "# params = $(e_params)\n") | |
| 187 | #write_log(hififile, st.log_hifi, "# params = $(e_params)\n") | |
| 188 | #write_tex(texfile, e_params) | |
| 189 | # @save datafile x y st params | |
| 190 | end | |
| 191 | ||
| 192 | close(datachannel) | |
| 193 | wait(gentask) | |
| 194 | ||
| 195 | if st.aborted | |
| 196 | break | |
| 197 | end | |
| 198 | else | |
| 199 | println("Skipping already computed experiment \"$(ename)\"") | |
| 200 | # texfile = e_params.save_prefix * "_params.tex" | |
| 201 | # write_tex(texfile, e_params) | |
| 202 | end | |
| 203 | end | |
| 204 | ||
| 205 | if visualise | |
| 206 | # Tell subprocess to finish, and wait | |
| 207 | put!(rc, nothing) | |
| 208 | close(rc) | |
| 209 | wait(visproc) | |
| 210 | end | |
| 211 | ||
| 212 | return | |
| 213 | end | |
| 214 | ||
| 215 | ||
| 216 | ###################################################### | |
| 217 | # Iterator that does visualisation and log collection | |
| 218 | ###################################################### | |
| 219 | ||
| 220 | function iterate_visualise(datachannel::Channel{OnlineData{DisplacementT}}, | |
| 221 | st :: State, | |
| 222 | step :: Function, | |
| 223 | params :: NamedTuple) where DisplacementT | |
| 224 | try | |
| 225 | sc = nothing | |
| 226 | ||
| 227 | d = take!(datachannel) | |
| 228 | ||
| 229 | for iter=1:params.maxiter | |
| 230 | dnext = take!(datachannel) | |
| 231 | st = step(d.b_noisy, d.v, dnext.b_noisy) do calc_objective | |
| 232 | stn = st | |
| 233 | ||
| 234 | if isnothing(stn.start_time) | |
| 235 | # The Julia precompiler is a miserable joke, apparently not crossing module | |
| 236 | # boundaries, so only start timing after the first iteration. | |
| 237 | stn = @set stn.start_time=secs_ns() | |
| 238 | end | |
| 239 | ||
| 240 | verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0 | |
| 241 | ||
| 242 | # Normalise movement to image dimensions so | |
| 243 | # our TikZ plotting code doesn't need to know | |
| 244 | # the image pixel size. | |
| 245 | sc = 1.0./maximum(size(d.b_true)) | |
| 246 | ||
| 247 | if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0) | |
| 248 | verb_start = secs_ns() | |
| 249 | tm = verb_start - stn.start_time - stn.wasted_time | |
| 250 | value, x, v, vhist = calc_objective() | |
| 251 | ||
| 252 | entry = LogEntry(iter, tm, value, | |
| 253 | #sc*d.v_cumul_true[1], | |
| 254 | #sc*d.v_cumul_true[2], | |
| 255 | #sc*v[1], sc*v[2], | |
| 256 | assess_psnr(x, d.b_true), | |
| 257 | assess_ssim(x, d.b_true), | |
| 258 | #assess_psnr(d.b_noisy, d.b_true), | |
| 259 | #assess_ssim(d.b_noisy, d.b_true) | |
| 260 | ) | |
| 261 | ||
| 262 | # (**) Collect a singly-linked list of log to avoid array resizing | |
| 263 | # while iterating | |
| 264 | stn = @set stn.log=LinkedListEntry(entry, stn.log) | |
| 265 | ||
| 266 | if !isnothing(vhist) | |
| 267 | vhist=vhist.*sc | |
| 268 | end | |
| 269 | ||
| 270 | if verb | |
| 271 | @printf("%d/%d J=%f, PSNR=%f, SSIM=%f, avg. FPS=%f\n", | |
| 272 | iter, params.maxiter, value, entry.psnr, | |
| 273 | entry.ssim, entry.iter/entry.time) | |
| 274 | if isa(stn.vis, Channel) | |
| 275 | put_onlylatest!(stn.vis, ((d.b_noisy, x), | |
| 276 | params.plot_movement, | |
| 277 | stn.log, vhist)) | |
| 278 | ||
| 279 | end | |
| 280 | end | |
| 281 | ||
| 282 | if params.save_images && (!haskey(params, :save_images_iters) | |
| 283 | || iter ∈ params.save_images_iters) | |
| 284 | fn = (t, ext) -> "$(params.save_prefix)_$(t)_frame$(iter).$(ext)" | |
| 285 | # save(File(format"PNG", fn("true", "png")), grayimg(d.b_true)) | |
| 286 | # save(File(format"PNG", fn("data", "png")), grayimg(d.b_noisy)) | |
| 287 | save(File(format"PNG", fn("reco", "png")), grayimg(x)) | |
| 288 | if !isnothing(vhist) | |
| 289 | open(fn("movement", "txt"), "w") do io | |
| 290 | writedlm(io, ["est_y" "est_x"]) | |
| 291 | writedlm(io, vhist) | |
| 292 | end | |
| 293 | end | |
| 294 | end | |
| 295 | ||
| 296 | stn = @set stn.wasted_time += (secs_ns() - verb_start) | |
| 297 | ||
| 298 | return stn | |
| 299 | end | |
| 300 | ||
| 301 | hifientry = LogEntryHiFi(iter, sc*d.v_cumul_true[1], sc*d.v_cumul_true[2]) | |
| 302 | st = @set st.log_hifi=LinkedListEntry(hifientry, st.log_hifi) | |
| 303 | ||
| 304 | return st | |
| 305 | end | |
| 306 | d=dnext | |
| 307 | end | |
| 308 | catch ex | |
| 309 | if params.handle_interrupt && isa(ex, InterruptException) | |
| 310 | # If SIGINT is received (user pressed ^C), terminate computations, | |
| 311 | # returning current status. Effectively, we do not call `step()` again, | |
| 312 | # ending the iterations, but letting the algorithm finish up. | |
| 313 | # Assuming (**) above occurs atomically, `st.log` should be valid, but | |
| 314 | # any results returned by the algorithm itself may be partial, as for | |
| 315 | # reasons of efficiency we do *not* store results of an iteration until | |
| 316 | # the next iteration is finished. | |
| 317 | printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202) | |
| 318 | st = @set st.aborted = true | |
| 319 | else | |
| 320 | rethrow(ex) | |
| 321 | end | |
| 322 | end | |
| 323 | ||
| 324 | return st | |
| 325 | end | |
| 326 | ||
| 327 | function bg_visualise_enhanced(rc; fullscreen=false) | |
| 328 | process_channel(rc) do d | |
| 329 | imgs, plot_movement, log, vhist = d | |
| 330 | do_visualise(imgs, refresh=false, fullscreen=fullscreen) | |
| 331 | # Overlay movement | |
| 332 | GR.settextcolorind(5) | |
| 333 | GR.setcharheight(0.015) | |
| 334 | GR.settextpath(GR.TEXT_PATH_RIGHT) | |
| 335 | tx, ty = GR.wctondc(0, 1) | |
| 336 | GR.text(tx, ty, @sprintf "FPS %.1f, SSIM %.2f, PSNR %.1f" (log.value.iter/log.value.time) log.value.ssim log.value.psnr) | |
| 337 | if plot_movement | |
| 338 | sc=1.0 | |
| 339 | p=unfold_linked_list(log) | |
| 340 | x=map(e -> 1.5+sc*e.v_cumul_true_x, p) | |
| 341 | y=map(e -> 0.5+sc*e.v_cumul_true_y, p) | |
| 342 | GR.setlinewidth(2) | |
| 343 | GR.setlinecolorind(2) | |
| 344 | GR.polyline(x, y) | |
| 345 | x=map(e -> 1.5+sc*e.v_cumul_est_x, p) | |
| 346 | y=map(e -> 0.5+sc*e.v_cumul_est_y, p) | |
| 347 | GR.setlinecolorind(3) | |
| 348 | GR.polyline(x, y) | |
| 349 | if vhist != nothing | |
| 350 | GR.setlinecolorind(4) | |
| 351 | x=map(v -> 1.5+sc*v, vhist[:,2]) | |
| 352 | y=map(v -> 0.5+sc*v, vhist[:,1]) | |
| 353 | GR.polyline(x, y) | |
| 354 | end | |
| 355 | end | |
| 356 | GR.updatews() | |
| 357 | end | |
| 358 | end | |
| 359 | ||
| 360 | end # module |