####################################################################
# Predictive online PDPS for optical flow with known velocity field
####################################################################

__precompile__()

module AlgorithmFB

identifier = "fb_known"

using Printf

using AlgTools.Util
import AlgTools.Iterate
using ImageTools.Gradient

using ..OpticalFlow: Image,
                     ImageSize,
                     flow!

#########################
# Iterate initialisation
#########################

function init_rest(x::Image)
    imdim=size(x)

    y = zeros(2, imdim...)
    Δx = copy(x)
    Δy = copy(y)
    ỹ = copy(y)
    y⁻ = copy(y)

    return x, y, Δx, Δy, ỹ, y⁻
end

function init_iterates(xinit::Image)
    return init_rest(copy(xinit))
end

function init_iterates(dim::ImageSize)
    return init_rest(zeros(dim...))
end

############
# Algorithm
############

function solve( :: Type{DisplacementT};
               dim :: ImageSize,
               iterate = AlgTools.simple_iterate,
               params::NamedTuple) where DisplacementT

    ################################                                        
    # Extract and set up parameters
    ################################                    

    α, ρ = params.α, params.ρ
    τ₀, τ̃₀ =  params.τ₀, params.τ̃₀

    R_K² = ∇₂_norm₂₂_est²    
    τ̃ = τ̃₀/R_K²
    τ = τ₀
    
    ######################
    # Initialise iterates
    ######################

    x, y, Δx, Δy, ỹ, y⁻  = init_iterates(dim)
    init_data = (params.init == :data)

    ####################
    # Run the algorithm
    ####################

    v = iterate(params) do verbose :: Function,
                           b :: Image,
                           v_known :: DisplacementT,
                           🚫unused_b_next :: Image

        ##################
        # Prediction step
        ##################

        if init_data
            x .= b
            init_data = false
        else
            # Δx is a temporary storage variable of correct dimensions
            flow!(x, v_known, Δx)
        end

        ##################################################################
        # We need to do forward–backward step on min_x |x-b|^2/2 + α|∇x|.
        # The forward step is easy, the prox requires solving the predual
        # problem of a problem similar to the original.
        ##################################################################

        @. x = x-τ*(x-b)

        ##############
        # Inner FISTA
        ##############

        t = 0
        # Move step length from proximal quadratic term into L1 term.
        α̃ = α*τ
        @. ỹ = y
        for i=1:params.fb_inner_iterations
            ∇₂ᵀ!(Δx, ỹ)
            @. Δx .-= x
            ∇₂!(Δy, Δx)
            @. y⁻ = y
            @. y = (ỹ - τ̃*Δy)/(1 + τ̃*ρ/α̃)
            proj_norm₂₁ball!(y, α̃)
            t⁺ = (1+√(1+4*t^2))/2
            @. ỹ = y+((t-1)/t⁺)*(y-y⁻)
            t = t⁺
        end

        ∇₂ᵀ!(Δx, y)
        @. x = x - Δx

        ################################
        # Give function value if needed
        ################################
        v = verbose() do            
            ∇₂!(Δy, x)
            value = norm₂²(b-x)/2 + α*γnorm₂₁(Δy, ρ)
            value, x, [NaN, NaN], nothing
        end

        v
    end

    return x, y, v
end

end # Module


