######################################################################
# Predictive online PDPS for optical flow with unknown velocity field
######################################################################

__precompile__()

module AlgorithmBothGreedyV

identifier = "pdps_unknown_greedyv"

using Printf

using AlgTools.Util
import AlgTools.Iterate
using ImageTools.Gradient

using ..OpticalFlow: Image,
                     ImageSize,
                     DisplacementConstant,
                     DisplacementFull,
                     pdflow!,
                     horn_schunck_reg_prox!,
                     pointwise_gradiprod_2d!,
                     filter_hs

using ..Algorithm: step_lengths

#########################
# Iterate initialisation
#########################

function init_displ(xinit::Image, ::Type{DisplacementConstant})
    return xinit, zeros(2)
end

function init_displ(xinit::Image, ::Type{DisplacementFull})
    return xinit, zeros(2, size(xinit)...)
end

function init_rest(x::Image, u::DisplacementT) where DisplacementT
    imdim=size(x)

    y = zeros(2, imdim...)
    Δx = copy(x)
    Δy = copy(y)
    x̄ = copy(x)

    return x, y, Δx, Δy, x̄, u
end

function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT    
    return init_rest(init_displ(copy(xinit), DisplacementT)...)
end

function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT
    return init_rest(init_displ(zeros(dim...), DisplacementT)...)
end

############
# Algorithm
############

function solve( :: Type{DisplacementT};
               dim :: ImageSize,
               iterate = AlgTools.simple_iterate, 
               params::NamedTuple) where DisplacementT

    ######################
    # Initialise iterates
    ######################

    x, y, Δx, Δy, x̄, u = init_iterates(DisplacementT, dim)
    init_data = (params.init == :data)

    # … for tracking cumulative movement
    if DisplacementT == DisplacementConstant
        ucumul = [0.0, 0.0]
    else
        ucumul = [NaN, NaN]
    end
    
    #############################################
    # Extract parameters and set up step lengths
    #############################################

    α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep
    R_K² = ∇₂_norm₂₂_est²
    γ = 1
    τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²)

    kernel = params.kernel

    ####################
    # Run the algorithm
    ####################

    b_next_filt=nothing

    v = iterate(params) do verbose :: Function,
                           b :: Image,
                           🚫unused_v_known :: DisplacementT,
                           b_next :: Image

        ####################################
        # Smooth data for Horn–Schunck term
        ####################################

        b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel)

        ##################
        # Prediction step
        ##################

        if init_data
            x .= b
            init_data = false
        end

        pdflow!(x, Δx, y, Δy, u, params.dual_flow)

        # Predict zero displacement
        u .= 0
        if params.prox_predict
            ∇₂!(y, x) 
            @. y = (y + σ̃*Δy)/(1 + σ̃*(ρ̃+ρ/α))
            proj_norm₂₁ball!(y, α) 
        end

        ############
        # PDPS step
        ############

        ∇₂ᵀ!(Δx, y)                       # primal step:
        @. x̄ = x                          # |  save old x for over-relax
        @. x = (x-τ*(Δx-b))/(1+τ)         # |  prox
        horn_schunck_reg_prox!(u, b_next_filt, b_filt, θ, λ, T, τ)
        @. x̄ = 2x - x̄                     # over-relax
        ∇₂!(y, x̄)                         # dual step: y
        @. y = (y + σ*Δy)/(1 + σ*ρ/α)     # |
        proj_norm₂₁ball!(y, α)            # |  prox

        if DisplacementT == DisplacementConstant
            ucumul .+= u
        end

        ########################################################
        # Give function value and cumulative movement if needed
        ########################################################
        v = verbose() do
            ∇₂!(Δy, x)
            tmp = zeros(size(b_filt))
            pointwise_gradiprod_2d!(tmp, Δy, u, b_filt)
            value = (norm₂²(b-x)/2 + θ*norm₂²((b_next_filt-b_filt)./T+tmp)
                     + λ*norm₂²(u)/2 + α*γnorm₂₁(Δy, ρ))

            value, x, ucumul, nothing
        end

        return v
    end

    return x, y, v
end

end # Module


