src/PredictPDPS.jl

Wed, 24 Apr 2024 17:01:58 +0300

author
Neil Dizon <neil.dizon@helsinki.fi>
date
Wed, 24 Apr 2024 17:01:58 +0300
changeset 31
cbedcfcc0557
parent 29
6a0ca7047f68
child 34
aca9c90f151c
permissions
-rw-r--r--

removed img folder

##################
# Our main module
##################

__precompile__()

module PredictPDPS

########################
# Load external modules
########################

using Printf
using FileIO
#using JLD2
using Setfield
using ImageQualityIndexes: assess_psnr, assess_ssim
using DelimitedFiles
import GR

using AlgTools.Util
using AlgTools.StructTools
using AlgTools.LinkedLists
using AlgTools.Comms
using ImageTools.Visualise: secs_ns, grayimg, do_visualise 
using ImageTools.ImFilter: gaussian

#####################
# Load local modules
#####################

include("OpticalFlow.jl")
include("Radon.jl")
include("ImGenerate.jl")
include("Algorithm.jl")
include("AlgorithmBoth.jl")
include("AlgorithmBothGreedyV.jl")
include("AlgorithmBothCumul.jl")
include("AlgorithmBothMulti.jl")
include("AlgorithmBothNL.jl")
include("AlgorithmFB.jl")
include("AlgorithmFBDual.jl")
include("PlotResults.jl")


# Additional
include("AlgorithmProximal.jl")
include("AlgorithmGreedy.jl")
include("AlgorithmRotation.jl")
include("AlgorithmNoPrediction.jl")
include("AlgorithmPrimalOnly.jl")
include("AlgorithmDualScaling.jl")
include("AlgorithmZeroDual.jl")
include("PET/PET.jl")


import .Algorithm,
       .AlgorithmBoth,
       .AlgorithmBothGreedyV,
       .AlgorithmBothCumul,
       .AlgorithmBothMulti,
       .AlgorithmBothNL,
       .AlgorithmFB,
       .AlgorithmFBDual,
       .AlgorithmProximal,
       .AlgorithmGreedy,
       .AlgorithmRotation,
       .AlgorithmNoPrediction,
       .AlgorithmPrimalOnly,
       .AlgorithmDualScaling,
       .AlgorithmZeroDual

using .ImGenerate
using .OpticalFlow: DisplacementFull, DisplacementConstant
using .PlotResults
using .PET

##############
# Our exports
##############

export run_experiments,
       batchrun_article,
       demo_known1, demo_known2, demo_known3,
       demo_unknown1,demo_unknown2,demo_unknown3,
       batchrun_denoising,
       batchrun_predictors,
       demo_denoising1, demo_denoising2, demo_denoising3, 
       demo_denoising4, demo_denoising5, demo_denoising6, demo_denoising7,
       demo_petS1, demo_petS2, demo_petS3, 
       demo_petS4, demo_petS5, demo_petS6, demo_petS7,
       demo_petB1, demo_petB2, demo_petB3, 
       demo_petB4, demo_petB5, demo_petB6, demo_petB7,
       batchrun_shepplogan, batchrun_brainphantom, batchrun_pet,
       plot_denoising, plot_pet, calculate_statistics
       
###################################
# Parameterisation and experiments
###################################

struct Experiment
    mod :: Module
    DisplacementT :: Type
    imgen :: ImGen
    params :: NamedTuple
end

function Base.show(io::IO, e::Experiment)
    displacementname(::Type{DisplacementFull}) = "DisplacementFull"
    displacementname(::Type{DisplacementConstant}) = "DisplacementConstant"
    print(io, "
    mod: $(e.mod)
    DisplacementT: $(displacementname(e.DisplacementT))
    imgen: $(e.imgen.name) $(e.imgen.dim[1])×$(e.imgen.dim[2])
    params: $(e.params ⬿ (kernel = "(not shown)",))
    ")
end

const default_save_prefix="img/"

const default_params = (
    ρ = 0,
    verbose_iter = 100,
    maxiter = 10000,
    save_results = true,
    save_images = true,
    save_images_iters = Set([1, 2, 3, 5,
                             10, 25, 30, 50,
                             100, 250, 300, 500,
                             1000, 2000, 2500, 3000, 4000, 5000,
                             6000, 7000, 7500, 8000, 9000, 10000, 8700]),
    pixelwise_displacement=false,
    dual_flow = true,
    prox_predict = true,
    handle_interrupt = true,
    init = :zero,
    plot_movement = false,
    stable_interval = Set(0),
)

const square = imgen_square((200, 300))
const lighthouse = imgen_shake("lighthouse", (200, 300))

const p_known₀_denoising = (
    noise_level = 0.5,
    shake_noise_level = 0.025,
    shake = 2.0,
    α = 1.0,
    ρ̃₀ = 1.0,
    σ̃₀ = 1.0,
    δ = 0.9,
    σ₀ = 1.0,
    τ₀ = 0.01,
    #stable_interval = Set(0),
    stable_interval = union(Set(2500:5000),Set(8700:10000)),
)

const p_known₀ = (
    noise_level = 0.5,
    shake_noise_level = 0.05,
    shake = 2,
    α = 0.15,
    ρ̃₀ = 1,
    σ̃₀ = 1,
    δ = 0.9,
    σ₀ = 1,
    τ₀ = 0.01,
)

const p_unknown₀ = (
    noise_level = 0.3,
    shake_noise_level = 0.05,
    shake = 2,
    α = 0.2,
    ρ̃₀ = 1,
    σ̃₀ = 1,
    σ₀ = 1,
    δ = 0.9,
    λ = 1,
    θ = (300*200)*100^3,
    kernel = gaussian((3, 3), (11, 11)),
    timestep = 0.5,
    displacement_count = 100,
    τ₀ = 0.01,
)


const experiments_pdps_known = (
    Experiment(Algorithm, DisplacementConstant, lighthouse,
               p_known₀_denoising ⬿ (phantom_ρ = 0,)),
    Experiment(Algorithm, DisplacementConstant, lighthouse,
               p_known₀ ⬿ (phantom_ρ = 100,)),
    Experiment(Algorithm, DisplacementConstant, square,
               p_known₀ ⬿ (phantom_ρ = 0,))
)

const experiments_pdps_unknown_multi = (
    Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse,
               p_unknown₀ ⬿ (phantom_ρ = 0,)),
    Experiment(AlgorithmBothMulti, DisplacementConstant, lighthouse,
               p_unknown₀ ⬿ (phantom_ρ = 100,)),
    Experiment(AlgorithmBothMulti, DisplacementConstant, square,
               p_unknown₀ ⬿ (phantom_ρ = 0,)),
)

const experiments_fb_known = (
    Experiment(AlgorithmFB, DisplacementConstant, lighthouse,
               p_known₀ ⬿ (τ̃₀=0.9, fb_inner_iterations = 10)),
)

const experiments_all = Iterators.flatten((
    experiments_pdps_known,
    experiments_pdps_unknown_multi,
    experiments_fb_known
))

const denoising_experiments_pdps_known = (
    Experiment(AlgorithmDualScaling, DisplacementConstant, lighthouse,
               p_known₀_denoising),
    Experiment(AlgorithmGreedy, DisplacementConstant, lighthouse,
               p_known₀_denoising),  
    Experiment(AlgorithmNoPrediction, DisplacementConstant, lighthouse,
               p_known₀_denoising),  
    Experiment(AlgorithmPrimalOnly, DisplacementConstant, lighthouse,
               p_known₀_denoising),                
    Experiment(AlgorithmProximal, DisplacementConstant, lighthouse,
               p_known₀_denoising ⬿ (phantom_ρ = 100,)),
    Experiment(AlgorithmRotation, DisplacementConstant, lighthouse,
               p_known₀_denoising),
    Experiment(AlgorithmZeroDual, DisplacementConstant, lighthouse,
               p_known₀_denoising),  
)

const denoising_experiments_all = Iterators.flatten((
    denoising_experiments_pdps_known,
))

################
# Log
################

struct LogEntry <: IterableStruct
    iter :: Int
    time :: Float64
    function_value :: Float64
    #v_cumul_true_y :: Float64
    #v_cumul_true_x :: Float64
    #v_cumul_est_y :: Float64
    #v_cumul_est_x :: Float64
    psnr :: Float64
    ssim :: Float64
    #psnr_data :: Float64
    #ssim_data :: Float64
end

struct LogEntryHiFi <: IterableStruct
    iter :: Int
    v_cumul_true_y :: Float64
    v_cumul_true_x :: Float64
end

###############
# Main routine
###############

struct State
    vis :: Union{Channel,Bool,Nothing}
    start_time :: Union{Real,Nothing}
    wasted_time :: Real
    log :: LinkedList{LogEntry}
    log_hifi :: LinkedList{LogEntryHiFi}
    aborted :: Bool
end

function name(e::Experiment, p)
    ig = e.imgen
    # return "$(ig.name)_$(e.mod.identifier)_$(@sprintf "%x" hash(p))"
    return "$(ig.name)_$(e.mod.identifier)_$(Int64(100*p.α))_$(Int64(10000*p.σ₀))_$(Int64(10000*p.τ₀))"
end

function write_tex(texfile, e_params)
    open(texfile, "w") do io
        wp = (n, v) -> println(io, "\\def\\EXPPARAM$(n){$(v)}")
        wf = (n, s) -> if isdefined(e_params, s)
                            wp(n, getfield(e_params, s))
                        end
        wf("alpha", :α)
        wf("sigmazero", :σ₀)
        wf("tauzero", :τ₀)
        wf("tildetauzero", :τ̃₀)
        wf("delta", :δ)
        wf("lambda", :λ)
        wf("theta", :θ)
        wf("maxiter", :maxiter)
        wf("noiselevel", :noise_level)
        wf("shakenoiselevel", :shake_noise_level)
        wf("shake", :shake)
        wf("timestep", :timestep)
        wf("displacementcount", :displacementcount)
        wf("phantomrho", :phantom_ρ)
        if isdefined(e_params, :σ₀)
            wp("sigma", (e_params.σ₀ == 1 ? "" : "$(e_params.σ₀)") * "\\sigma_{\\max}")
        end
    end
end                

function run_experiments(;visualise=true,
                          recalculate=true,
                          experiments,
                          save_prefix=default_save_prefix,
                          fullscreen=false,
                          kwargs...)

    # Create visualisation
    if visualise
        rc = Channel(1)
        visproc = Threads.@spawn bg_visualise_enhanced(rc, fullscreen=fullscreen)
        bind(rc, visproc)
        vis = rc
    else
        vis = false
    end

    # Run all experiments
    for e ∈ experiments

        # Parameters for this experiment
        e_params = default_params ⬿ e.params ⬿ kwargs
        ename = name(e, e_params)
        e_params = e_params ⬿ (save_prefix = save_prefix * ename,
                                dynrange = e.imgen.dynrange,
                                Λ = e.imgen.Λ)

        if recalculate || !isfile(e_params.save_prefix * ".txt")
            println("Running experiment \"$(ename)\"")

            # Start data generation task
            datachannel = Channel{OnlineData{e.DisplacementT}}(2)
            gentask = Threads.@spawn e.imgen.f(e.DisplacementT, datachannel, e_params)
            bind(datachannel, gentask)

            # Run algorithm
            iterate = curry(iterate_visualise, datachannel,
                            State(vis, nothing, 0.0, nothing, nothing, false))
                            
            x, y, st = e.mod.solve(e.DisplacementT;
                                   dim=e.imgen.dim,
                                   iterate=iterate,
                                   params=e_params)

            # Clear non-saveable things
            st = @set st.vis = nothing

            println("Wasted_time: $(st.wasted_time)s")

            if e_params.save_results
                println("Saving " * e_params.save_prefix * "(.txt,_hifi.txt,_params.tex)")

                perffile = e_params.save_prefix * ".txt"
                hififile = e_params.save_prefix * "_hifi.txt"
                texfile = e_params.save_prefix * "_params.tex"
                # datafile = e_params.save_prefix * ".jld2"

                write_log(perffile, st.log, "# params = $(e_params)\n")
                #write_log(hififile, st.log_hifi, "# params = $(e_params)\n")
                #write_tex(texfile, e_params)
                # @save datafile x y st params
            end

            close(datachannel)
            wait(gentask)

            if st.aborted
                break
            end
        else
            println("Skipping already computed experiment \"$(ename)\"")
            # texfile = e_params.save_prefix * "_params.tex"
            # write_tex(texfile, e_params)
        end
    end

    if visualise
        # Tell subprocess to finish, and wait
        put!(rc, nothing)
        close(rc)
        wait(visproc)
    end

    return
end

#######################
# Demos and batch runs
#######################

function demo(experiment; kwargs...)
    run_experiments(;experiments=(experiment,),
                    save_results=false,
                    save_images=false,
                    visualise=true,
                    recalculate=true,
                    verbose_iter=50,
                    fullscreen=true,
                    kwargs...)
end

demo_known1 = () -> demo(experiments_pdps_known[3])
demo_known2 = () -> demo(experiments_pdps_known[1])
demo_known3 = () -> demo(experiments_pdps_known[2])

demo_unknown1 = () -> demo(experiments_pdps_unknown_multi[3], plot_movement=true)
demo_unknown2 = () -> demo(experiments_pdps_unknown_multi[1], plot_movement=true)
demo_unknown3 = () -> demo(experiments_pdps_unknown_multi[2], plot_movement=true)

demo_denoising1 = () -> demo(denoising_experiments_pdps_known[1]) # Dual scaling
demo_denoising2 = () -> demo(denoising_experiments_pdps_known[2]) # Greedy
demo_denoising3 = () -> demo(denoising_experiments_pdps_known[3]) # No Prediction
demo_denoising4 = () -> demo(denoising_experiments_pdps_known[4]) # Primal Only
demo_denoising5 = () -> demo(denoising_experiments_pdps_known[5]) # Proximal (old)
demo_denoising6 = () -> demo(denoising_experiments_pdps_known[6]) # Rotation
demo_denoising7 = () -> demo(denoising_experiments_pdps_known[7]) # Zero dual

function batchrun_article(kwargs...)
    run_experiments(;experiments=experiments_all,
                    save_results=true,
                    save_images=true,
                    visualise=false,
                    recalculate=false,
                    kwargs...)
end

function batchrun_denoising(;kwargs...)
    run_experiments(;experiments=denoising_experiments_all,
                    save_results=true,
                    save_images=true,
                    visualise=false,
                    recalculate=false,
                    kwargs...)
end


function batchrun_predictors(;kwargs...)
    batchrun_denoising(;kwargs...)
    batchrun_pet(;kwargs...)
end

######################################################
# Iterator that does visualisation and log collection
######################################################

function iterate_visualise(datachannel::Channel{OnlineData{DisplacementT}},
                           st :: State,
                           step :: Function,
                           params :: NamedTuple) where DisplacementT
    try
        sc = nothing

        d = take!(datachannel)

        for iter=1:params.maxiter 
            dnext = take!(datachannel)
            st = step(d.b_noisy, d.v, dnext.b_noisy) do calc_objective
                stn = st

                if isnothing(stn.start_time)
                    # The Julia precompiler is a miserable joke, apparently not crossing module
                    # boundaries, so only start timing after the first iteration.
                    stn = @set stn.start_time=secs_ns()
                end

                verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0

                # Normalise movement to image dimensions so
                # our TikZ plotting code doesn't need to know
                # the image pixel size.
                sc = 1.0./maximum(size(d.b_true))
                    
                if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0)
                    verb_start = secs_ns()
                    tm = verb_start - stn.start_time - stn.wasted_time
                    value, x, v, vhist = calc_objective()

                    entry = LogEntry(iter, tm, value,
                                     #sc*d.v_cumul_true[1],
                                     #sc*d.v_cumul_true[2],
                                     #sc*v[1], sc*v[2], 
                                     assess_psnr(x, d.b_true),
                                     assess_ssim(x, d.b_true),
                                     #assess_psnr(d.b_noisy, d.b_true),
                                     #assess_ssim(d.b_noisy, d.b_true)
                                     )

                    # (**) Collect a singly-linked list of log to avoid array resizing
                    # while iterating
                    stn = @set stn.log=LinkedListEntry(entry, stn.log)
                    
                    if !isnothing(vhist)
                        vhist=vhist.*sc
                    end

                    if verb
                        @printf("%d/%d J=%f, PSNR=%f, SSIM=%f, avg. FPS=%f\n",
                                iter, params.maxiter, value, entry.psnr,
                                entry.ssim, entry.iter/entry.time)
                        if isa(stn.vis, Channel)
                            put_onlylatest!(stn.vis, ((d.b_noisy, x),
                                                      params.plot_movement,
                                                      stn.log, vhist))

                        end
                    end

                    if params.save_images && (!haskey(params, :save_images_iters)
                                              || iter ∈ params.save_images_iters)
                        fn = (t, ext) -> "$(params.save_prefix)_$(t)_frame$(iter).$(ext)"
                        # save(File(format"PNG", fn("true", "png")), grayimg(d.b_true))
                        # save(File(format"PNG", fn("data", "png")), grayimg(d.b_noisy))
                        save(File(format"PNG", fn("reco", "png")), grayimg(x))
                        if !isnothing(vhist)
                            open(fn("movement", "txt"), "w") do io
                                writedlm(io, ["est_y" "est_x"])
                                writedlm(io, vhist)
                            end
                        end
                    end

                    stn = @set stn.wasted_time += (secs_ns() - verb_start)

                    return stn
                end

                hifientry = LogEntryHiFi(iter, sc*d.v_cumul_true[1], sc*d.v_cumul_true[2])
                st = @set st.log_hifi=LinkedListEntry(hifientry, st.log_hifi)

                return st
            end
            d=dnext
        end
    catch ex
        if params.handle_interrupt && isa(ex, InterruptException)
            # If SIGINT is received (user pressed ^C), terminate computations,
            # returning current status. Effectively, we do not call `step()` again,
            # ending the iterations, but letting the algorithm finish up.
            # Assuming (**) above occurs atomically, `st.log` should be valid, but
            # any results returned by the algorithm itself may be partial, as for
            # reasons of efficiency we do *not* store results of an iteration until
            # the next iteration is finished.
            printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202)
            st = @set st.aborted = true
        else
            rethrow(ex)
        end
    end
    
    return st
end

function bg_visualise_enhanced(rc; fullscreen=false)
    process_channel(rc) do d
        imgs, plot_movement, log, vhist = d
        do_visualise(imgs, refresh=false, fullscreen=fullscreen)
        # Overlay movement
        GR.settextcolorind(5)
        GR.setcharheight(0.015)
        GR.settextpath(GR.TEXT_PATH_RIGHT)
        tx, ty = GR.wctondc(0, 1)
        GR.text(tx, ty, @sprintf "FPS %.1f, SSIM %.2f, PSNR %.1f" (log.value.iter/log.value.time) log.value.ssim log.value.psnr)
        if plot_movement
            sc=1.0
            p=unfold_linked_list(log)
            x=map(e -> 1.5+sc*e.v_cumul_true_x, p)
            y=map(e -> 0.5+sc*e.v_cumul_true_y, p)
            GR.setlinewidth(2)
            GR.setlinecolorind(2)
            GR.polyline(x, y)
            x=map(e -> 1.5+sc*e.v_cumul_est_x, p)
            y=map(e -> 0.5+sc*e.v_cumul_est_y, p)
            GR.setlinecolorind(3)
            GR.polyline(x, y)
            if vhist != nothing
                GR.setlinecolorind(4)
                x=map(v -> 1.5+sc*v, vhist[:,2])
                y=map(v -> 0.5+sc*v, vhist[:,1])
                GR.polyline(x, y)
            end
        end
        GR.updatews() 
    end
end

#########################
# Plotting SSIM and PSNR
#########################

function plot_denoising(kwargs...)
    ssim_plot("lighthouse")
    psnr_plot("lighthouse")
    fv_plot("lighthouse")
end


###############
# Precompiling
###############

# precompile(Tuple{typeof(GR.drawimage), Float64, Float64, Float64, Float64, Int64, Int64, Array{UInt32, 2}})
# precompile(Tuple{Type{Plots.Plot{T} where T<:RecipesBase.AbstractBackend}, Plots.GRBackend, Int64, Base.Dict{Symbol, Any}, Base.Dict{Symbol, Any}, Array{Plots.Series, 1}, Nothing, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Base.Dict{Any, Plots.Subplot{T} where T<:RecipesBase.AbstractBackend}, Plots.EmptyLayout, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Bool})
# precompile(Tuple{typeof(Plots._plot!), Plots.Plot{Plots.GRBackend}, Base.Dict{Symbol, Any}, Tuple{Array{ColorTypes.Gray{Float64}, 2}}})

end # Module

mercurial