src/PET/OpticalFlow.jl

Wed, 24 Apr 2024 17:04:57 +0300

author
Neil Dizon <neil.dizon@helsinki.fi>
date
Wed, 24 Apr 2024 17:04:57 +0300
changeset 32
88632284396f
parent 26
ccd22bbbb02f
permissions
-rw-r--r--

added plotting and table generator

################################
# Code relevant to optical flow
################################

__precompile__()

module OpticalFlow

using AlgTools.Util
using ImageTools.Gradient
import ImageTools.Translate
using ImageTools.ImFilter

# using ImageTransformations
# using Images, CoordinateTransformations, Rotations, OffsetArrays
# using Interpolations

import Images: center, warp
import CoordinateTransformations: recenter
import Rotations: RotMatrix
import Interpolations: Flat

##########
# Exports
##########

export flow!,
       pdflow!,
       flow_grad!,
       flow_interp!,
       estimate_Λ²,
       estimate_linear_Λ²,
       pointwise_gradiprod_2d!,
       pointwise_gradiprod_2dᵀ!,
       horn_schunck_reg_prox!,
       horn_schunck_reg_prox_op!,
       mldivide_step_plus_sym2x2!,
       linearised_optical_flow_error,
       Image, AbstractImage, ImageSize,
       Gradient, Displacement,
       DisplacementFull, DisplacementConstant,
       HornSchunckData,
       filter_hs,
       petpdflow!,
       DualScaling, Greedy, Rotation

###############################################
# Types (several imported from ImageTools.Translate)
###############################################

Image = Translate.Image
AbstractImage = AbstractArray{Float64,2}
Displacement = Translate.Displacement
DisplacementFull = Translate.DisplacementFull
DisplacementConstant = Translate.DisplacementConstant
Gradient = Array{Float64,3}
ImageSize = Tuple{Int64,Int64}


#################################
# Struct for flow
#################################
struct DualScaling end
struct Greedy end
struct Rotation end

#################################
# Displacement field based flow
#################################

function flow_interp!(x::AbstractImage, u::Displacement, tmp::AbstractImage;
                      threads = false)
    tmp .= x
    Translate.translate_image!(x, tmp, u; threads=threads)
end

function flow_interp!(x::AbstractImage, u::Displacement;
                      threads = false)
    tmp = copy(x)
    Translate.translate_image!(x, tmp, u; threads=threads)
end

flow! = flow_interp!

function pdflow!(x, Δx, y, Δy, u, dual_flow; threads=:none)
    if dual_flow
        #flow!((x, @view(y[1, :, :]), @view(y[2, :, :])), diffu,
        #      (Δx, @view(Δy[1, :, :]), @view(Δy[2, :, :])))
        @backgroundif (threads==:outer) begin
            flow!(x, u, Δx; threads=(threads==:inner))
        end begin
            flow!(@view(y[1, :, :]), u, @view(Δy[1, :, :]); threads=(threads==:inner))
            flow!(@view(y[2, :, :]), u, @view(Δy[2, :, :]); threads=(threads==:inner))
        end
    else
        flow!(x, u, Δx)
    end
end

function pdflow!(x, Δx, y, Δy, z, Δz, u, dual_flow; threads=:none)
    if dual_flow
        @backgroundif (threads==:outer) begin
            flow!(x, u, Δx; threads=(threads==:inner))
            flow!(z, u, Δz; threads=(threads==:inner))
        end begin
            flow!(@view(y[1, :, :]), u, @view(Δy[1, :, :]); threads=(threads==:inner))
            flow!(@view(y[2, :, :]), u, @view(Δy[2, :, :]); threads=(threads==:inner))
        end
    else
        flow!(x, u, Δx; threads=(threads==:inner))
        flow!(z, u, Δz; threads=(threads==:inner))
    end
end

# Additional method for Greedy
function pdflow!(x, Δx, y, Δy, u, flow :: Greedy; threads=:none)
    @assert(size(u)==(2,))  
    Δx .= x
    Δy .= y
    flow!(x, u; threads=(threads==:inner))
    Dxx = similar(Δy)
    DΔx = similar(Δy)
    ∇₂!(Dxx, x)
    ∇₂!(DΔx, Δx)
    inds = abs.(Dxx) .≤ 1e-1
    Dxx[inds] .= 1
    DΔx[inds] .= 1
    y .= y.* DΔx ./ Dxx  
end

# Additional method for Rotation
function pdflow!(x, Δx, y, Δy, u, flow :: Rotation; threads=:none) 
    @assert(size(u)==(2,))
    Δx .= x
    flow!(x, u; threads=(threads==:inner))
    (m,n) = size(x)
    dx = similar(y)
    dx_banana = similar(y)
    ∇₂!(dx, Δx)
    ∇₂!(dx_banana, x)
    for i=1:m
        for j=1:n
            ndx = @views sum(dx[:, i, j].^2)
            ndx_banana = @views sum(dx_banana[:, i, j].^2)
            if ndx > 1e-4 && ndx_banana > 1e-4
                A = dx[:, i, j]
                B = dx_banana[:, i, j]
                theta = atan(B[1] * A[2] - B[2] * A[1], B[1] * A[1] + B[2] * A[2]) # Oriented angle from A to B
                cos_theta = cos(theta)
                sin_theta = sin(theta)
                a = cos_theta * y[1, i, j] - sin_theta * y[2, i, j]
                b = sin_theta * y[1, i, j] + cos_theta * y[2, i, j]
                y[1, i, j] = a
                y[2, i, j] = b
            end
        end
    end
end

# Additional method for Dual Scaling
function pdflow!(x, Δx, y, Δy, u, flow :: DualScaling; threads=:none)
    @assert(size(u)==(2,))
    oldx = copy(x)
    flow!(x, u; threads=(threads==:inner))
    C = similar(y)
    cc = abs.(x-oldx)
    cm = max(1e-12,maximum(cc))
    c = 1 .* (1 .- cc./ cm) .^(10)
    C[1,:,:] .= c
    C[2,:,:] .= c
    y .= C.*y 
end


##########################
# PET
##########################
function petflow_interp!(x::AbstractImage, tmp::AbstractImage, u::DisplacementConstant, theta_known::DisplacementConstant;
    threads = false)
    tmp .= x
    center_point = center(x) .+ u
    tform = recenter(RotMatrix(theta_known[1]), center_point)
    tmp = warp(x, tform, axes(x), fillvalue=Flat())
    x .= tmp
end

petflow! = petflow_interp!

function petpdflow!(x, Δx, y, Δy, u, theta_known, dual_flow; threads=:none)
    if dual_flow
        @backgroundif (threads==:outer) begin
        petflow!(x, Δx, u, theta_known; threads=(threads==:inner))
    end begin
        petflow!(@view(y[1, :, :]), @view(Δy[1, :, :]), u, theta_known; threads=(threads==:inner))
        petflow!(@view(y[2, :, :]), @view(Δy[2, :, :]), u, theta_known; threads=(threads==:inner))
        end
    else
    petflow!(x, Δx, u, theta_known)
    end
end

# Method for greedy predictor
function petpdflow!(x, Δx, y, Δy, u, theta_known, flow :: Greedy; threads=:none)
    oldx = copy(x)
    center_point = center(x) .+ u
    tform = recenter(RotMatrix(theta_known[1]), center_point)
    Δx = warp(x, tform, axes(x), fillvalue=Flat())
    @. x = Δx
    @. Δy = y
    Dxx = copy(Δy)
    DΔx = copy(Δy)
    ∇₂!(Dxx, x)
    ∇₂!(DΔx, oldx)
    inds = abs.(Dxx) .≤ 1e-2
    Dxx[inds] .= 1
    DΔx[inds] .= 1
    y .= y.* DΔx ./ Dxx            
end

# Method for dual scaling predictor
function petpdflow!(x, Δx, y, Δy, u, theta_known, flow :: DualScaling; threads=:none)
    oldx = copy(x)
    center_point = center(x) .+ u
    tform = recenter(RotMatrix(theta_known[1]), center_point)
    Δx = warp(x, tform, axes(x), fillvalue=Flat())
    @. x = Δx
    C = similar(y)
    cc = abs.(x-oldx)
    cm = max(1e-12,maximum(cc))
    c = 1 .* (1 .- cc./ cm) .^(10)
    C[1,:,:] .= c
    C[2,:,:] .= c
    y .= C.*y
end

# Method for rotation prediction (exploiting property of inverse rotation)
function petpdflow!(x, Δx, y, Δy, u, theta_known, flow :: Rotation; threads=:none)
    @backgroundif (threads==:outer) begin
        petflow!(x, Δx, u, theta_known; threads=(threads==:inner))
    end begin
        petflow!(@view(y[1, :, :]), @view(Δy[1, :, :]), u, -theta_known; threads=(threads==:inner))
        petflow!(@view(y[2, :, :]), @view(Δy[2, :, :]), u, -theta_known; threads=(threads==:inner))
    end
end

##########################
# Linearised optical flow
##########################

# ⟨⟨u, ∇b⟩⟩
function pointwise_gradiprod_2d!(y::Image, vtmp::Gradient,
                                 u::DisplacementFull, b::Image;
                                 add = false)
    ∇₂c!(vtmp, b)

    u′=reshape(u, (size(u, 1), prod(size(u)[2:end])))
    vtmp′=reshape(vtmp, (size(vtmp, 1), prod(size(vtmp)[2:end])))
    y′=reshape(y, prod(size(y)))

    if add
        @simd for i = 1:length(y′)
            @inbounds y′[i] += dot(@view(u′[:, i]), @view(vtmp′[:, i]))
        end
    else
        @simd for i = 1:length(y′)
            @inbounds y′[i] = dot(@view(u′[:, i]), @view(vtmp′[:, i]))
        end
    end
end

function pointwise_gradiprod_2d!(y::Image, vtmp::Gradient,
                                 u::DisplacementConstant, b::Image;
                                 add = false)
    ∇₂c!(vtmp, b)

    vtmp′=reshape(vtmp, (size(vtmp, 1), prod(size(vtmp)[2:end])))
    y′=reshape(y, prod(size(y)))

    if add
        @simd for i = 1:length(y′)
            @inbounds y′[i] += dot(u, @view(vtmp′[:, i]))
        end
    else
        @simd for i = 1:length(y′)
            @inbounds y′[i] = dot(u, @view(vtmp′[:, i]))
        end
    end
end

# ∇b ⋅ y
function pointwise_gradiprod_2dᵀ!(u::DisplacementFull, y::Image, b::Image)
    ∇₂c!(u, b)

    u′=reshape(u, (size(u, 1), prod(size(u)[2:end])))
    y′=reshape(y, prod(size(y)))

    @simd for i=1:length(y′)
        @inbounds @. u′[:, i] *= y′[i]
    end
end

function pointwise_gradiprod_2dᵀ!(u::DisplacementConstant, y::Image, b::Image)
    @assert(size(y)==size(b) && size(u)==(2,))
    u .= 0
    ∇₂cfold!(b, nothing) do g, st, (i, j)
        @inbounds u .+= g.*y[i, j]
        return st
    end
    # Reweight to be with respect to 𝟙^*𝟙 inner product.
    u ./= prod(size(b))
end

mutable struct ConstantDisplacementHornSchunckData
    M₀::Array{Float64,2}
    z::Array{Float64,1}
    Mv::Array{Float64,2}
    av::Array{Float64,1}
    cv::Float64

    function ConstantDisplacementHornSchunckData()
        return new(zeros(2, 2), zeros(2), zeros(2,2), zeros(2), 0)
    end
end

# For DisplacementConstant, for the simple prox step
#
# (1) argmin_u 1/(2τ)|u-ũ|^2 + (θ/2)|b⁺-b+<<u-ŭ,∇b>>|^2 + (λ/2)|u-ŭ|^2,
#
# construct matrix M₀ and vector z such that we can solve u from
#
# (2) (I/τ+M₀)u = M₀ŭ + ũ/τ - z
#
# Note that the problem
#
#    argmin_u 1/(2τ)|u-ũ|^2 + (θ/2)|b⁺-b+<<u-ŭ,∇b>>|^2 + (λ/2)|u-ŭ|^2
#                           + (θ/2)|b⁺⁺-b⁺+<<uʹ-u,∇b⁺>>|^2 + (λ/2)|u-uʹ|^2
#
# has with respect to u the system
#
#     (I/τ+M₀+M₀ʹ)u = M₀ŭ + M₀ʹuʹ + ũ/τ - z + zʹ,
#
# where the primed variables correspond to (2) for (1) for uʹ in place of u:
#
#    argmin_uʹ 1/(2τ)|uʹ-ũʹ|^2 + (θ/2)|b⁺⁺-b⁺+<<uʹ-u,∇b⁺>>|^2 + (λ/2)|uʹ-u|^2
#
function horn_schunck_reg_prox_op!(hs::ConstantDisplacementHornSchunckData,
                                   bnext::Image, b::Image, θ, λ, T)
    @assert(size(b)==size(bnext))
    w = prod(size(b))
    z = hs.z
    cv = 0
    # Factors of symmetric matrix [a c; c d]
    a, c, d = 0.0, 0.0, 0.0
    # This used to use  ∇₂cfold but it is faster to allocate temporary
    # storage for the full gradient due to probably better memory and SIMD
    # instruction usage. 
    g = zeros(2, size(b)...)
    ∇₂c!(g, b)
    @inbounds for i=1:size(b, 1)
        for j=1:size(b, 2)
            δ = bnext[i,j]-b[i,j]
            @. z += g[:,i,j]*δ
            cv += δ*δ
            a += g[1,i,j]*g[1,i,j]
            c += g[1,i,j]*g[2,i,j]
            d += g[2,i,j]*g[2,i,j]
        end
    end
    w₀ = λ
    w₂ = θ/w
    aʹ = w₀ + w₂*a
    cʹ = w₂*c
    dʹ = w₀ + w₂*d
    hs.M₀ .= [aʹ cʹ; cʹ dʹ]
    hs.Mv .= [w*λ+θ*a θ*c; θ*c w*λ+θ*d]
    hs.cv = cv*θ
    hs.av .= hs.z.*θ
    hs.z .*= w₂/T
end

# Solve the 2D system (I/τ+M₀)u = z
@inline function mldivide_step_plus_sym2x2!(u, M₀, z, τ)
    a = 1/τ+M₀[1, 1]
    c = M₀[1, 2]
    d = 1/τ+M₀[2, 2]
    u .= ([d -c; -c a]*z)./(a*d-c*c)
end

function horn_schunck_reg_prox!(u::DisplacementConstant, bnext::Image, b::Image,
                                θ, λ, T, τ)
    hs=ConstantDisplacementHornSchunckData()
    horn_schunck_reg_prox_op!(hs, bnext, b, θ, λ, T)
    mldivide_step_plus_sym2x2!(u, hs.M₀, (u./τ)-hs.z, τ)
end

function flow_grad!(x::Image, vtmp::Gradient, u::Displacement; δ=nothing)
    if !isnothing(δ)
        u = δ.*u
    end
    pointwise_gradiprod_2d!(x, vtmp, u, x; add=true)
end

# Error b-b_prev+⟨⟨u, ∇b⟩⟩ for Horn–Schunck type penalisation
function linearised_optical_flow_error(u::Displacement, b::Image, b_prev::Image)
    imdim = size(b)
    vtmp = zeros(2, imdim...)
    tmp = b-b_prev
    pointwise_gradiprod_2d!(tmp, vtmp, u, b_prev; add=true)
    return tmp
end

##############################################
# Helper to smooth data for Horn–Schunck term
##############################################

function filter_hs(b, b_next, b_next_filt, kernel)
    if kernel==nothing
        f = x -> x
    else
        f = x -> simple_imfilter(x, kernel; threads=true)
    end

    # We already filtered b in the previous step (b_next in that step) 
    b_filt = b_next_filt==nothing ? f(b) : b_next_filt
    b_next_filt = f(b_next)

    return b_filt, b_next_filt
end

end # Module

mercurial