########################################################
# TV reconstruction via primal–dual proximal splitting
########################################################

__precompile__()

module TVRecon

using AlgTools.Util
using AlgTools.LinOps
import AlgTools.Iterate
using ..Gradient

##############
# Our exports
##############

export recon_pdps

#############
# Data types
#############

ImageSize = Tuple{Integer,Integer}
Image = Array{Float64,2}
Primal = Image
Dual = Array{Float64,3}

#########################
# Iterate initialisation
#########################

function init_rest(x::Primal, b::Array{Float64, N}) where N
    imdim=size(x)
    datadim=size(b)

    y = zeros(2, imdim...)
    λ = zeros(datadim...)
    Δx₁ = copy(x)
    Δx₂ = copy(x)
    Δy = copy(y)
    Δλ = copy(λ)
    x̄ = copy(x)

    return x, y, λ, Δx₁, Δx₂, Δy, Δλ, x̄
end

function init_primal(xinit::Image)
    return copy(xinit)
end

############
# Algorithm
############

function recon_pdps(b :: Data, op :: LinOp{Image, Data};
                    xinit :: Union{Image, Nothing} = nothing,
                    iterate = Iterate.simple_iterate,
                    params::NamedTuple) where Data

    ################################
    # Extract and set up parameters
    ################################

    α, ρ = params.α, params.ρ
    τ₀, σ₀ =  params.τ₀, params.σ₀

    R_K = √(∇₂_norm₂₂_est^2+opnorm_estimate(op)^2)
    γ = 1

    @assert(τ₀*σ₀ < 1)
    σ = σ₀/R_K
    τ = τ₀/R_K
    
    ######################
    # Initialise iterates
    ######################

    x, y, λ, Δx₁, Δx₂, Δy, Δλ, x̄ = init_rest(copy(xinit), b)

    ####################
    # Run the algorithm
    ####################

    v = iterate(params) do verbose :: Function
        ω = params.accel ? 1/√(1+2*γ*τ) : 1
        
        ∇₂ᵀ!(Δx₁, y)                      # primal step: x
        inplace!(Δx₂, op', λ)             # |
        @. x̄ = x                          # |
        @. x = x-τ*(Δx₁+Δx₂)              # |
        @. x̄ = (1+ω)*x - ω*x̄              # over-relax: x̄ = 2x-x_old
        ∇₂!(Δy, x̄)                        # dual step: y, λ
        inplace!(Δλ, op, x̄)               # |
        @. y = (y + σ*Δy)/(1 + σ*ρ/α)     # |
        proj_norm₂₁ball!(y, α)            # |
        @. λ = (λ + σ*(Δλ-b))/(1 + σ)     # |

        if params.accel
            τ, σ = τ*ω, σ/ω
        end
                
        ################################
        # Give function value if needed
        ################################
        v = verbose() do
            ∇₂!(Δy, x)
            inplace!(Δλ, op, x)
            value = norm₂²(Δλ-b)/2 + params.α*γnorm₂₁(Δy, params.ρ)
            value, x
        end

        v
    end

    return x, y, v
end

end # Module


