##################################################
# Simple (and fast for small filters compared to
# ImageFiltering) image filtering
##################################################

__precompile__()

module ImFilter

using OffsetArrays
using AlgTools.Util: @threadsif, norm₁
using AlgTools.LinOps

##########
# Exports
##########

export simple_imfilter,
       simple_imfilter!,
       simple_imfilter_adjoint,
       simple_imfilter_adjoint!,
       gaussian,
       FilterKernel

##############
# The routine
##############

Image = Array{Float64,2}
Kernel = OffsetArray{Float64,2,Image}

@inline function inside(i, aʹ, bʹ, a, b)
     return (max(a, i - aʹ) - i):(min(b,  i + bʹ) - i)
end

function simple_imfilter!(res::Image, b::Image, kernel::Kernel; threads::Bool=true)
    n, m = size(b)
    k, 𝓁 = size(kernel)
    o₁, o₂ = kernel.offsets
    a₁, a₂ = k + o₁, 𝓁 + o₂
    b₁, b₂ = -1 - o₁, -1 - o₂
    kp = kernel.parent

    @assert(isodd(k) && isodd(𝓁) && size(res)==size(b))

    @threadsif threads for i=1:n
        @inbounds for j=1:m
            tmp = 0.0
            it₁ = inside(i, a₁, b₁, 1, n)
            it₂ = inside(j, a₂, b₂, 1, m)
            for p=it₁
                @simd for q=it₂
                    tmp += kp[p-o₁, q-o₂]*b[i+p,j+q]
                end
            end
            res[i, j] = tmp
        end
    end

    return res
end

function simple_imfilter(b::Image, kernel::Kernel; threads::Bool=true)
    res = similar(b)
    simple_imfilter!(res, b, kernel)
end

function simple_imfilter_adjoint!(res::Image, b::Image, kernel::Kernel; threads::Bool=true)
    n, m = size(b)
    k, 𝓁 = size(kernel)
    o₁, o₂ = kernel.offsets
    a₁, a₂ = k + o₁, 𝓁 + o₂
    b₁, b₂ = -1 - o₁, -1 - o₂
    kp = kernel.parent

    @assert(isodd(k) && isodd(𝓁) && size(res)==size(b))

    res .= 0

    @threadsif threads for i=1:n
        @inbounds for j=1:m
            it₁ = inside(i, a₁, b₁, 1, n)
            it₂ = inside(j, a₂, b₂, 1, m)
            for p=it₁
                @simd for q=it₂
                    res[i+p,j+q] += kp[p-o₁, q-o₂]*b[i, j]
                end
            end
        end
    end

    return res
end

function simple_imfilter_adjoint(b::Image, kernel::Kernel; threads::Bool=true)
    res = similar(b)
    simple_imfilter_adjoint!(res, b, kernel)
end

###########################
# Abstract linear operator
###########################

struct FilterKernel <: AdjointableOp{Image, Image}
    kernel::Kernel
end

function (op::FilterKernel)(b::Image)
    return simple_imfilter(b, op.kernel)
end

function LinOps.inplace!(y::Image, op::FilterKernel, x::Image)
    return simple_imfilter!(y, x, op.kernel)
end

function LinOps.calc_adjoint(op::FilterKernel, y::Image)
    return simple_imfilter_adjoint(y, op.kernel)
end

function LinOps.calc_adjoint!(res::Image, op::FilterKernel, y::Image)
    return simple_imfilter_adjoint!(res, y, op.kernel)
end

function LinOps.opnorm_estimate(op::FilterKernel)
    # Due to |f * g|_p ≤ |f|_p|g|_1
    return norm₁(op.kernel)
end

######################################################
# Distributions. Just to avoid the long load times of
# ImageFiltering and heavy dependencies on FFTW etc.
######################################################

function gaussian(σ, n)
    @assert(all(isodd.(n)))
    a=convert.(Integer, @. (n-1)/2)
    g=OffsetArray{Float64}(undef, [-m:m for m in a]...);
    for i in CartesianIndices(g)
        g[i]=exp(-sum(Tuple(i).^2 ./ (2 .* σ.^2)))
    end
    g./=sum(g)
end

end # Module

