##################
# Experiment running and interactive visualisation
##################

module Run

import GR
using Setfield
using Printf
using ImageQualityIndexes: assess_psnr, assess_ssim
using FileIO
#using JLD2
using DelimitedFiles

using AlgTools.Util
using AlgTools.StructTools
using AlgTools.LinkedLists
using AlgTools.Comms

using ImageTools.Visualise: secs_ns, grayimg, do_visualise

using ..ImGenerate
using ..OpticalFlow: identifier, DisplacementFull, DisplacementConstant

export run_experiments,
       Experiment,
       State,
       LogEntry,
       LogEntryHiFi

################
# Experiment
################

struct Experiment
    mod :: Module
    DisplacementT :: Type
    imgen :: ImGen
    params :: NamedTuple
end

function Base.show(io::IO, e::Experiment)
    displacementname(::Type{DisplacementFull}) = "DisplacementFull"
    displacementname(::Type{DisplacementConstant}) = "DisplacementConstant"
    print(io, "
    mod: $(e.mod)
    DisplacementT: $(displacementname(e.DisplacementT))
    imgen: $(e.imgen.name) $(e.imgen.dim[1])×$(e.imgen.dim[2])
    params: $(e.params ⬿ (kernel = "(not shown)",))
    ")
end

################
# Log
################

struct LogEntry <: IterableStruct
    iter :: Int
    time :: Float64
    function_value :: Float64
    #v_cumul_true_y :: Float64
    #v_cumul_true_x :: Float64
    #v_cumul_est_y :: Float64
    #v_cumul_est_x :: Float64
    psnr :: Float64
    ssim :: Float64
    #psnr_data :: Float64
    #ssim_data :: Float64
end

struct LogEntryHiFi <: IterableStruct
    iter :: Int
    v_cumul_true_y :: Float64
    v_cumul_true_x :: Float64
end

###############
# Main routine
###############

struct State
    vis :: Union{Channel,Bool,Nothing}
    start_time :: Union{Real,Nothing}
    wasted_time :: Real
    log :: LinkedList{LogEntry}
    log_hifi :: LinkedList{LogEntryHiFi}
    aborted :: Bool
end

function name(e::Experiment, p)
    ig = e.imgen
    id = if haskey(p, :predictor) && ~isnothing(p.predictor)
        identifier(p.predictor)
    else
        e.mod.identifier
    end
    if haskey(p, :variant)
        id *= "_" * p.variant
    end
    # return "$(ig.name)_$(e.mod.identifier)_$(@sprintf "%x" hash(p))"
    return "$(ig.name)_$(id)_$(Int64(100*p.α))_$(Int64(10000*p.σ₀))_$(Int64(10000*p.τ₀))"
end

function write_tex(texfile, e_params)
    open(texfile, "w") do io
        wp = (n, v) -> println(io, "\\def\\EXPPARAM$(n){$(v)}")
        wf = (n, s) -> if isdefined(e_params, s)
                            wp(n, getfield(e_params, s))
                        end
        wf("alpha", :α)
        wf("sigmazero", :σ₀)
        wf("tauzero", :τ₀)
        wf("tildetauzero", :τ̃₀)
        wf("delta", :δ)
        wf("lambda", :λ)
        wf("theta", :θ)
        wf("maxiter", :maxiter)
        wf("noiselevel", :noise_level)
        wf("shakenoiselevel", :shake_noise_level)
        wf("shake", :shake)
        wf("timestep", :timestep)
        wf("displacementcount", :displacementcount)
        wf("phantomrho", :phantom_ρ)
        if isdefined(e_params, :σ₀)
            wp("sigma", (e_params.σ₀ == 1 ? "" : "$(e_params.σ₀)") * "\\sigma_{\\max}")
        end
    end
end

function run_experiments(;visualise=true,
                          visfn = iterate_visualise,
                          datatype = OnlineData,
                          recalculate=true,
                          experiments,
                          save_prefix="",
                          fullscreen=false,
                          kwargs...)

    # Create visualisation
    if visualise
        rc = Channel(1)
        visproc = Threads.@spawn bg_visualise_enhanced(rc, fullscreen=fullscreen)
        bind(rc, visproc)
        vis = rc
    else
        vis = false
    end

    # Run all experiments
    for e ∈ experiments

        # Parameters for this experiment
        e_params = e.params ⬿ kwargs
        ename = name(e, e_params)
        e_params = e_params ⬿ (save_prefix = save_prefix * ename,
                                dynrange = e.imgen.dynrange,
                                Λ = e.imgen.Λ)

        if recalculate || !isfile(e_params.save_prefix * ".txt")
            println("Running experiment \"$(ename)\"")

            # Start data generation task
            datachannel = Channel{datatype{e.DisplacementT}}(2)
            gentask = Threads.@spawn e.imgen.f(e.DisplacementT, datachannel, e_params)
            bind(datachannel, gentask)

            # Run algorithm
            iterate = curry(visfn, datachannel,
                            State(vis, nothing, 0.0, nothing, nothing, false))
                            
            x, y, st = e.mod.solve(e.DisplacementT;
                                   dim=e.imgen.dim,
                                   iterate=iterate,
                                   params=e_params)

            # Clear non-saveable things
            st = @set st.vis = nothing

            println("Wasted_time: $(st.wasted_time)s")

            if e_params.save_results
                println("Saving " * e_params.save_prefix * "(.txt,_hifi.txt,_params.tex)")

                perffile = e_params.save_prefix * ".txt"
                hififile = e_params.save_prefix * "_hifi.txt"
                texfile = e_params.save_prefix * "_params.tex"
                # datafile = e_params.save_prefix * ".jld2"

                write_log(perffile, st.log, "# params = $(e_params)\n")
                #write_log(hififile, st.log_hifi, "# params = $(e_params)\n")
                #write_tex(texfile, e_params)
                # @save datafile x y st params
            end

            close(datachannel)
            wait(gentask)

            if st.aborted
                break
            end
        else
            println("Skipping already computed experiment \"$(ename)\"")
            # texfile = e_params.save_prefix * "_params.tex"
            # write_tex(texfile, e_params)
        end
    end

    if visualise
        # Tell subprocess to finish, and wait
        put!(rc, nothing)
        close(rc)
        wait(visproc)
    end

    return
end


######################################################
# Iterator that does visualisation and log collection
######################################################

function iterate_visualise(datachannel::Channel{OnlineData{DisplacementT}},
                           st :: State,
                           step :: Function,
                           params :: NamedTuple) where DisplacementT
    try
        sc = nothing

        d = take!(datachannel)

        for iter=1:params.maxiter
            dnext = take!(datachannel)
            st = step(d.b_noisy, d.v, dnext.b_noisy) do calc_objective
                stn = st

                if isnothing(stn.start_time)
                    # The Julia precompiler is a miserable joke, apparently not crossing module
                    # boundaries, so only start timing after the first iteration.
                    stn = @set stn.start_time=secs_ns()
                end

                verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0

                # Normalise movement to image dimensions so
                # our TikZ plotting code doesn't need to know
                # the image pixel size.
                sc = 1.0./maximum(size(d.b_true))
                    
                if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0)
                    verb_start = secs_ns()
                    tm = verb_start - stn.start_time - stn.wasted_time
                    value, x, v, vhist = calc_objective()

                    entry = LogEntry(iter, tm, value,
                                     #sc*d.v_cumul_true[1],
                                     #sc*d.v_cumul_true[2],
                                     #sc*v[1], sc*v[2],
                                     assess_psnr(x, d.b_true),
                                     assess_ssim(x, d.b_true),
                                     #assess_psnr(d.b_noisy, d.b_true),
                                     #assess_ssim(d.b_noisy, d.b_true)
                                     )

                    # (**) Collect a singly-linked list of log to avoid array resizing
                    # while iterating
                    stn = @set stn.log=LinkedListEntry(entry, stn.log)
                    
                    if !isnothing(vhist)
                        vhist=vhist.*sc
                    end

                    if verb
                        @printf("%d/%d J=%f, PSNR=%f, SSIM=%f, avg. FPS=%f\n",
                                iter, params.maxiter, value, entry.psnr,
                                entry.ssim, entry.iter/entry.time)
                        if isa(stn.vis, Channel)
                            put_onlylatest!(stn.vis, ((d.b_noisy, x),
                                                      params.plot_movement,
                                                      stn.log, vhist))

                        end
                    end

                    if params.save_images && (!haskey(params, :save_images_iters)
                                              || iter ∈ params.save_images_iters)
                        fn = (t, ext) -> "$(params.save_prefix)_$(t)_frame$(iter).$(ext)"
                        # save(File(format"PNG", fn("true", "png")), grayimg(d.b_true))
                        # save(File(format"PNG", fn("data", "png")), grayimg(d.b_noisy))
                        save(File(format"PNG", fn("reco", "png")), grayimg(x))
                        if !isnothing(vhist)
                            open(fn("movement", "txt"), "w") do io
                                writedlm(io, ["est_y" "est_x"])
                                writedlm(io, vhist)
                            end
                        end
                    end

                    stn = @set stn.wasted_time += (secs_ns() - verb_start)

                    return stn
                end

                hifientry = LogEntryHiFi(iter, sc*d.v_cumul_true[1], sc*d.v_cumul_true[2])
                st = @set st.log_hifi=LinkedListEntry(hifientry, st.log_hifi)

                return st
            end
            d=dnext
        end
    catch ex
        if params.handle_interrupt && isa(ex, InterruptException)
            # If SIGINT is received (user pressed ^C), terminate computations,
            # returning current status. Effectively, we do not call `step()` again,
            # ending the iterations, but letting the algorithm finish up.
            # Assuming (**) above occurs atomically, `st.log` should be valid, but
            # any results returned by the algorithm itself may be partial, as for
            # reasons of efficiency we do *not* store results of an iteration until
            # the next iteration is finished.
            printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202)
            st = @set st.aborted = true
        else
            rethrow(ex)
        end
    end
    
    return st
end

function bg_visualise_enhanced(rc; fullscreen=false)
    process_channel(rc) do d
        imgs, plot_movement, log, vhist = d
        do_visualise(imgs, refresh=false, fullscreen=fullscreen)
        # Overlay movement
        GR.settextcolorind(5)
        GR.setcharheight(0.015)
        GR.settextpath(GR.TEXT_PATH_RIGHT)
        tx, ty = GR.wctondc(0, 1)
        GR.text(tx, ty, @sprintf "FPS %.1f, SSIM %.2f, PSNR %.1f" (log.value.iter/log.value.time) log.value.ssim log.value.psnr)
        if plot_movement
            sc=1.0
            p=unfold_linked_list(log)
            x=map(e -> 1.5+sc*e.v_cumul_true_x, p)
            y=map(e -> 0.5+sc*e.v_cumul_true_y, p)
            GR.setlinewidth(2)
            GR.setlinecolorind(2)
            GR.polyline(x, y)
            x=map(e -> 1.5+sc*e.v_cumul_est_x, p)
            y=map(e -> 0.5+sc*e.v_cumul_est_y, p)
            GR.setlinecolorind(3)
            GR.polyline(x, y)
            if vhist != nothing
                GR.setlinecolorind(4)
                x=map(v -> 1.5+sc*v, vhist[:,2])
                y=map(v -> 0.5+sc*v, vhist[:,1])
                GR.polyline(x, y)
            end
        end
        GR.updatews()
    end
end

end # module
