src/PET/PET.jl

Fri, 19 Apr 2024 17:05:54 +0300

author
Neil Dizon <neil.dizon@helsinki.fi>
date
Fri, 19 Apr 2024 17:05:54 +0300
changeset 9
4eaff3502eea
parent 8
e4ad8f7ce671
child 11
97e9c745a000
permissions
-rw-r--r--

readme corrections

__precompile__()

module PET

########################
# Load external modules
########################

using Printf
using FileIO
#using JLD2
using Setfield
using ImageQualityIndexes: assess_psnr, assess_ssim
using DelimitedFiles
import GR

using AlgTools.Util
using AlgTools.StructTools
using AlgTools.LinkedLists
using AlgTools.Comms
using ImageTools.Visualise: secs_ns, grayimg, do_visualise 
using ImageTools.ImFilter: gaussian

# For PET
#using Colors
#using ColorTypes
using ColorSchemes
#using PerceptualColourMaps

#####################
# Load local modules
#####################
include("OpticalFlow.jl")
include("Radon.jl")
include("ImGenerate.jl")
include("AlgorithmDualScaling.jl")
include("AlgorithmGreedy.jl")
include("AlgorithmNoPrediction.jl")
include("AlgorithmPrimalOnly.jl")
include("AlgorithmProximal.jl")
include("AlgorithmRotation.jl")

using .Radon 
using .ImGenerate
using .OpticalFlow
using .AlgorithmDualScaling
using .AlgorithmGreedy
using .AlgorithmNoPrediction
using .AlgorithmPrimalOnly
using .AlgorithmProximal
using .AlgorithmRotation


##############
# Our exports
##############

export demo_pet1, demo_pet2, demo_pet3, demo_pet4, demo_pet5, demo_pet6, batchrun_pet

###################################
# Parameterisation and experiments
###################################

struct Experiment
    mod :: Module
    DisplacementT :: Type
    imgen :: ImGen
    params :: NamedTuple
end

function Base.show(io::IO, e::Experiment)
    displacementname(::Type{DisplacementFull}) = "DisplacementFull"
    displacementname(::Type{DisplacementConstant}) = "DisplacementConstant"
    print(io, "
    mod: $(e.mod)
    DisplacementT: $(displacementname(e.DisplacementT))
    imgen: $(e.imgen.name) $(e.imgen.dim[1])×$(e.imgen.dim[2])
    params: $(e.params ⬿ (kernel = "(not shown)",))
    ")
end

const default_save_prefix="img/"

const default_params = (
    ρ = 0,
    verbose_iter = 100,
    maxiter = 4000,
    save_results = true,
    save_images = true,
    save_images_iters = Set([100, 200, 300, 500, 800, 1000,
                             1300, 1500, 1800, 2000, 4000]),
    pixelwise_displacement=false,
    dual_flow = true,
    prox_predict = true,
    handle_interrupt = true,
    init = :zero,
    plot_movement = false,
)

const p_known₀_pet = (
    noise_level = 0.5,
    shake_noise_level = 0.1,
    shake = 1.0, 
    rotation_factor = 0.15,                  
    rotation_noise_level = 0.0025,          
    α = 1.0,
    ρ̃₀ = 1.0,
    σ̃₀ = 0.0001,
    δ = 0.9,
    σ₀ = 0.5,
    τ₀ = 0.9,
    λ = 1,
    origsize = 256,             
    radondims = [128, 64],
    sz = (256,256),
    scale = 1,
    c = 1.0,
    sino_sparsity = 0.5,
    L = 300.0,
    L_experiment = false,
)

const petscan = imgen_shepplogan_radon(p_known₀_pet.origsize, p_known₀_pet.sz)

const pet_experiments_pdps_known = (
    Experiment(AlgorithmDualScaling, DisplacementConstant, petscan,
               p_known₀_pet),
    Experiment(AlgorithmGreedy, DisplacementConstant, petscan,
               p_known₀_pet),    
    Experiment(AlgorithmNoPrediction, DisplacementConstant, petscan,
               p_known₀_pet),          
    Experiment(AlgorithmPrimalOnly, DisplacementConstant, petscan,
               p_known₀_pet), 
    Experiment(AlgorithmProximal, DisplacementConstant, petscan,
               p_known₀_pet ⬿ (phantom_ρ = 100,)),
    Experiment(AlgorithmRotation, DisplacementConstant, petscan,
               p_known₀_pet),                         
)

const pet_experiments_all = Iterators.flatten((
    pet_experiments_pdps_known,
))

################
# Log
################

struct LogEntry <: IterableStruct
    iter :: Int
    time :: Float64
    function_value :: Float64
    #v_cumul_true_y :: Float64
    #v_cumul_true_x :: Float64
    #v_cumul_est_y :: Float64
    #v_cumul_est_x :: Float64
    psnr :: Float64
    ssim :: Float64
    #psnr_data :: Float64
    #ssim_data :: Float64
end

struct LogEntryHiFi <: IterableStruct
    iter :: Int
    v_cumul_true_y :: Float64
    v_cumul_true_x :: Float64
end

###############
# Main routine
###############

struct State
    vis :: Union{Channel,Bool,Nothing}
    start_time :: Union{Real,Nothing}
    wasted_time :: Real
    log :: LinkedList{LogEntry}
    log_hifi :: LinkedList{LogEntryHiFi}
    aborted :: Bool
end

function name(e::Experiment, p)
    ig = e.imgen
    # return "$(ig.name)_$(e.mod.identifier)_$(@sprintf "%x" hash(p))"
    return "$(ig.name)_$(e.mod.identifier)_$(Int64(10000*p.σ₀))_$(Int64(10000*p.τ₀))"
end

function write_tex(texfile, e_params)
    open(texfile, "w") do io
        wp = (n, v) -> println(io, "\\def\\EXPPARAM$(n){$(v)}")
        wf = (n, s) -> if isdefined(e_params, s)
                            wp(n, getfield(e_params, s))
                        end
        wf("alpha", :α)
        wf("sigmazero", :σ₀)
        wf("tauzero", :τ₀)
        wf("tildetauzero", :τ̃₀)
        wf("delta", :δ)
        wf("lambda", :λ)
        wf("theta", :θ)
        wf("maxiter", :maxiter)
        wf("noiselevel", :noise_level)
        wf("shakenoiselevel", :shake_noise_level)
        wf("shake", :shake)
        wf("timestep", :timestep)
        wf("displacementcount", :displacementcount)
        wf("phantomrho", :phantom_ρ)
        if isdefined(e_params, :σ₀)
            wp("sigma", (e_params.σ₀ == 1 ? "" : "$(e_params.σ₀)") * "\\sigma_{\\max}")
        end
    end
end                

function run_experiments(;visualise=true,
                          recalculate=true,
                          experiments,
                          save_prefix=default_save_prefix,
                          fullscreen=false,
                          kwargs...)

    # Create visualisation
    if visualise
        rc = Channel(1)
        visproc = Threads.@spawn bg_visualise_enhanced(rc, fullscreen=fullscreen)
        bind(rc, visproc)
        vis = rc
    else
        vis = false
    end

    # Run all experiments
    for e ∈ experiments

        # Parameters for this experiment
        e_params = default_params ⬿ e.params ⬿ kwargs
        ename = name(e, e_params)
        e_params = e_params ⬿ (save_prefix = save_prefix * ename,
                                dynrange = e.imgen.dynrange,
                                Λ = e.imgen.Λ)

        if recalculate || !isfile(e_params.save_prefix * ".txt")
            println("Running experiment \"$(ename)\"")

            # Start data generation task
            datachannel = Channel{PetOnlineData{e.DisplacementT}}(2)
            gentask = Threads.@spawn e.imgen.f(e.DisplacementT, datachannel, e_params)
            bind(datachannel, gentask)

            # Run algorithm
            iterate = curry(iterate_visualise, datachannel,
                            State(vis, nothing, 0.0, nothing, nothing, false))
                            
            x, y, st = e.mod.solve(e.DisplacementT;
                                   dim=e.imgen.dim,
                                   iterate=iterate,
                                   params=e_params)

            # Clear non-saveable things
            st = @set st.vis = nothing

            println("Wasted_time: $(st.wasted_time)s")

            if e_params.save_results
                println("Saving " * e_params.save_prefix * "(.txt,_hifi.txt,_params.tex)")

                perffile = e_params.save_prefix * ".txt"
                hififile = e_params.save_prefix * "_hifi.txt"
                texfile = e_params.save_prefix * "_params.tex"
                # datafile = e_params.save_prefix * ".jld2"

                write_log(perffile, st.log, "# params = $(e_params)\n")
                write_log(hififile, st.log_hifi, "# params = $(e_params)\n")
                write_tex(texfile, e_params)
                # @save datafile x y st params
            end

            close(datachannel)
            wait(gentask)

            if st.aborted
                break
            end
        else
            println("Skipping already computed experiment \"$(ename)\"")
            # texfile = e_params.save_prefix * "_params.tex"
            # write_tex(texfile, e_params)
        end
    end

    if visualise
        # Tell subprocess to finish, and wait
        put!(rc, nothing)
        close(rc)
        wait(visproc)
    end

    return
end

#######################
# Demos and batch runs
#######################

function demo(experiment; kwargs...)
    run_experiments(;experiments=(experiment,),
                    save_results=false,
                    save_images=false,
                    visualise=true,
                    recalculate=true,
                    verbose_iter=1,
                    fullscreen=true,
                    kwargs...)
end

demo_pet1 = () -> demo(pet_experiments_pdps_known[1]) # Dual scaling
demo_pet2 = () -> demo(pet_experiments_pdps_known[2]) # Greedy
demo_pet3 = () -> demo(pet_experiments_pdps_known[3]) # No Prediction
demo_pet4 = () -> demo(pet_experiments_pdps_known[4]) # Primal only
demo_pet5 = () -> demo(pet_experiments_pdps_known[5]) # Proximal (old)
demo_pet6 = () -> demo(pet_experiments_pdps_known[6]) # Rotation


function batchrun_pet(kwargs...)
    run_experiments(;experiments=pet_experiments_all,
                    save_results=true,
                    save_images=true,
                    visualise=false,
                    recalculate=false,
                    kwargs...)
end

######################################################
# Iterator that does visualisation and log collection
######################################################

function rescale(arr, new_range)
    old_min = minimum(arr)
    old_max = maximum(arr)
    
    scale_factor = (new_range[2] - new_range[1]) / (old_max - old_min)
    
    scaled_arr = new_range[1] .+ (arr .- old_min) * scale_factor
    
    return scaled_arr
end

function iterate_visualise(datachannel::Channel{PetOnlineData{DisplacementT}},
                           st :: State,
                           step :: Function,
                           params :: NamedTuple) where DisplacementT
    try
        sc = nothing

        d = take!(datachannel)

        for iter=1:params.maxiter 
            dnext = take!(datachannel)
            st = step(d.sinogram_noisy, d.v, d.theta, d.b_true, d.S) do calc_objective
                stn = st

                if isnothing(stn.start_time)
                    # The Julia precompiler is a miserable joke, apparently not crossing module
                    # boundaries, so only start timing after the first iteration.
                    stn = @set stn.start_time=secs_ns()
                end

                verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0

                # Normalise movement to image dimensions so
                # our TikZ plotting code doesn't need to know
                # the image pixel size.
                sc = 1.0./maximum(size(d.b_true))
                    
                if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0)
                    verb_start = secs_ns()
                    tm = verb_start - stn.start_time - stn.wasted_time
                    value, x, v, vhist = calc_objective()

                    entry = LogEntry(iter, tm, value,
                                     #sc*d.v_cumul_true[1],
                                     #sc*d.v_cumul_true[2],
                                     #sc*v[1], sc*v[2], 
                                     assess_psnr(x, d.b_true),
                                     assess_ssim(x, d.b_true),
                                     #assess_psnr(d.b_noisy, d.b_true),
                                     #assess_ssim(d.b_noisy, d.b_true)
                                     )

                    # (**) Collect a singly-linked list of log to avoid array resizing
                    # while iterating
                    stn = @set stn.log=LinkedListEntry(entry, stn.log)
                    
                    if !isnothing(vhist)
                        vhist=vhist.*sc
                    end

                    if verb
                        @printf("%d/%d J=%f, PSNR=%f, SSIM=%f, avg. FPS=%f\n",
                                iter, params.maxiter, value, entry.psnr,
                                entry.ssim, entry.iter/entry.time)
                        if isa(stn.vis, Channel)
                            put_onlylatest!(stn.vis, ((rescale(backproject!(d.b_true,d.sinogram_noisy),(0.0,params.dynrange)), x),
                                                        params.plot_movement,
                                                        stn.log, vhist))

                        end
                    end

                    if params.save_images && (!haskey(params, :save_images_iters)
                                              || iter ∈ params.save_images_iters)
                        fn = (t, ext) -> "$(params.save_prefix)_$(t)_frame$(iter).$(ext)"
                        normalise = (data) -> data./maximum(data)
                        # save(File(format"PNG", fn("true", "png")), mapped_img(d.b_true, ColorSchemes.cmyk.colors[1:end]))
                        # save(File(format"PNG", fn("true_sinogram", "png")), mapped_img(normalise(d.sinogram_true), ColorSchemes.cmyk.colors[1:end]))
                        # save(File(format"PNG", fn("data_sinogram", "png")), mapped_img(normalise(d.S.*d.sinogram_noisy), ColorSchemes.cmyk.colors[1:end]))
                        save(File(format"PNG", fn("reco", "png")), mapped_img(x, ColorSchemes.cmyk.colors[1:end]))
                        if !isnothing(vhist)
                            open(fn("movement", "txt"), "w") do io
                                writedlm(io, ["est_y" "est_x"])
                                writedlm(io, vhist)
                            end
                        end
                    end

                    stn = @set stn.wasted_time += (secs_ns() - verb_start)

                    return stn
                end

                hifientry = LogEntryHiFi(iter, sc*d.v_cumul_true[1], sc*d.v_cumul_true[2])
                st = @set st.log_hifi=LinkedListEntry(hifientry, st.log_hifi)

                return st
            end
            d=dnext
        end
    catch ex
        if params.handle_interrupt && isa(ex, InterruptException)
            # If SIGINT is received (user pressed ^C), terminate computations,
            # returning current status. Effectively, we do not call `step()` again,
            # ending the iterations, but letting the algorithm finish up.
            # Assuming (**) above occurs atomically, `st.log` should be valid, but
            # any results returned by the algorithm itself may be partial, as for
            # reasons of efficiency we do *not* store results of an iteration until
            # the next iteration is finished.
            printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202)
            st = @set st.aborted = true
        else
            rethrow(ex)
        end
    end
    
    return st
end

function bg_visualise_enhanced(rc; fullscreen=false)
    process_channel(rc) do d
        imgs, plot_movement, log, vhist = d
        do_visualise(imgs, refresh=false, fullscreen=fullscreen)
        # Overlay movement
        GR.settextcolorind(5)
        GR.setcharheight(0.015)
        GR.settextpath(GR.TEXT_PATH_RIGHT)
        tx, ty = GR.wctondc(0, 1)
        GR.text(tx, ty, @sprintf "FPS %.1f, SSIM %.2f, PSNR %.1f" (log.value.iter/log.value.time) log.value.ssim log.value.psnr)
        if plot_movement
            sc=1.0
            p=unfold_linked_list(log)
            x=map(e -> 1.5+sc*e.v_cumul_true_x, p)
            y=map(e -> 0.5+sc*e.v_cumul_true_y, p)
            GR.setlinewidth(2)
            GR.setlinecolorind(2)
            GR.polyline(x, y)
            x=map(e -> 1.5+sc*e.v_cumul_est_x, p)
            y=map(e -> 0.5+sc*e.v_cumul_est_y, p)
            GR.setlinecolorind(3)
            GR.polyline(x, y)
            if vhist != nothing
                GR.setlinecolorind(4)
                x=map(v -> 1.5+sc*v, vhist[:,2])
                y=map(v -> 0.5+sc*v, vhist[:,1])
                GR.polyline(x, y)
            end
        end
        GR.updatews() 
    end
end

# Clip image values to allowed range
clip = x -> min(max(x, 0.0), 1.0)

# Apply a colourmap (vector of RGB objects) to raw image data
function mapped_img(im, cmap)
    l = length(cmap)
    apply = t -> cmap[1+round(UInt16, clip(t) * (l-1))]
    return apply.(im)
end

###############
# Precompiling
###############

# precompile(Tuple{typeof(GR.drawimage), Float64, Float64, Float64, Float64, Int64, Int64, Array{UInt32, 2}})
# precompile(Tuple{Type{Plots.Plot{T} where T<:RecipesBase.AbstractBackend}, Plots.GRBackend, Int64, Base.Dict{Symbol, Any}, Base.Dict{Symbol, Any}, Array{Plots.Series, 1}, Nothing, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Base.Dict{Any, Plots.Subplot{T} where T<:RecipesBase.AbstractBackend}, Plots.EmptyLayout, Array{Plots.Subplot{T} where T<:RecipesBase.AbstractBackend, 1}, Bool})
# precompile(Tuple{typeof(Plots._plot!), Plots.Plot{Plots.GRBackend}, Base.Dict{Symbol, Any}, Tuple{Array{ColorTypes.Gray{Float64}, 2}}})

end # Module

mercurial