##################################################
# Visualising and data-collecting iteration tools
##################################################

module Visualise

using Printf
using Distributed
using FileIO
using Setfield
using Images, Plots, Measures

using AlgTools.Util
using AlgTools.StructTools
using AlgTools.LinkedLists

##############
# Our exports
##############

export LogEntry,
       bg_visualise,
       visualise,
       clip,
       grayimg,
       secs_ns,
       iterate_visualise,
       initialise_visualisation,
       finalise_visualisation

##################
# Data structures
##################

struct LogEntry <: IterableStruct
    iter :: Int
    time :: Float64
    function_value :: Float64
end

struct State
    vis :: Union{Distributed.RemoteChannel,Bool,Nothing}
    visproc :: Union{Nothing,Future}
    start_time :: Union{Real,Nothing}
    wasted_time :: Real
    log :: LinkedList{LogEntry}
end

##################
# Helper routines
##################

@inline function secs_ns()
    return convert(Float64, time_ns())*1e-9
end

clip = x -> min(max(x, 0.0), 1.0)
grayimg = im -> Gray.(clip.(im))

################
# Visualisation
################

function bg_visualise(rc)
    while true
        imgs=take!(rc)
        # Take only the latest image to visualise
        while isready(rc)
            imgs=take!(rc)
        end
        # We're done if we were fed an empty image
        if isnothing(imgs)
            break
        end
        do_visualise(imgs)
    end
    return
end

function do_visualise(imgs)
    plt = im -> plot(grayimg(im), showaxis=false, grid=false, aspect_ratio=:equal, margin=2mm)
    display(plot([plt(imgs[i]) for i =1:length(imgs)]..., reuse=true, margin=0mm))
end

function visualise(rc_or_vis, imgs)
    if isa(rc_or_vis, RemoteChannel)
        rc = rc_or_vis
        while isready(rc)
            take!(rc)
        end
        put!(rc, imgs)
    elseif isa(rc_or_vis, Bool) && rc_or_vis
        do_visualise(imgs)
    end
end

######################################################
# Iterator that does visualisation and log collection
######################################################

function iterate_visualise(st :: State,
                           step :: Function,
                           params :: NamedTuple) where DisplacementT
    try
        for iter=1:params.maxiter 
            st = step() do calc_objective
                if isnothing(st.start_time)
                    # The Julia precompiler is a miserable joke, apparently not crossing module
                    # boundaries, so only start timing after the first iteration.
                    st = @set st.start_time=secs_ns()
                end

                verb = params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0
                    
                if verb || iter ≤ 20 || (iter ≤ 200 && mod(iter, 10) == 0)
                    verb_start = secs_ns()
                    tm = verb_start - st.start_time - st.wasted_time
                    value, x = calc_objective()

                    entry = LogEntry(iter, tm, value)

                    # (**) Collect a singly-linked list of log to avoid array resizing
                    # while iterating
                    st = @set st.log=LinkedListEntry(entry, st.log)
                    
                    if verb
                        @printf("%d/%d J=%f\n", iter, params.maxiter, value)
                        visualise(st.vis, (x,))
                    end

                    if params.save_iterations
                        fn = t -> "$(params.save_prefix)_$(t)_iter$(iter).png"
                        save(File(format"PNG", fn("reco")), grayimg(x))
                    end

                    st = @set st.wasted_time += (secs_ns() - verb_start)
                end
                
                return st
            end
        end
    catch ex
        if isa(ex, InterruptException)
            # If SIGINT is received (user pressed ^C), terminate computations,
            # returning current status. Effectively, we do not call `step()` again,
            # ending the iterations, but letting the algorithm finish up.
            # Assuming (**) above occurs atomically, `st.log` should be valid, but
            # any results returned by the algorithm itself may be partial, as for
            # reasons of efficiency we do *not* store results of an iteration until
            # the next iteration is finished.
            printstyled("\rUser interrupt—finishing up.\n", bold=true, color=202)
        else
            throw(ex)
        end
    end
    
    return st
end

####################
# Launcher routines
####################

function initialise_visualisation(visualise; iterator=iterate_visualise)
    # Create visualisation
    if visualise
        rc = RemoteChannel()
        visproc = @spawn bg_visualise(rc)
        vis =rc
        #vis = true

        sleep(0.01)
    else
        vis = false
        visproc = nothing
    end

    st = State(vis, visproc, nothing, 0.0, nothing)
    iterate = curry(iterate_visualise, st)

    return st, iterate
end

function finalise_visualisation(st)
    if isa(st.rc, RemoteChannel)
        # Tell subprocess to finish, and wait
        put!(st.rc, nothing)
        wait(st.visproc)
    end
end

end # Module