src/PredictPDPS.jl

changeset 0
a55e35d20336
child 3
0cfbe340796d
child 2
be7cab83b14a
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/PredictPDPS.jl	Tue Apr 07 14:19:48 2020 -0500
@@ -0,0 +1,508 @@
+##################
+# Our main module
+##################
+
+__precompile__()
+
+module PredictPDPS
+
+########################
+# Load external modules
+########################
+
+using Printf
+using FileIO
+#using JLD2
+using Setfield
+using ImageQualityIndexes: psnr, ssim
+using DelimitedFiles
+import GR
+
+using AlgTools.Util
+using AlgTools.StructTools
+using AlgTools.LinkedLists
+using AlgTools.Comms
+using ImageTools.Visualise: secs_ns, grayimg, do_visualise 
+using ImageTools.ImFilter: gaussian
+
+#####################
+# Load local modules
+#####################
+
+include("OpticalFlow.jl")
+include("ImGenerate.jl")
+include("Algorithm.jl")
+include("AlgorithmBoth.jl")
+include("AlgorithmBothGreedyV.jl")
+include("AlgorithmBothCumul.jl")
+include("AlgorithmBothMulti.jl")
+include("AlgorithmBothNL.jl")
+include("AlgorithmFB.jl")
+include("AlgorithmFBDual.jl")
+
+import .Algorithm,
+       .AlgorithmBoth,
+       .AlgorithmBothGreedyV,
+       .AlgorithmBothCumul,
+       .AlgorithmBothMulti,
+       .AlgorithmBothNL,
+       .AlgorithmFB,
+       .AlgorithmFBDual
+
+using .ImGenerate
+using .OpticalFlow: DisplacementFull, DisplacementConstant
+
+##############
+# Our exports
+##############
+
+export run_experiments,
+       batchrun_article,
+       demo_known1,
+       demo_known2,
+       demo_known3,
+       demo_unknown1,
+       demo_unknown2,
+       demo_unknown3
+
+###################################
+# Parameterisation and experiments
+###################################
+
+struct Experiment
+    mod :: Module
+    DisplacementT :: Type
+    imgen :: ImGen
+    params :: NamedTuple
+end
+
+function Base.show(io::IO, e::Experiment)
+    displacementname(::Type{DisplacementFull}) = "DisplacementFull"
+    displacementname(::Type{DisplacementConstant}) = "DisplacementConstant"
+    print(io, "
+    mod: $(e.mod)
+    DisplacementT: $(displacementname(e.DisplacementT))
+    imgen: $(e.imgen.name) $(e.imgen.dim[1])×$(e.imgen.dim[2])
+    params: $(e.params ⬿ (kernel = "(not shown)",))
+    ")
+end
+
+const default_save_prefix="img/"
+
+const default_params = (
+    ρ = 0,
+    verbose_iter = 100,
+    maxiter = 10000,
+    save_results = true,
+    save_images = true,
+    save_images_iters = Set([1, 2, 3, 5,
+                             10, 25, 30, 50,
+                             100, 250, 300, 500,
+                             1000, 2500, 3000, 5000,
+                             10000]),
+    pixelwise_displacement=false,
+    dual_flow = true,
+    prox_predict = true,
+    handle_interrupt = true,
+    init = :zero,
+    plot_movement = false,
+)
+
+const square = imgen_square((200, 300))
+const lighthouse = imgen_shake("lighthouse", (200, 300))
+
+const p_known₀ = (
+    noise_level = 0.5,
+    shake_noise_level = 0.05,
+    shake = 2,
+    α = 1,
+    ρ̃₀ = 1,
+    σ̃₀ = 1,
+    δ = 0.9,
+    σ₀ = 1,
+    τ₀ = 0.01,
+)
+
+const p_unknown₀ = (
+    noise_level = 0.3,
+    shake_noise_level = 0.05,
+    shake = 2,
+    α = 0.2,
+    ρ̃₀ = 1,
+    σ̃₀ = 1,
+    σ₀ = 1,
+    δ = 0.9,
+    λ = 1,
+    θ = (300*200)*100^3,
+    kernel = gaussian((3, 3), (11, 11)),
+    timestep = 0.5,
+    displacement_count = 100,
+    τ₀ = 0.01,
+)
+
+const experiments_pdps_known = (
+    Experiment(Algorithm, DisplacementConstant, lighthouse,
+               p_known₀ ⬿ (phantom_ρ = 0,)),
+    Experiment(Algorithm, DisplacementConstant, lighthouse,
+               p_known₀ ⬿ (phantom_ρ = 100,)),
+    Experiment(Algorithm, DisplacementConstant, square,
+               p_known₀ ⬿ (phantom_ρ = 0,))
+)
+
+const experiments_pdps_unknown_multi = (
+    Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse,
+               p_unknown₀ ⬿ (phantom_ρ = 0,)),
+    Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse,
+               p_unknown₀ ⬿ (phantom_ρ = 100,)),
+    Experiment(AlgorithmBothMulti, DisplacementConstant, square,
+               p_unknown₀ ⬿ (phantom_ρ = 0,)),
+)
+
+const experiments_fb_known = (
+    Experiment(AlgorithmFB, DisplacementConstant, lighthouse,
+               p_known₀ ⬿ (τ̃₀=0.9, fb_inner_iterations = 10)),
+)
+
+const experiments_all = Iterators.flatten((
+    experiments_pdps_known,
+    experiments_pdps_unknown_multi,
+    experiments_fb_known
+))
+
+################
+# Log
+################
+
+struct LogEntry <: IterableStruct
+    iter :: Int
+    time :: Float64
+    function_value :: Float64
+    v_cumul_true_y :: Float64
+    v_cumul_true_x :: Float64
+    v_cumul_est_y :: Float64
+    v_cumul_est_x :: Float64
+    psnr :: Float64
+    ssim :: Float64
+    psnr_data :: Float64
+    ssim_data :: Float64
+end
+
+struct LogEntryHiFi <: IterableStruct
+    iter :: Int
+    v_cumul_true_y :: Float64
+    v_cumul_true_x :: Float64
+end
+
+###############
+# Main routine
+###############
+
+struct State
+    vis :: Union{Channel,Bool,Nothing}
+    start_time :: Union{Real,Nothing}
+    wasted_time :: Real
+    log :: LinkedList{LogEntry}
+    log_hifi :: LinkedList{LogEntryHiFi}
+    aborted :: Bool
+end
+
+function name(e::Experiment, p)
+    ig = e.imgen
+    return "$(ig.name)_$(e.mod.identifier)_$(@sprintf "%x" hash(p))"
+end
+
+function write_tex(texfile, e_params)
+    open(texfile, "w") do io
+        wp = (n, v) -> println(io, "\\def\\EXPPARAM$(n){$(v)}")
+        wf = (n, s) -> if isdefined(e_params, s)
+                            wp(n, getfield(e_params, s))
+                        end
+        wf("alpha", :α)
+        wf("sigmazero", :σ₀)
+        wf("tauzero", :τ₀)
+        wf("tildetauzero", :τ̃₀)
+        wf("delta", :δ)
+        wf("lambda", :λ)
+        wf("theta", :θ)
+        wf("maxiter", :maxiter)
+        wf("noiselevel", :noise_level)
+        wf("shakenoiselevel", :shake_noise_level)
+        wf("shake", :shake)
+        wf("timestep", :timestep)
+        wf("displacementcount", :displacementcount)
+        wf("phantomrho", :phantom_ρ)
+        if isdefined(e_params, :σ₀)
+            wp("sigma", (e_params.σ₀ == 1 ? "" : "$(e_params.σ₀)") * "\\sigma_{\\max}")
+        end
+    end
+end                
+
+function run_experiments(;visualise=true,
+                          recalculate=true,
+                          experiments,
+                          save_prefix=default_save_prefix,
+                          fullscreen=false,
+                          kwargs...)
+
+    # Create visualisation
+    if visualise
+        rc = Channel(1)
+        visproc = Threads.@spawn bg_visualise_enhanced(rc, fullscreen=fullscreen)
+        bind(rc, visproc)
+        vis = rc
+    else
+        vis = false
+    end
+
+    # Run all experiments
+    for e ∈ experiments
+
+        # Parameters for this experiment
+        e_params = default_params ⬿ e.params ⬿ kwargs
+        ename = name(e, e_params)
+        e_params = e_params ⬿ (save_prefix = save_prefix * ename,
+                                dynrange = e.imgen.dynrange,
+                                Λ = e.imgen.Λ)
+
+        if recalculate || !isfile(e_params.save_prefix * ".txt")
+            println("Running experiment \"$(ename)\"")
+
+            # Start data generation task
+            datachannel = Channel{OnlineData{e.DisplacementT}}(2)
+            gentask = Threads.@spawn e.imgen.f(e.DisplacementT, datachannel, e_params)
+            bind(datachannel, gentask)
+
+            # Run algorithm
+            iterate = curry(iterate_visualise, datachannel,
+                            State(vis, nothing, 0.0, nothing, nothing, false))
+                            
+            x, y, st = e.mod.solve(e.DisplacementT;
+                                   dim=e.imgen.dim,
+                                   iterate=iterate,
+                                   params=e_params)
+
+            # Clear non-saveable things
+            st = @set st.vis = nothing
+
+            println("Wasted_time: $(st.wasted_time)s")
+
+            if e_params.save_results
+                println("Saving " * e_params.save_prefix * "(.txt,_hifi.txt,_params.tex)")
+
+                perffile = e_params.save_prefix * ".txt"
+                hififile = e_params.save_prefix * "_hifi.txt"
+                texfile = e_params.save_prefix * "_params.tex"
+                # datafile = e_params.save_prefix * ".jld2"
+
+                write_log(perffile, st.log, "# params = $(e_params)\n")
+                write_log(hififile, st.log_hifi, "# params = $(e_params)\n")
+                write_tex(texfile, e_params)
+                # @save datafile x y st params
+            end
+
+            close(datachannel)
+            wait(gentask)
+
+            if st.aborted
+                break
+            end
+        else
+            println("Skipping already computed experiment \"$(ename)\"")
+            # texfile = e_params.save_prefix * "_params.tex"
+            # write_tex(texfile, e_params)
+        end
+    end
+
+    if visualise
+        # Tell subprocess to finish, and wait
+        put!(rc, nothing)
+        close(rc)
+        wait(visproc)
+    end
+
+    return
+end
+
+#######################
+# Demos and batch runs
+#######################
+
+function demo(experiment; kwargs...)
+    run_experiments(;experiments=(experiment,),
+                    save_results=false,
+                    save_images=false,
+                    visualise=true,
+                    recalculate=true,
+                    verbose_iter=10,
+                    fullscreen=true,
+                    kwargs...)
+end
+
+demo_known1 = () -> demo(experiments_pdps_known[3])
+demo_known2 = () -> demo(experiments_pdps_known[1])
+demo_known3 = () -> demo(experiments_pdps_known[2])
+demo_unknown1 = () -> demo(experiments_pdps_unknown_multi[3], plot_movement=true)
+demo_unknown2 = () -> demo(experiments_pdps_unknown_multi[1], plot_movement=true)
+demo_unknown3 = () -> demo(experiments_pdps_unknown_multi[2], plot_movement=true)
+
+function batchrun_article(kwargs...)
+    run_experiments(;experiments=experiments_all,
+                    save_results=true,
+                    save_images=true,
+                    visualise=false,
+                    recalculate=false,
+                    kwargs...)
+end
+
+######################################################
+# Iterator that does visualisation and log collection
+######################################################
+
+function iterate_visualise(datachannel::Channel{OnlineData{DisplacementT}},
+                           st :: State,
+                           step :: Function,
+                           params :: NamedTuple) where DisplacementT
+    try
+        sc = nothing
+
+        d = take!(datachannel)
+
+        for iter=1:params.maxiter 
+            dnext = take!(datachannel)
+            st = step(d.b_noisy, d.v, dnext.b_noisy) do calc_objective
+                stn = st
+
+                if isnothing(stn.start_time)
+                    # The Julia precompiler is a miserable joke, apparently not crossing module
+                    # boundaries, so only start timing after the first iteration.
+                    stn = @set stn.start_time=secs_ns()
+                end
+
+                verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0
+
+                # Normalise movement to image dimensions so
+                # our TikZ plotting code doesn't need to know
+                # the image pixel size.
+                sc = 1.0./maximum(size(d.b_true))
+                    
+                if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0)
+                    verb_start = secs_ns()
+                    tm = verb_start - stn.start_time - stn.wasted_time
+                    value, x, v, vhist = calc_objective()
+
+                    entry = LogEntry(iter, tm, value,
+                                     sc*d.v_cumul_true[1],
+                                     sc*d.v_cumul_true[2],
+                                     sc*v[1], sc*v[2], 
+                                     psnr(x, d.b_true),
+                                     ssim(x, d.b_true),
+                                     psnr(d.b_noisy, d.b_true),
+                                     ssim(d.b_noisy, d.b_true))
+
+                    # (**) Collect a singly-linked list of log to avoid array resizing
+                    # while iterating
+                    stn = @set stn.log=LinkedListEntry(entry, stn.log)
+                    
+                    if !isnothing(vhist)
+                        vhist=vhist.*sc
+                    end
+
+                    if verb
+                        @printf("%d/%d J=%f, PSNR=%f, SSIM=%f, avg. FPS=%f\n",
+                                iter, params.maxiter, value, entry.psnr,
+                                entry.ssim, entry.iter/entry.time)
+                        if isa(stn.vis, Channel)
+                            put_onlylatest!(stn.vis, ((d.b_noisy, x),
+                                                      params.plot_movement,
+                                                      stn.log, vhist))
+
+                        end
+                    end
+
+                    if params.save_images && (!haskey(params, :save_images_iters)
+                                              || iter ∈ params.save_images_iters)
+                        fn = (t, ext) -> "$(params.save_prefix)_$(t)_frame$(iter).$(ext)"
+                        save(File(format"PNG", fn("true", "png")), grayimg(d.b_true))
+                        save(File(format"PNG", fn("data", "png")), grayimg(d.b_noisy))
+                        save(File(format"PNG", fn("reco", "png")), grayimg(x))
+                        if !isnothing(vhist)
+                            open(fn("movement", "txt"), "w") do io
+                                writedlm(io, ["est_y" "est_x"])
+                                writedlm(io, vhist)
+                            end
+                        end
+                    end
+
+                    stn = @set stn.wasted_time += (secs_ns() - verb_start)
+
+                    return stn
+                end
+
+                hifientry = LogEntryHiFi(iter, sc*d.v_cumul_true[1], sc*d.v_cumul_true[2])
+                st = @set st.log_hifi=LinkedListEntry(hifientry, st.log_hifi)
+
+                return st
+            end
+            d=dnext
+        end
+    catch ex
+        if params.handle_interrupt && isa(ex, InterruptException)
+            # If SIGINT is received (user pressed ^C), terminate computations,
+            # returning current status. Effectively, we do not call `step()` again,
+            # ending the iterations, but letting the algorithm finish up.
+            # Assuming (**) above occurs atomically, `st.log` should be valid, but
+            # any results returned by the algorithm itself may be partial, as for
+            # reasons of efficiency we do *not* store results of an iteration until
+            # the next iteration is finished.
+            printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202)
+            st = @set st.aborted = true
+        else
+            rethrow(ex)
+        end
+    end
+    
+    return st
+end
+
+function bg_visualise_enhanced(rc; fullscreen=false)
+    process_channel(rc) do d
+        imgs, plot_movement, log, vhist = d
+        do_visualise(imgs, refresh=false, fullscreen=fullscreen)
+        # Overlay movement
+        GR.settextcolorind(5)
+        GR.setcharheight(0.015)
+        GR.settextpath(GR.TEXT_PATH_RIGHT)
+        tx, ty = GR.wctondc(0, 1)
+        GR.text(tx, ty, @sprintf "FPS %.1f, SSIM %.2f, PSNR %.1f" (log.value.iter/log.value.time) log.value.ssim log.value.psnr)
+        if plot_movement
+            sc=1.0
+            p=unfold_linked_list(log)
+            x=map(e -> 1.5+sc*e.v_cumul_true_x, p)
+            y=map(e -> 0.5+sc*e.v_cumul_true_y, p)
+            GR.setlinewidth(2)
+            GR.setlinecolorind(2)
+            GR.polyline(x, y)
+            x=map(e -> 1.5+sc*e.v_cumul_est_x, p)
+            y=map(e -> 0.5+sc*e.v_cumul_est_y, p)
+            GR.setlinecolorind(3)
+            GR.polyline(x, y)
+            if vhist != nothing
+                GR.setlinecolorind(4)
+                x=map(v -> 1.5+sc*v, vhist[:,2])
+                y=map(v -> 0.5+sc*v, vhist[:,1])
+                GR.polyline(x, y)
+            end
+        end
+        GR.updatews() 
+    end
+end
+
+###############
+# Precompiling
+###############
+
+# precompile(Tuple{typeof(GR.drawimage), Float64, Float64, Float64, Float64, Int64, Int64, Array{UInt32, 2}})
+# precompile(Tuple{Type{Plots.Plot{T} where T<:RecipesBase.AbstractBackend}, Plots.GRBackend, Int64, Base.Dict{Symbol, Any}, Base.Dict{Symbol, Any}, Array{Plots.Series, 1}, Nothing, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Base.Dict{Any, Plots.Subplot{T} where T<:RecipesBase.AbstractBackend}, Plots.EmptyLayout, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Bool})
+# precompile(Tuple{typeof(Plots._plot!), Plots.Plot{Plots.GRBackend}, Base.Dict{Symbol, Any}, Tuple{Array{ColorTypes.Gray{Float64}, 2}}})
+
+end # Module

mercurial