######################################################################
# Predictive online PDPS for optical flow with unknown velocity field
######################################################################

__precompile__()

module AlgorithmBothCumul

identifier = "pdps_unknown_cumul"

using Printf

using AlgTools.Util
import AlgTools.Iterate
using ImageTools.Gradient
using ImageTools.ImFilter

using ..OpticalFlow: Image,
                     ImageSize,
                     DisplacementConstant,
                     pdflow!,
                     horn_schunck_reg_prox!,
                     pointwise_gradiprod_2d!

using ..AlgorithmBothGreedyV: init_iterates
using ..Algorithm: step_lengths

############
# Algorithm
############

function solve( :: Type{DisplacementT};
               dim :: ImageSize,
               iterate = AlgTools.simple_iterate, 
               params::NamedTuple) where DisplacementT

    ######################
    # Initialise iterates
    ######################

    x, y, Δx, Δy, x̄, u = init_iterates(DisplacementT, dim)
    init_data = (params.init == :data)

    # … for tracking cumulative movement
    if DisplacementT == DisplacementConstant
        ucumul = zeros(size(u)...)
    end
    
    #############################################
    # Extract parameters and set up step lengths
    #############################################

    α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep
    R_K² = ∇₂_norm₂₂_est²
    γ = 1
    τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²)

    kernel = params.kernel

    ####################
    # Run the algorithm
    ####################

    b₀=nothing
    b₀_filt=nothing
    u_prev=zeros(size(u))

    v = iterate(params) do verbose :: Function,
                           b :: Image,
                           🚫unused_v_known :: DisplacementT,
                           🚫unused_b_next :: Image

        #########################################################
        # Smoothen data for Horn–Schunck term; zero initial data
        #########################################################
        
        b_filt = (kernel==nothing ? b : simple_imfilter(b, kernel))

        if b₀ == nothing
            b₀ = b
            b₀_filt = b_filt
        end

        ################################################
        # Prediction step
        # We leave u as-is in this cumulative version
        ################################################

        if init_data
            x .= b
            init_data = false
        end

        pdflow!(x, Δx, y, Δy, u-u_prev, params.dual_flow)

        if params.prox_predict
            ∇₂!(Δy, x) 
            @. y = (y + σ̃*Δy)/(1 + σ̃*(ρ̃+ρ/α))
            proj_norm₂₁ball!(y, α) 
        end

        # Store current cumulative displacement before updating in next step.
        u_prev .= u

        ############
        # PDPS step
        ############

        ∇₂ᵀ!(Δx, y)                      # primal step:
        @. x̄ = x                         # |  save old x for over-relax
        @. x = (x-τ*(Δx-b))/(1+τ)        # |  prox
        horn_schunck_reg_prox!(u, b_filt, b₀_filt, τ, θ, λ, T)
        @. x̄ = 2x - x̄                    # over-relax
        ∇₂!(Δy, x̄)                       # dual step: y
        @. y = (y + σ*Δy)/(1 + σ*ρ/α)    # |
        proj_norm₂₁ball!(y, α)           # |  prox

        ########################################################
        # Give function value and cumulative movement if needed
        ########################################################
        v = verbose() do
            ∇₂!(Δy, x)
            tmp = zeros(size(b_filt))
            pointwise_gradiprod_2d!(tmp, Δy, u, b₀_filt)
            value = (norm₂²(b-x)/2 + θ*norm₂²((b_filt-b₀_filt)./T+tmp)
                     + λ*norm₂²(u)/2 + α*γnorm₂₁(Δy, ρ))

            value, x, u, nothing
        end

        return v
    end

    return x, y, v
end

end # Module


