--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/PredictPDPS.jl Tue Apr 07 14:19:48 2020 -0500 @@ -0,0 +1,508 @@ +################## +# Our main module +################## + +__precompile__() + +module PredictPDPS + +######################## +# Load external modules +######################## + +using Printf +using FileIO +#using JLD2 +using Setfield +using ImageQualityIndexes: psnr, ssim +using DelimitedFiles +import GR + +using AlgTools.Util +using AlgTools.StructTools +using AlgTools.LinkedLists +using AlgTools.Comms +using ImageTools.Visualise: secs_ns, grayimg, do_visualise +using ImageTools.ImFilter: gaussian + +##################### +# Load local modules +##################### + +include("OpticalFlow.jl") +include("ImGenerate.jl") +include("Algorithm.jl") +include("AlgorithmBoth.jl") +include("AlgorithmBothGreedyV.jl") +include("AlgorithmBothCumul.jl") +include("AlgorithmBothMulti.jl") +include("AlgorithmBothNL.jl") +include("AlgorithmFB.jl") +include("AlgorithmFBDual.jl") + +import .Algorithm, + .AlgorithmBoth, + .AlgorithmBothGreedyV, + .AlgorithmBothCumul, + .AlgorithmBothMulti, + .AlgorithmBothNL, + .AlgorithmFB, + .AlgorithmFBDual + +using .ImGenerate +using .OpticalFlow: DisplacementFull, DisplacementConstant + +############## +# Our exports +############## + +export run_experiments, + batchrun_article, + demo_known1, + demo_known2, + demo_known3, + demo_unknown1, + demo_unknown2, + demo_unknown3 + +################################### +# Parameterisation and experiments +################################### + +struct Experiment + mod :: Module + DisplacementT :: Type + imgen :: ImGen + params :: NamedTuple +end + +function Base.show(io::IO, e::Experiment) + displacementname(::Type{DisplacementFull}) = "DisplacementFull" + displacementname(::Type{DisplacementConstant}) = "DisplacementConstant" + print(io, " + mod: $(e.mod) + DisplacementT: $(displacementname(e.DisplacementT)) + imgen: $(e.imgen.name) $(e.imgen.dim[1])×$(e.imgen.dim[2]) + params: $(e.params ⬿ (kernel = "(not shown)",)) + ") +end + +const default_save_prefix="img/" + +const default_params = ( + ρ = 0, + verbose_iter = 100, + maxiter = 10000, + save_results = true, + save_images = true, + save_images_iters = Set([1, 2, 3, 5, + 10, 25, 30, 50, + 100, 250, 300, 500, + 1000, 2500, 3000, 5000, + 10000]), + pixelwise_displacement=false, + dual_flow = true, + prox_predict = true, + handle_interrupt = true, + init = :zero, + plot_movement = false, +) + +const square = imgen_square((200, 300)) +const lighthouse = imgen_shake("lighthouse", (200, 300)) + +const p_known₀ = ( + noise_level = 0.5, + shake_noise_level = 0.05, + shake = 2, + α = 1, + ρ̃₀ = 1, + σ̃₀ = 1, + δ = 0.9, + σ₀ = 1, + τ₀ = 0.01, +) + +const p_unknown₀ = ( + noise_level = 0.3, + shake_noise_level = 0.05, + shake = 2, + α = 0.2, + ρ̃₀ = 1, + σ̃₀ = 1, + σ₀ = 1, + δ = 0.9, + λ = 1, + θ = (300*200)*100^3, + kernel = gaussian((3, 3), (11, 11)), + timestep = 0.5, + displacement_count = 100, + τ₀ = 0.01, +) + +const experiments_pdps_known = ( + Experiment(Algorithm, DisplacementConstant, lighthouse, + p_known₀ ⬿ (phantom_ρ = 0,)), + Experiment(Algorithm, DisplacementConstant, lighthouse, + p_known₀ ⬿ (phantom_ρ = 100,)), + Experiment(Algorithm, DisplacementConstant, square, + p_known₀ ⬿ (phantom_ρ = 0,)) +) + +const experiments_pdps_unknown_multi = ( + Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse, + p_unknown₀ ⬿ (phantom_ρ = 0,)), + Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse, + p_unknown₀ ⬿ (phantom_ρ = 100,)), + Experiment(AlgorithmBothMulti, DisplacementConstant, square, + p_unknown₀ ⬿ (phantom_ρ = 0,)), +) + +const experiments_fb_known = ( + Experiment(AlgorithmFB, DisplacementConstant, lighthouse, + p_known₀ ⬿ (τ̃₀=0.9, fb_inner_iterations = 10)), +) + +const experiments_all = Iterators.flatten(( + experiments_pdps_known, + experiments_pdps_unknown_multi, + experiments_fb_known +)) + +################ +# Log +################ + +struct LogEntry <: IterableStruct + iter :: Int + time :: Float64 + function_value :: Float64 + v_cumul_true_y :: Float64 + v_cumul_true_x :: Float64 + v_cumul_est_y :: Float64 + v_cumul_est_x :: Float64 + psnr :: Float64 + ssim :: Float64 + psnr_data :: Float64 + ssim_data :: Float64 +end + +struct LogEntryHiFi <: IterableStruct + iter :: Int + v_cumul_true_y :: Float64 + v_cumul_true_x :: Float64 +end + +############### +# Main routine +############### + +struct State + vis :: Union{Channel,Bool,Nothing} + start_time :: Union{Real,Nothing} + wasted_time :: Real + log :: LinkedList{LogEntry} + log_hifi :: LinkedList{LogEntryHiFi} + aborted :: Bool +end + +function name(e::Experiment, p) + ig = e.imgen + return "$(ig.name)_$(e.mod.identifier)_$(@sprintf "%x" hash(p))" +end + +function write_tex(texfile, e_params) + open(texfile, "w") do io + wp = (n, v) -> println(io, "\\def\\EXPPARAM$(n){$(v)}") + wf = (n, s) -> if isdefined(e_params, s) + wp(n, getfield(e_params, s)) + end + wf("alpha", :α) + wf("sigmazero", :σ₀) + wf("tauzero", :τ₀) + wf("tildetauzero", :τ̃₀) + wf("delta", :δ) + wf("lambda", :λ) + wf("theta", :θ) + wf("maxiter", :maxiter) + wf("noiselevel", :noise_level) + wf("shakenoiselevel", :shake_noise_level) + wf("shake", :shake) + wf("timestep", :timestep) + wf("displacementcount", :displacementcount) + wf("phantomrho", :phantom_ρ) + if isdefined(e_params, :σ₀) + wp("sigma", (e_params.σ₀ == 1 ? "" : "$(e_params.σ₀)") * "\\sigma_{\\max}") + end + end +end + +function run_experiments(;visualise=true, + recalculate=true, + experiments, + save_prefix=default_save_prefix, + fullscreen=false, + kwargs...) + + # Create visualisation + if visualise + rc = Channel(1) + visproc = Threads.@spawn bg_visualise_enhanced(rc, fullscreen=fullscreen) + bind(rc, visproc) + vis = rc + else + vis = false + end + + # Run all experiments + for e ∈ experiments + + # Parameters for this experiment + e_params = default_params ⬿ e.params ⬿ kwargs + ename = name(e, e_params) + e_params = e_params ⬿ (save_prefix = save_prefix * ename, + dynrange = e.imgen.dynrange, + Λ = e.imgen.Λ) + + if recalculate || !isfile(e_params.save_prefix * ".txt") + println("Running experiment \"$(ename)\"") + + # Start data generation task + datachannel = Channel{OnlineData{e.DisplacementT}}(2) + gentask = Threads.@spawn e.imgen.f(e.DisplacementT, datachannel, e_params) + bind(datachannel, gentask) + + # Run algorithm + iterate = curry(iterate_visualise, datachannel, + State(vis, nothing, 0.0, nothing, nothing, false)) + + x, y, st = e.mod.solve(e.DisplacementT; + dim=e.imgen.dim, + iterate=iterate, + params=e_params) + + # Clear non-saveable things + st = @set st.vis = nothing + + println("Wasted_time: $(st.wasted_time)s") + + if e_params.save_results + println("Saving " * e_params.save_prefix * "(.txt,_hifi.txt,_params.tex)") + + perffile = e_params.save_prefix * ".txt" + hififile = e_params.save_prefix * "_hifi.txt" + texfile = e_params.save_prefix * "_params.tex" + # datafile = e_params.save_prefix * ".jld2" + + write_log(perffile, st.log, "# params = $(e_params)\n") + write_log(hififile, st.log_hifi, "# params = $(e_params)\n") + write_tex(texfile, e_params) + # @save datafile x y st params + end + + close(datachannel) + wait(gentask) + + if st.aborted + break + end + else + println("Skipping already computed experiment \"$(ename)\"") + # texfile = e_params.save_prefix * "_params.tex" + # write_tex(texfile, e_params) + end + end + + if visualise + # Tell subprocess to finish, and wait + put!(rc, nothing) + close(rc) + wait(visproc) + end + + return +end + +####################### +# Demos and batch runs +####################### + +function demo(experiment; kwargs...) + run_experiments(;experiments=(experiment,), + save_results=false, + save_images=false, + visualise=true, + recalculate=true, + verbose_iter=10, + fullscreen=true, + kwargs...) +end + +demo_known1 = () -> demo(experiments_pdps_known[3]) +demo_known2 = () -> demo(experiments_pdps_known[1]) +demo_known3 = () -> demo(experiments_pdps_known[2]) +demo_unknown1 = () -> demo(experiments_pdps_unknown_multi[3], plot_movement=true) +demo_unknown2 = () -> demo(experiments_pdps_unknown_multi[1], plot_movement=true) +demo_unknown3 = () -> demo(experiments_pdps_unknown_multi[2], plot_movement=true) + +function batchrun_article(kwargs...) + run_experiments(;experiments=experiments_all, + save_results=true, + save_images=true, + visualise=false, + recalculate=false, + kwargs...) +end + +###################################################### +# Iterator that does visualisation and log collection +###################################################### + +function iterate_visualise(datachannel::Channel{OnlineData{DisplacementT}}, + st :: State, + step :: Function, + params :: NamedTuple) where DisplacementT + try + sc = nothing + + d = take!(datachannel) + + for iter=1:params.maxiter + dnext = take!(datachannel) + st = step(d.b_noisy, d.v, dnext.b_noisy) do calc_objective + stn = st + + if isnothing(stn.start_time) + # The Julia precompiler is a miserable joke, apparently not crossing module + # boundaries, so only start timing after the first iteration. + stn = @set stn.start_time=secs_ns() + end + + verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0 + + # Normalise movement to image dimensions so + # our TikZ plotting code doesn't need to know + # the image pixel size. + sc = 1.0./maximum(size(d.b_true)) + + if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0) + verb_start = secs_ns() + tm = verb_start - stn.start_time - stn.wasted_time + value, x, v, vhist = calc_objective() + + entry = LogEntry(iter, tm, value, + sc*d.v_cumul_true[1], + sc*d.v_cumul_true[2], + sc*v[1], sc*v[2], + psnr(x, d.b_true), + ssim(x, d.b_true), + psnr(d.b_noisy, d.b_true), + ssim(d.b_noisy, d.b_true)) + + # (**) Collect a singly-linked list of log to avoid array resizing + # while iterating + stn = @set stn.log=LinkedListEntry(entry, stn.log) + + if !isnothing(vhist) + vhist=vhist.*sc + end + + if verb + @printf("%d/%d J=%f, PSNR=%f, SSIM=%f, avg. FPS=%f\n", + iter, params.maxiter, value, entry.psnr, + entry.ssim, entry.iter/entry.time) + if isa(stn.vis, Channel) + put_onlylatest!(stn.vis, ((d.b_noisy, x), + params.plot_movement, + stn.log, vhist)) + + end + end + + if params.save_images && (!haskey(params, :save_images_iters) + || iter ∈ params.save_images_iters) + fn = (t, ext) -> "$(params.save_prefix)_$(t)_frame$(iter).$(ext)" + save(File(format"PNG", fn("true", "png")), grayimg(d.b_true)) + save(File(format"PNG", fn("data", "png")), grayimg(d.b_noisy)) + save(File(format"PNG", fn("reco", "png")), grayimg(x)) + if !isnothing(vhist) + open(fn("movement", "txt"), "w") do io + writedlm(io, ["est_y" "est_x"]) + writedlm(io, vhist) + end + end + end + + stn = @set stn.wasted_time += (secs_ns() - verb_start) + + return stn + end + + hifientry = LogEntryHiFi(iter, sc*d.v_cumul_true[1], sc*d.v_cumul_true[2]) + st = @set st.log_hifi=LinkedListEntry(hifientry, st.log_hifi) + + return st + end + d=dnext + end + catch ex + if params.handle_interrupt && isa(ex, InterruptException) + # If SIGINT is received (user pressed ^C), terminate computations, + # returning current status. Effectively, we do not call `step()` again, + # ending the iterations, but letting the algorithm finish up. + # Assuming (**) above occurs atomically, `st.log` should be valid, but + # any results returned by the algorithm itself may be partial, as for + # reasons of efficiency we do *not* store results of an iteration until + # the next iteration is finished. + printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202) + st = @set st.aborted = true + else + rethrow(ex) + end + end + + return st +end + +function bg_visualise_enhanced(rc; fullscreen=false) + process_channel(rc) do d + imgs, plot_movement, log, vhist = d + do_visualise(imgs, refresh=false, fullscreen=fullscreen) + # Overlay movement + GR.settextcolorind(5) + GR.setcharheight(0.015) + GR.settextpath(GR.TEXT_PATH_RIGHT) + tx, ty = GR.wctondc(0, 1) + GR.text(tx, ty, @sprintf "FPS %.1f, SSIM %.2f, PSNR %.1f" (log.value.iter/log.value.time) log.value.ssim log.value.psnr) + if plot_movement + sc=1.0 + p=unfold_linked_list(log) + x=map(e -> 1.5+sc*e.v_cumul_true_x, p) + y=map(e -> 0.5+sc*e.v_cumul_true_y, p) + GR.setlinewidth(2) + GR.setlinecolorind(2) + GR.polyline(x, y) + x=map(e -> 1.5+sc*e.v_cumul_est_x, p) + y=map(e -> 0.5+sc*e.v_cumul_est_y, p) + GR.setlinecolorind(3) + GR.polyline(x, y) + if vhist != nothing + GR.setlinecolorind(4) + x=map(v -> 1.5+sc*v, vhist[:,2]) + y=map(v -> 0.5+sc*v, vhist[:,1]) + GR.polyline(x, y) + end + end + GR.updatews() + end +end + +############### +# Precompiling +############### + +# precompile(Tuple{typeof(GR.drawimage), Float64, Float64, Float64, Float64, Int64, Int64, Array{UInt32, 2}}) +# precompile(Tuple{Type{Plots.Plot{T} where T<:RecipesBase.AbstractBackend}, Plots.GRBackend, Int64, Base.Dict{Symbol, Any}, Base.Dict{Symbol, Any}, Array{Plots.Series, 1}, Nothing, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Base.Dict{Any, Plots.Subplot{T} where T<:RecipesBase.AbstractBackend}, Plots.EmptyLayout, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Bool}) +# precompile(Tuple{typeof(Plots._plot!), Plots.Plot{Plots.GRBackend}, Base.Dict{Symbol, Any}, Tuple{Array{ColorTypes.Gray{Float64}, 2}}}) + +end # Module