src/AlgorithmBothMulti.jl

Sun, 21 Apr 2024 12:46:04 +0300

author
Neil Dizon <neil.dizon@helsinki.fi>
date
Sun, 21 Apr 2024 12:46:04 +0300
changeset 14
c286925c0f35
parent 0
a55e35d20336
child 36
e4a8f662a1ac
permissions
-rw-r--r--

Change rand to randn in PET

######################################################################
# Predictive online PDPS for optical flow with unknown velocity field
######################################################################

__precompile__()

module AlgorithmBothMulti

identifier = "pdps_unknownmulti"

using Printf

using AlgTools.Util
import AlgTools.Iterate
using ImageTools.Gradient
using ImageTools.ImFilter

using ..OpticalFlow: Image,
                     ImageSize,
                     DisplacementConstant,
                     pdflow!,
                     horn_schunck_reg_prox!,
                     pointwise_gradiprod_2d!,
                     horn_schunck_reg_prox_op!,
                     mldivide_step_plus_sym2x2!,
                     ConstantDisplacementHornSchunckData,
                     filter_hs

using ..Algorithm: step_lengths

#########################
# Iterate initialisation
#########################

function init_displ(xinit::Image, ::Type{DisplacementConstant}, n::Integer)
    return xinit, zeros(n, 2)
end

# function init_displ(xinit::Image, ::Type{DisplacementFull}, n::Integer)
#     return xinit, zeros(n, 2, size(xinit)...)
# end

function init_rest(x::Image, u::DisplacementT) where DisplacementT
    imdim=size(x)

    y = zeros(2, imdim...)
    Δx = copy(x)
    Δy = copy(y)
    x̄ = copy(x)

    return x, y, Δx, Δy, x̄, u
end

function init_iterates( :: Type{DisplacementT},
                       xinit::Image,
                       n::Integer) where DisplacementT    
    return init_rest(init_displ(copy(xinit), DisplacementT, n)...)
end

function init_iterates( :: Type{DisplacementT},
                       dim::ImageSize,
                       n::Integer) where DisplacementT
    return init_rest(init_displ(zeros(dim...), DisplacementT, n)...)
end

############
# Algorithm
############

function solve( :: Type{DisplacementT};
               dim :: ImageSize,
               iterate = AlgTools.simple_iterate, 
               params::NamedTuple) where DisplacementT

    ######################
    # Initialise iterates
    ######################

    n = params.displacement_count
    k = 0 # number of displacements we have already

    x, y, Δx, Δy, x̄, u = init_iterates(DisplacementT, dim, n)
    init_data = (params.init == :data)
    hs = [ConstantDisplacementHornSchunckData() for i=1:n]
    #hs = Array{ConstantDisplacementHornSchunckData}(undef, n)
    A = Array{Float64,3}(undef, n, 2, 2)
    d = Array{Float64,2}(undef, n, 2)

    # … for tracking cumulative movement
    ucumulbase = [0.0, 0.0]
    
    #############################################
    # Extract parameters and set up step lengths
    #############################################

    α, ρ, λ, θ = params.α, params.ρ, params.λ, params.θ
    R_K² = ∇₂_norm₂₂_est²
    γ = 1
    τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²)

    kernel = params.kernel
    T = params.timestep 

    ####################
    # Run the algorithm
    ####################

    b_next_filt = nothing
    diffu = similar(u[1, :])

    v = iterate(params) do verbose :: Function,
                           b :: Image,
                           🚫unused_v_known :: DisplacementT,
                           b_next :: Image

        ####################################
        # Smooth data for Horn–Schunck term
        ####################################

        b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel)

        ################################################
        # Prediction step
        ################################################

        # Predict x and y
        if k==0
            if init_data
                x .= b
                init_data = false
            end 
        else
            # Displacement from previous to this image is estimated as
            # the difference of their displacements from beginning of window.
            if k>1
                @. @views diffu = u[k, :] - u[k-1, :]
            else
                @. @views diffu = u[k, :]
            end

            pdflow!(x, Δx, y, Δy, diffu, params.dual_flow)
        end

        # Shift stored prox matrices
        if k==n
            tmp = copy(u[1, :])
            ucumulbase .+= tmp
            for j=1:(n-1)
                @. @views u[j, :] = u[j+1, :] - tmp
                hs[j] = hs[j+1]
            end
            # Create new struct as original contains references to objects that
            # have been moved to index n-1.
            hs[n]=ConstantDisplacementHornSchunckData()
        else
            k += 1
        end

        # Predict u: zero displacement from current to next image, i.e.,
        # same displacement to beginning of window.
        if k==1        
            @. @views u[k, :] = 0.0
        else
            @. @views u[k, :] = u[k-1, :]
        end

        # Predictor proximals tep
        if params.prox_predict
            ∇₂!(Δy, x) 
            @. y = (y + σ̃*Δy)/(1 + σ̃*(ρ̃+ρ/α))
            proj_norm₂₁ball!(y, α) 
        end

        #################################################################################
        # PDPS step
        #
        # For the displacements, with τ̃=τ/k, we need to solve for 2≤j<k,
        #
        # (1) (I/τ̃+M₀^j+M₀^{j+1})u^j = M₀^ju^{j-1} + M₀^{j+1}u^{j+1}
        #                              + ũ^j/τ̃ - z^j + z^{j+1}, 
        #
        # as well as
        #
        # (2) (I/τ̃+M₀^k)u^k = M₀^k u^{k-1} + ũ^k/τ̃ - z^k
        #
        # and
        #
        # (3) (I/τ̃+M₀^1+M₀^2)u^1 = 0 + M₀^{2}u^{2} + ũ^1/τ̃ - z^1 + z^{2}
        #
        # We first construct from (2) that
        #
        #     u^k = A^k u^{k-1} + d^k
        #
        # for
        #
        #     A^k := (I/τ̃+M₀^k)^{-1} M₀^k
        #     d_k := (I/τ̃+M₀^k)^{-1} (ũ^k/τ̃ - z^k).
        #
        # Inserting this into (1) we need
        #
        # (4)  (I/τ̃+M₀^j+M₀^{j+1}(I-A^{j+1}))u^j = M₀^ju^{j-1} + M₀^{j+1}d^{j+1}
        #                                          + ũ^j/τ̃ - z^j + z^{j+1}.
        # 
        # This is well-defined because A^{j+1} < I. It also has the same form as (1), so 
        # we continue with
        #
        # (5)   u^j = A^j u^{j-1} + d^j
        #
        # for
        #
        #      A^j := (I/τ̃+M₀^j+M₀^{j+1}(I-A^{j+1}))^{-1} M₀^j
        #      d^j := (I/τ̃+M₀^j+M₀^{j+1}(I-A^{j+1}))^{-1}
        #               (M₀^{j+1}d^{j+1} + ũ^j/τ̃ - z^j + z^{j+1})
        # 
        # Finally from (3) with these we need
        #
        #      (I/τ̃+M₀^1+M₀^2(I-A^2))u^1 = M₀^2d^2 + ũ^1/τ̃ - z^1 + z^2,
        #
        # which is of the same form as (4) with u^0=0, so by (5) u^1=d^1.
        #
        #################################################################################

        ∇₂ᵀ!(Δx, y)                         # primal step:
        @. x̄ = x                            # |  save old x for over-relax
        @. x = (x-τ*(Δx-b))/(1+τ)           # |  prox
                                            # |  | for displacement
        # Calculate matrices for latest data; rest is stored. 
        @views begin
            horn_schunck_reg_prox_op!(hs[k], b_next_filt, b_filt, θ, λ, T)

            τ̃=τ/k

            B = hs[k].M₀
            c = u[k, :]./τ̃-hs[k].z
            mldivide_step_plus_sym2x2!(A[k, :, :], B, B, τ̃)
            mldivide_step_plus_sym2x2!(d[k, :], B, c, τ̃)

            for j=(k-1):-1:1
                B = hs[j].M₀+hs[j+1].M₀*([1 0; 0 1]-A[j+1, :, :])
                c = hs[j+1].M₀*d[j+1, :]+u[j, :]./τ̃-hs[j].z+hs[j+1].z
                mldivide_step_plus_sym2x2!(A[j, :, :], B, hs[j].M₀, τ̃)
                mldivide_step_plus_sym2x2!(d[j, :], B, c, τ̃)
            end

            u[1, :] .= d[1, :]
            for j=2:k
                u[j, :] .= A[j, :, :]*u[j-1, :] + d[j, :]
            end
        end

        @. x̄ = 2x - x̄                       # over-relax: x̄ = 2x-x_old
        ∇₂!(Δy, x̄)                          # dual step: y
        @. y = (y + σ*Δy)/(1 + σ*ρ/α)       # |
        proj_norm₂₁ball!(y, α)              # |  prox

        ########################################################
        # Give function value and cumulative movement if needed
        ########################################################
        v = verbose() do
            ∇₂!(Δy, x)
            hs_plus_reg=0            
            for j=1:k
                v=(j==1 ? u[j, :] : u[j, :]-u[j-1, :])                
                hs_plus_reg += hs[j].cv/2 + dot(hs[j].Mv*v, v)/2+dot(hs[j].av, v)
            end
            value = (norm₂²(b-x)/2 + hs_plus_reg/k + α*γnorm₂₁(Δy, ρ))

            value, x, u[k, :]+ucumulbase, u[1:k,:].+ucumulbase'
        end

        return v
    end

    return x, y, v
end

end # Module

mercurial