src/AlgorithmBoth.jl

Fri, 19 Apr 2024 17:00:37 +0300

author
Neil Dizon <neil.dizon@helsinki.fi>
date
Fri, 19 Apr 2024 17:00:37 +0300
changeset 8
e4ad8f7ce671
parent 0
a55e35d20336
permissions
-rw-r--r--

Added PET and updated README

######################################################################
# Predictive online PDPS for optical flow with unknown velocity field
######################################################################

__precompile__()

module AlgorithmBoth

identifier = "pdps_unknown_basic"

using Printf

using AlgTools.Util
import AlgTools.Iterate
using ImageTools.Gradient

using ..OpticalFlow: ImageSize,
                     Image,
                     Gradient,
                     DisplacementConstant,
                     DisplacementFull,
                     pdflow!,
                     pointwise_gradiprod_2d!,
                     pointwise_gradiprod_2dᵀ!,
                     filter_hs

using ..Algorithm: step_lengths

#############
# Data types
#############

struct Primal{DisplacementT}
    x :: Image
    u :: DisplacementT
end

function Base.similar(x::Primal{DisplacementT}) where DisplacementT
   return Primal{DisplacementT}(Base.similar(x.x), Base.similar(x.u))
end

function Base.copy(x::Primal{DisplacementT}) where DisplacementT
    return Primal{DisplacementT}(Base.copy(x.x), Base.copy(x.u))
 end

struct Dual
    tv :: Gradient
    flow :: Image
end

function Base.similar(y::Dual)
   return Dual(Base.similar(y.tv), Base.similar(y.flow))
end

function Base.copy(y::Dual)
    return Dual(Base.copy(y.tv), Base.copy(y.flow))
 end

#########################
# Iterate initialisation
#########################

function init_primal(xinit::Image, ::Type{DisplacementConstant})
    return Primal{DisplacementConstant}(xinit, zeros(2))
end

function init_primal(xinit::Image, ::Type{DisplacementFull})
    return Primal{DisplacementFull}(xinit, zeros(2, size(xinit)...))
end

function init_rest(x::Primal{DisplacementT}) where DisplacementT
    imdim=size(x.x)

    y = Dual(zeros(2, imdim...), zeros(imdim))
    Δx = copy(x)
    Δy = copy(y)
    x̄ = copy(x)

    return x, y, Δx, Δy, x̄
end

function init_iterates( :: Type{DisplacementT},
                       xinit::Primal{DisplacementT}) where DisplacementT       
    return init_rest(copy(xinit))
end

function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT    
    return init_rest(init_primal(copy(xinit), DisplacementT))
end

function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT
    return init_rest(init_primal(zeros(dim...), DisplacementT))
end

##############################################
# Weighting for different displacements types
##############################################

norm²weight( :: Type{DisplacementConstant}, sz ) = prod(sz)
norm²weight( :: Type{DisplacementFull}, sz ) = 1

############
# Algorithm
############

function solve( :: Type{DisplacementT};
               dim :: ImageSize,
               iterate = AlgTools.simple_iterate, 
               params::NamedTuple) where DisplacementT

    ######################
    # Initialise iterates
    ######################

    x, y, Δx, Δy, x̄ = init_iterates(DisplacementT, dim)
    init_data = (params.init == :data)

    # … for tracking cumulative movement
    if DisplacementT == DisplacementConstant
        ucumul = [0.0, 0.0]
    else
        ucumul = [NaN, NaN]
    end
    
    #############################################
    # Extract parameters and set up step lengths
    #############################################

    α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep
    R_K² = max(∇₂_norm₂₂_est², ∇₂_norm₂∞_est²*params.dynrange^2)
    γ = min(1, λ*norm²weight(DisplacementT, size(x.x)))
    τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²)

    kernel = params.kernel

    ####################
    # Run the algorithm
    ####################

    b_next_filt=nothing

    v = iterate(params) do verbose :: Function,
                           b :: Image,
                           🚫unused_v_known :: DisplacementT,
                           b_next :: Image

        ####################################
        # Smooth data for Horn–Schunck term
        ####################################

        b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel)

        ############################
        # Construct K for this step
        ############################

        K! = (yʹ, xʹ) -> begin
            # Optical flow part: ⟨⟨u, ∇b_k⟩⟩.
            # Use y.tv as temporary gradient storage.
            pointwise_gradiprod_2d!(yʹ.flow, yʹ.tv, xʹ.u, b_filt)
            #@. yʹ.flow = -yʹ.flow
            # TV part
            ∇₂!(yʹ.tv, xʹ.x) 
        end
        Kᵀ! = (xʹ, yʹ) -> begin
            # Optical flow part: ∇b_k ⋅ y
            pointwise_gradiprod_2dᵀ!(xʹ.u, yʹ.flow, b_filt)
            #@. xʹ.u = -xʹ.u
            # TV part
            ∇₂ᵀ!(xʹ.x, yʹ.tv) 
        end

        ##################
        # Prediction step
        ##################

        if init_data
            x .= b
            init_data = false
        end

        pdflow!(x.x, Δx.x, y.tv, Δy.tv, y.flow, Δy.flow, x.u, params.dual_flow)

        # Predict zero displacement
        x.u .= 0
        if params.prox_predict
            K!(Δy, x)
            @. y.tv = (y.tv + σ̃*Δy.tv)/(1 + σ̃*(ρ̃+ρ/α))
            proj_norm₂₁ball!(y.tv, α) 
            @. y.flow = (y.flow+σ̃*((b_next_filt-b_filt)/T+Δy.flow))/(1+σ̃*(ρ̃+1/θ))
        end

        ############
        # PDPS step
        #
        # NOTE: For DisplacementConstant, the x.u update is supposed to be with
        # respect to the 𝟙^*𝟙 norm/inner product that makes the norm equivalent
        # to full-space norm when restricted to constant displacements. Since
        # `OpticalFlow.pointwise_gradiprod_2dᵀ!` already uses this inner product,
        # and the λ-weighted term in the problem is with respect to this norm,
        # all the norm weights disappear in this update.
        ############

        Kᵀ!(Δx, y)                             # primal step:
        @. x̄.x = x.x                           # |  save old x for over-relax
        @. x̄.u = x.u                           # |  
        @. x.x = (x.x-τ*(Δx.x-b))/(1+τ)        # |  prox
        @. x.u = (x.u-τ*Δx.u)/(1+τ*λ)          # |  
        @. x̄.x = 2x.x - x̄.x                    # over-relax
        @. x̄.u = 2x.u - x̄.u                    # |
        K!(Δy, x̄)                              # dual step: y
        @. y.tv = (y.tv + σ*Δy.tv)/(1 + σ*ρ/α) # |
        proj_norm₂₁ball!(y.tv, α)              # |  prox
        @. y.flow = (y.flow+σ*((b_next_filt-b_filt)/T+Δy.flow))/(1+σ/θ)

        if DisplacementT == DisplacementConstant
            ucumul .+= x.u
        end

        ########################################################
        # Give function value and cumulative movement if needed
        ########################################################
        v = verbose() do
            K!(Δy, x)
            value = (norm₂²(b-x.x)/2 + θ*norm₂²((b_next_filt-b_filt)./T+Δy.flow)
                     + λ*norm₂²(x.u)/2 + α*γnorm₂₁(Δy.tv, ρ))

            value, x.x, ucumul, nothing
        end

        return v
    end

    return x, y, v
end

end # Module

mercurial