__precompile__()

module PET

########################
# Load external modules
########################

using Printf
using FileIO
#using JLD2
using Setfield
using ImageQualityIndexes: assess_psnr, assess_ssim
using DelimitedFiles
import GR

using AlgTools.Util
using AlgTools.StructTools
using AlgTools.LinkedLists
using AlgTools.Comms
using ImageTools.Visualise: secs_ns, grayimg, do_visualise
using ImageTools.ImFilter: gaussian

# For PET
using ColorSchemes

#####################
# Load local modules
#####################a
include("AlgorithmNew.jl")
include("AlgorithmProximal.jl")
#include("PlotResults.jl")

import .AlgorithmNew
import .AlgorithmProximal

using ..Radon: backproject!
using ..ImGenerate
using ..OpticalFlow
using ..Run
#using .PlotResults


##############
# Our exports
##############

export demo_petS1, demo_petS2, demo_petS3,
       demo_petS4, demo_petS5, demo_petS6, demo_petS7,
       demo_petB1, demo_petB2, demo_petB3,
       demo_petB4, demo_petB5, demo_petB6, demo_petB7,
       batchrun_shepplogan, batchrun_brainphantom, batchrun_pet
       #plot_pet

###################################
# Parameterisation and experiments
###################################

const default_save_prefix="img/"

const default_params = (
    ρ = 0,
    verbose_iter = 100,
    maxiter = 4000,
    save_results = true,
    save_images = true,
    save_images_iters = Set([100, 300, 500, 800, 1000,
                             1300, 1500, 1800, 2000,
                             2300, 2500, 2800, 3000,
                             3300, 3500, 3800, 4000]),
    pixelwise_displacement=false,
    dual_flow = true,
    prox_predict = true,
    handle_interrupt = true,
    init = :zero,
    plot_movement = false,
    stable_interval = Set(0),
)

const p_known₀_pet = default_params ⬿ (
    noise_level = 0.5,
    shake_noise_level = 0.1,
    shake = 1.0,
    rotation_factor = 0.075,
    rotation_noise_level = 0.0075,
    α = 0.15,
    ρ̃₀ = 1.0,
    σ̃₀ = 1.0,
    δ = 0.9,
    σ₀ = 1.0,
    τ₀ = 0.9,
    λ = 1,
    radondims = [128,64],
    sz = (256,256),
    scale = 1,
    c = 1.0,
    sino_sparsity = 0.5,
    L = 300.0,
    L_experiment = false,
    #stable_interval = Set(0),
    stable_interval = union(Set(1000:2000),Set(3500:4000)),
)

const p_known₀_petb = p_known₀_pet# ⬿ ( seed = 313159, )

const shepplogan = imgen_shepplogan_radon(p_known₀_pet.sz)

const brainphantom = imgen_brainphantom_radon(p_known₀_pet.sz)

const shepplogan_experiments_pdps_known = (
    Experiment(AlgorithmNew, DisplacementConstant, shepplogan,
               p_known₀_pet ⬿ (predictor=DualScaling(),)),
    Experiment(AlgorithmNew, DisplacementConstant, shepplogan,
               p_known₀_pet ⬿ (predictor=Greedy(),)),
    Experiment(AlgorithmNew, DisplacementConstant, shepplogan,
               p_known₀_pet ⬿ (predictor=nothing,),),
    Experiment(AlgorithmNew, DisplacementConstant, shepplogan,
               p_known₀_pet ⬿ (predictor=PrimalOnly(),)),
    Experiment(AlgorithmProximal, DisplacementConstant, shepplogan,
               p_known₀_pet ⬿ (phantom_ρ = 100,)),
    Experiment(AlgorithmNew, DisplacementConstant, shepplogan,
               p_known₀_pet ⬿ (predictor=Rotation(),)),
    Experiment(AlgorithmNew, DisplacementConstant, shepplogan,
               p_known₀_pet ⬿ (predictor=ZeroDual(),)),
)

const brainphantom_experiments_pdps_known = (
    Experiment(AlgorithmNew, DisplacementConstant, brainphantom,
               p_known₀_petb ⬿ (predictor=DualScaling(),)),
    Experiment(AlgorithmNew, DisplacementConstant, brainphantom,
               p_known₀_petb ⬿ (predictor=Greedy(),)),
    Experiment(AlgorithmNew, DisplacementConstant, brainphantom,
               p_known₀_petb ⬿ (predictor=nothing,),),
    Experiment(AlgorithmNew, DisplacementConstant, brainphantom,
               p_known₀_petb ⬿ (predictor=PrimalOnly(),)),
    Experiment(AlgorithmProximal, DisplacementConstant, brainphantom,
               p_known₀_petb ⬿ (phantom_ρ = 100,)),
    Experiment(AlgorithmNew, DisplacementConstant, brainphantom,
               p_known₀_petb ⬿ (predictor=Rotation(),)),
    Experiment(AlgorithmNew, DisplacementConstant, brainphantom,
               p_known₀_petb ⬿ (predictor=ZeroDual(),)),
)


const shepplogan_experiments_all = Iterators.flatten((
    shepplogan_experiments_pdps_known,
))

const brainphantom_experiments_all = Iterators.flatten((
    brainphantom_experiments_pdps_known,
))

#######################
# Demos and batch runs
#######################

function demo(experiment; kwargs...)
    run_experiments(;experiments=(experiment,),
                    save_prefix=default_save_prefix,
                    visfn=iterate_visualise_pet,
                    datatype=PetOnlineData,
                    save_results=false,
                    save_images=false,
                    visualise=true,
                    recalculate=true,
                    verbose_iter=50,
                    fullscreen=true,
                    kwargs...)
end

demo_petS1 = () -> demo(shepplogan_experiments_pdps_known[1]) # Dual scaling
demo_petS2 = () -> demo(shepplogan_experiments_pdps_known[2]) # Greedy
demo_petS3 = () -> demo(shepplogan_experiments_pdps_known[3]) # No Prediction
demo_petS4 = () -> demo(shepplogan_experiments_pdps_known[4]) # Primal only
demo_petS5 = () -> demo(shepplogan_experiments_pdps_known[5]) # Proximal (old)
demo_petS6 = () -> demo(shepplogan_experiments_pdps_known[6]) # Rotation
demo_petS7 = () -> demo(shepplogan_experiments_pdps_known[7]) # Zero dual

demo_petB1 = () -> demo(brainphantom_experiments_pdps_known[1]) # Dual scaling
demo_petB2 = () -> demo(brainphantom_experiments_pdps_known[2]) # Greedy
demo_petB3 = () -> demo(brainphantom_experiments_pdps_known[3]) # No Prediction
demo_petB4 = () -> demo(brainphantom_experiments_pdps_known[4]) # Primal only
demo_petB5 = () -> demo(brainphantom_experiments_pdps_known[5]) # Proximal (old)
demo_petB6 = () -> demo(brainphantom_experiments_pdps_known[6]) # Rotation
demo_petB7 = () -> demo(brainphantom_experiments_pdps_known[7]) # Zero dual


function batchrun_shepplogan(;kwargs...)
    run_experiments(;experiments=shepplogan_experiments_all,
                    visfn=iterate_visualise_pet,
                    datatype=PetOnlineData,
                    save_prefix=default_save_prefix,
                    save_results=true,
                    save_images=true,
                    visualise=false,
                    recalculate=false,
                    kwargs...)
end

function batchrun_brainphantom(;kwargs...)
    run_experiments(;experiments=brainphantom_experiments_all,
                    visfn=iterate_visualise_pet,
                    datatype=PetOnlineData,
                    save_prefix=default_save_prefix,
                    save_results=true,
                    save_images=true,
                    visualise=false,
                    recalculate=false,
                    kwargs...)
end

function batchrun_pet(;kwargs...)
    batchrun_shepplogan(;kwargs...)
    batchrun_brainphantom(;kwargs...)
end

######################################################
# Iterator that does visualisation and log collection
######################################################

function rescale(arr, new_range)
    old_min = minimum(arr)
    old_max = maximum(arr)
    scale_factor = (new_range[2] - new_range[1]) / (old_max - old_min)
    scaled_arr = new_range[1] .+ (arr .- old_min) * scale_factor
    return scaled_arr
end

function iterate_visualise_pet(datachannel::Channel{PetOnlineData{DisplacementT}},
                               st :: State,
                               step :: Function,
                               params :: NamedTuple) where DisplacementT
    try
        sc = nothing

        d = take!(datachannel)

        for iter=1:params.maxiter
            dnext = take!(datachannel)
            st = step(d.sinogram_noisy, d.v, d.theta, d.b_true, d.S) do calc_objective
                stn = st

                if isnothing(stn.start_time)
                    # The Julia precompiler is a miserable joke, apparently not crossing module
                    # boundaries, so only start timing after the first iteration.
                    stn = @set stn.start_time=secs_ns()
                end

                verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0

                # Normalise movement to image dimensions so
                # our TikZ plotting code doesn't need to know
                # the image pixel size.
                sc = 1.0./maximum(size(d.b_true))
                    
                if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0)
                    verb_start = secs_ns()
                    tm = verb_start - stn.start_time - stn.wasted_time
                    value, x, v, vhist = calc_objective()

                    entry = LogEntry(iter, tm, value,
                                     #sc*d.v_cumul_true[1],
                                     #sc*d.v_cumul_true[2],
                                     #sc*v[1], sc*v[2],
                                     assess_psnr(x, d.b_true),
                                     assess_ssim(x, d.b_true),
                                     #assess_psnr(d.b_noisy, d.b_true),
                                     #assess_ssim(d.b_noisy, d.b_true)
                                     )

                    # (**) Collect a singly-linked list of log to avoid array resizing
                    # while iterating
                    stn = @set stn.log=LinkedListEntry(entry, stn.log)
                    
                    if !isnothing(vhist)
                        vhist=vhist.*sc
                    end

                    if verb
                        @printf("%d/%d J=%f, PSNR=%f, SSIM=%f, avg. FPS=%f\n",
                                iter, params.maxiter, value, entry.psnr,
                                entry.ssim, entry.iter/entry.time)
                        if isa(stn.vis, Channel)
                            put_onlylatest!(stn.vis, ((rescale(backproject!(d.b_true,d.sinogram_noisy),(0.0,params.dynrange)), x),
                                                        params.plot_movement,
                                                        stn.log, vhist))

                        end
                    end

                    if params.save_images && (!haskey(params, :save_images_iters)
                                              || iter ∈ params.save_images_iters)
                        fn = (t, ext) -> "$(params.save_prefix)_$(t)_frame$(iter).$(ext)"
                        normalise = (data) -> data./maximum(data)
                        # save(File(format"PNG", fn("true", "png")), mapped_img(d.b_true, ColorSchemes.cmyk.colors[1:end]))
                        # save(File(format"PNG", fn("true_sinogram", "png")), mapped_img(normalise(d.sinogram_true), ColorSchemes.cmyk.colors[1:end]))
                        # save(File(format"PNG", fn("data_sinogram", "png")), mapped_img(normalise(d.S.*d.sinogram_noisy), ColorSchemes.cmyk.colors[1:end]))
                        save(File(format"PNG", fn("reco", "png")), mapped_img(x, ColorSchemes.cmyk.colors[1:end]))
                        if !isnothing(vhist)
                            open(fn("movement", "txt"), "w") do io
                                writedlm(io, ["est_y" "est_x"])
                                writedlm(io, vhist)
                            end
                        end
                    end

                    stn = @set stn.wasted_time += (secs_ns() - verb_start)

                    return stn
                end

                hifientry = LogEntryHiFi(iter, sc*d.v_cumul_true[1], sc*d.v_cumul_true[2])
                st = @set st.log_hifi=LinkedListEntry(hifientry, st.log_hifi)

                return st
            end
            d=dnext
        end
    catch ex
        if params.handle_interrupt && isa(ex, InterruptException)
            # If SIGINT is received (user pressed ^C), terminate computations,
            # returning current status. Effectively, we do not call `step()` again,
            # ending the iterations, but letting the algorithm finish up.
            # Assuming (**) above occurs atomically, `st.log` should be valid, but
            # any results returned by the algorithm itself may be partial, as for
            # reasons of efficiency we do *not* store results of an iteration until
            # the next iteration is finished.
            printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202)
            st = @set st.aborted = true
        else
            rethrow(ex)
        end
    end
    
    return st
end

# Clip image values to allowed range
clip = x -> min(max(x, 0.0), 1.0)

# Apply a colourmap (vector of RGB objects) to raw image data
function mapped_img(im, cmap)
    l = length(cmap)
    apply = t -> cmap[1+round(UInt16, clip(t) * (l-1))]
    return apply.(im)
end


#########################
# Plotting SSIM and PSNR
#########################

#function plot_pet(kwargs...)
#    ssim_plot("shepplogan")
#    psnr_plot("shepplogan")
#    fv_plot("shepplogan")
#    ssim_plot("brainphantom")
#    psnr_plot("brainphantom")
#    fv_plot("brainphantom")
#end

end # Module
