###################
# Image generation
###################

module ImGenerate

using ColorTypes: Gray
import TestImages
# We don't really *directly* depend on QuartzImageIO. The import here is
# merely a workaround to suppress warnings when  loading TestImages.
# Something is broken in those packages.
import QuartzImageIO

using AlgTools.Util
using AlgTools.Comms
using ImageTools.Translate

using ..OpticalFlow: Image, DisplacementConstant, DisplacementFull

##############
# Our exports
##############

export ImGen,
       OnlineData,
       imgen_square,
       imgen_shake

##################
# Data structures
##################

struct ImGen
    f :: Function
    dim :: Tuple{Int64,Int64}
    Λ :: Float64
    dynrange :: Float64
    name :: String
end

struct OnlineData{DisplacementT}
    b_true :: Image
    b_noisy :: Image
    v :: DisplacementT
    v_true :: DisplacementT
    v_cumul_true :: DisplacementT
end

###################
# Shake generation
###################

function make_const_v(displ, sz)
    v = zeros(2, sz...)
    v[1, :, :] .= displ[1]
    v[2, :, :] .= displ[2]
    return v
end

function shake(params)
    if !haskey(params, :shaketype) || params.shaketype == :gaussian
        return () -> params.shake.*randn(2)
    elseif params.shaketype == :disk
        return () -> begin
            θ = 2π*rand(Float64)
            r = params.shake*√(rand(Float64))
            return [r*cos(θ), r*sin(θ)]
        end
    elseif params.shaketype == :circle
        return () -> begin
            θ = 2π*rand(Float64)
            r = params.shake
            return [r*cos(θ), r*sin(θ)]
        end        
    else
        error("Unknown shaketype $(params.shaketype)")
    end 
end

pixelwise = (shakefn, sz) -> () -> make_const_u(shakefn(), sz)

################
# Moving square
################

function generate_square(sz,
                         :: Type{DisplacementT},
                         datachannel :: Channel{OnlineData{DisplacementT}},
                         params) where DisplacementT

    if false
        v₀ = make_const_v(0.1.*(-1, 1), sz)
        nextv = () -> v₀
    elseif DisplacementT == DisplacementFull
        nextv = pixelwise(shake(params), sz)
    elseif DisplacementT == DisplacementConstant
        nextv = shake(params)
    else
        @error "Invalid DisplacementT"
    end

    # Constant linear displacement everywhere has Jacobian determinant one
    # (modulo the boundaries which we ignore here)
    m = round(Int, sz[1]/5)
    b_orig = zeros(sz...)
    b_orig[sz[1].-(2*m:3*m), 2*m:3*m] .= 1

    v_true = nextv()
    v_cumul = copy(v_true)

    while true
        # Flow original data and add noise
        b_true = zeros(sz...)
        translate_image!(b_true, b_orig, v_cumul; threads=true)
        b = b_true .+ params.noise_level.*randn(sz...)
        v = v_true.*(1.0 .+ params.shake_noise_level.*randn(size(v_true)...))
        # Pass true data to iteration routine
        data = OnlineData{DisplacementT}(b_true, b, v, v_true, v_cumul)
        if !put_unless_closed!(datachannel, data)
            return
        end
        # Next step shake
        v_true = nextv()
        v_cumul .+= v_true
    end
end

function imgen_square(sz)
    return ImGen(curry(generate_square, sz), sz, 1, 1, "square$(sz[1])x$(sz[2])")
end

################
# Shake a photo
################

function generate_shake_image(im, sz,
                              :: Type{DisplacementConstant},
                              datachannel :: Channel{OnlineData{DisplacementConstant}},
                              params :: NamedTuple)

    nextv = shake(params)
    v_true = nextv()
    v_cumul = copy(v_true)

    while true
        # Extract subwindow of original image and add noise
        b_true = zeros(sz...)
        extract_subimage!(b_true, im, v_cumul; threads=true)
        b = b_true .+ params.noise_level.*randn(sz...)
        v = v_true.*(1.0 .+ params.shake_noise_level.*randn(size(v_true)...))
        # Pass data to iteration routine
        data = OnlineData{DisplacementConstant}(b_true, b, v, v_true, v_cumul)
        if !put_unless_closed!(datachannel, data)
            return
        end
        # Next step shake
        v_true = nextv()
        v_cumul .+= v_true
    end
end

function imgen_shake(imname, sz)
    im = Float64.(Gray.(TestImages.testimage(imname)))
    dynrange = maximum(im)
    return ImGen(curry(generate_shake_image, im, sz), sz, 1, dynrange,
                 "$(imname)$(sz[1])x$(sz[2])")
end

end # Module
