####################################################################
# Predictive online PDPS for optical flow with known velocity field
####################################################################

__precompile__()

module Algorithm

identifier = "pdps_known"

using Printf

using AlgTools.Util
import AlgTools.Iterate
using ImageTools.Gradient

using ..OpticalFlow: ImageSize,
                     Image,
                     pdflow!

#########################
# Iterate initialisation
#########################

function init_rest(x::Image)
    imdim=size(x)

    y = zeros(2, imdim...)
    Δx = copy(x)
    Δy = copy(y)
    x̄ = copy(x)

    return x, y, Δx, Δy, x̄
end

function init_iterates(xinit::Image)
    return init_rest(copy(xinit))
end

function init_iterates(dim::ImageSize)
    return init_rest(zeros(dim...))
end

############
# Algorithm
############

function step_lengths(params, γ, R_K²)
    ρ̃₀, τ₀, σ₀, σ̃₀ =  params.ρ̃₀, params.τ₀, params.σ₀, params.σ̃₀
    δ = params.δ
    ρ = isdefined(params, :phantom_ρ) ? params.phantom_ρ : params.ρ
    Λ = params.Λ
    Θ = params.dual_flow ? Λ : 1

    τ = τ₀/γ
    @assert(1+γ*τ ≥ Λ)
    σ = σ₀*min(1/(τ*R_K²), 1/max(0, τ*R_K²/((1+γ*τ-Λ)*(1-δ))-ρ))
    q = δ*(1+σ*ρ)/Θ
    if 1 ≥ q
        σ̃ = σ̃₀*σ/q
        #ρ̃ = ρ̃₀*max(0, ((Θ*σ)/(2*δ*σ̃^2*(1+σ*ρ))+1/(2σ)-1/σ̃))
        ρ̃ = max(0, (1-q)/(2*σ))
    else
        σ̃ = σ̃₀*σ/(q*(1-√(1-1/q)))
        ρ̃ = 0
    end
    
    println("Step length parameters: τ=$(τ), σ=$(σ), σ̃=$(σ̃), ρ̃=$(ρ̃)")

    return τ, σ, σ̃, ρ̃
end

function solve( :: Type{DisplacementT};
               dim :: ImageSize,
               iterate = AlgTools.simple_iterate,
               params::NamedTuple) where DisplacementT

    ################################                                        
    # Extract and set up parameters
    ################################                    

    α, ρ = params.α, params.ρ
    R_K² = ∇₂_norm₂₂_est²
    γ = 1
    τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²)

    ######################
    # Initialise iterates
    ######################

    x, y, Δx, Δy, x̄ = init_iterates(dim)
    init_data = (params.init == :data)

    ####################
    # Run the algorithm
    ####################

    v = iterate(params) do verbose :: Function,
                           b :: Image,
                           v_known :: DisplacementT,
                           🚫unused_b_next :: Image

        ##################
        # Prediction step
        ##################
        if init_data
            x .= b
            init_data = false
        end

        pdflow!(x, Δx, y, Δy, v_known, params.dual_flow)

        if params.prox_predict
            ∇₂!(Δy, x)
            @. y = (y + σ̃*Δy)/(1 + σ̃*(ρ̃+ρ/α))
            proj_norm₂₁ball!(y, α) 
        end

        ############
        # PDPS step
        ############

        ∇₂ᵀ!(Δx, y)                    # primal step:
        @. x̄ = x                       # |  save old x for over-relax
        @. x = (x-τ*(Δx-b))/(1+τ)      # |  prox
        @. x̄ = 2x - x̄                  # over-relax
        ∇₂!(Δy, x̄)                     # dual step: y
        @. y = (y + σ*Δy)/(1 + σ*ρ/α)  # |
        proj_norm₂₁ball!(y, α)         # |  prox

        ################################
        # Give function value if needed
        ################################
        v = verbose() do            
            ∇₂!(Δy, x)
            value = norm₂²(b-x)/2 + params.α*γnorm₂₁(Δy, params.ρ)
            value, x, [NaN, NaN], nothing
        end

        v
    end

    return x, y, v
end

end # Module


