src/PET/OpticalFlow.jl

changeset 8
e4ad8f7ce671
child 26
ccd22bbbb02f
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/PET/OpticalFlow.jl	Fri Apr 19 17:00:37 2024 +0300
@@ -0,0 +1,431 @@
+################################
+# Code relevant to optical flow
+################################
+
+__precompile__()
+
+module OpticalFlow
+
+using AlgTools.Util
+using ImageTools.Gradient
+import ImageTools.Translate
+using ImageTools.ImFilter
+
+# using ImageTransformations
+# using Images, CoordinateTransformations, Rotations, OffsetArrays
+# using Interpolations
+
+import Images: center, warp
+import CoordinateTransformations: recenter
+import Rotations: RotMatrix
+import Interpolations: Flat
+
+##########
+# Exports
+##########
+
+export flow!,
+       pdflow!,
+       flow_grad!,
+       flow_interp!,
+       estimate_Λ²,
+       estimate_linear_Λ²,
+       pointwise_gradiprod_2d!,
+       pointwise_gradiprod_2dᵀ!,
+       horn_schunck_reg_prox!,
+       horn_schunck_reg_prox_op!,
+       mldivide_step_plus_sym2x2!,
+       linearised_optical_flow_error,
+       Image, AbstractImage, ImageSize,
+       Gradient, Displacement,
+       DisplacementFull, DisplacementConstant,
+       HornSchunckData,
+       filter_hs,
+       petpdflow!
+
+###############################################
+# Types (several imported from ImageTools.Translate)
+###############################################
+
+Image = Translate.Image
+AbstractImage = AbstractArray{Float64,2}
+Displacement = Translate.Displacement
+DisplacementFull = Translate.DisplacementFull
+DisplacementConstant = Translate.DisplacementConstant
+Gradient = Array{Float64,3}
+ImageSize = Tuple{Int64,Int64}
+
+#################################
+# Displacement field based flow
+#################################
+
+function flow_interp!(x::AbstractImage, u::Displacement, tmp::AbstractImage;
+                      threads = false)
+    tmp .= x
+    Translate.translate_image!(x, tmp, u; threads=threads)
+end
+
+function flow_interp!(x::AbstractImage, u::Displacement;
+                      threads = false)
+    tmp = copy(x)
+    Translate.translate_image!(x, tmp, u; threads=threads)
+end
+
+flow! = flow_interp!
+
+function pdflow!(x, Δx, y, Δy, u, dual_flow; threads=:none)
+    if dual_flow
+        #flow!((x, @view(y[1, :, :]), @view(y[2, :, :])), diffu,
+        #      (Δx, @view(Δy[1, :, :]), @view(Δy[2, :, :])))
+        @backgroundif (threads==:outer) begin
+            flow!(x, u, Δx; threads=(threads==:inner))
+        end begin
+            flow!(@view(y[1, :, :]), u, @view(Δy[1, :, :]); threads=(threads==:inner))
+            flow!(@view(y[2, :, :]), u, @view(Δy[2, :, :]); threads=(threads==:inner))
+        end
+    else
+        flow!(x, u, Δx)
+    end
+end
+
+function pdflow!(x, Δx, y, Δy, z, Δz, u, dual_flow; threads=:none)
+    if dual_flow
+        @backgroundif (threads==:outer) begin
+            flow!(x, u, Δx; threads=(threads==:inner))
+            flow!(z, u, Δz; threads=(threads==:inner))
+        end begin
+            flow!(@view(y[1, :, :]), u, @view(Δy[1, :, :]); threads=(threads==:inner))
+            flow!(@view(y[2, :, :]), u, @view(Δy[2, :, :]); threads=(threads==:inner))
+        end
+    else
+        flow!(x, u, Δx; threads=(threads==:inner))
+        flow!(z, u, Δz; threads=(threads==:inner))
+    end
+end
+
+# Additional method for Greedy
+function pdflow!(x, Δx, y, Δy, u; threads=:none)
+    @assert(size(u)==(2,))  
+    Δx .= x
+    Δy .= y
+    flow!(x, u; threads=(threads==:inner))
+    Dxx = similar(Δy)
+    DΔx = similar(Δy)
+    ∇₂!(Dxx, x)
+    ∇₂!(DΔx, Δx)
+    inds = abs.(Dxx) .≤ 1e-1
+    Dxx[inds] .= 1
+    DΔx[inds] .= 1
+    y .= y.* DΔx ./ Dxx  
+end
+
+# Additional method for Rotation
+function pdflow!(x, Δx, y, u; threads=:none) 
+    @assert(size(u)==(2,))
+    Δx .= x
+    flow!(x, u; threads=(threads==:inner))
+
+    (m,n) = size(x)
+    dx = similar(y)
+    dx_banana = similar(y)
+    ∇₂!(dx, Δx)
+    ∇₂!(dx_banana, x)
+
+    for i=1:m
+        for j=1:n
+            ndx = @views sum(dx[:, i, j].^2)
+            ndx_banana = @views sum(dx_banana[:, i, j].^2)
+            if ndx > 1e-4 && ndx_banana > 1e-4
+                A = dx[:, i, j]
+                B = dx_banana[:, i, j]
+                theta = atan(B[1] * A[2] - B[2] * A[1], B[1] * A[1] + B[2] * A[2]) # Oriented angle from A to B
+                cos_theta = cos(theta)
+                sin_theta = sin(theta)
+                a = cos_theta * y[1, i, j] - sin_theta * y[2, i, j]
+                b = sin_theta * y[1, i, j] + cos_theta * y[2, i, j]
+                y[1, i, j] = a
+                y[2, i, j] = b
+            end
+        end
+    end
+end
+
+# Additional method for Dual Scaling
+function pdflow!(x, y, u; threads=:none)
+    @assert(size(u)==(2,))
+    oldx = copy(x)
+    flow!(x, u; threads=(threads==:inner))
+    C = similar(y)
+    cc = abs.(x-oldx)
+    cm = max(1e-12,maximum(cc))
+    c = 1 .* (1 .- cc./ cm) .^(10)
+    C[1,:,:] .= c
+    C[2,:,:] .= c
+    y .= C.*y 
+end
+
+
+##########################
+# PET
+##########################
+function petflow_interp!(x::AbstractImage, tmp::AbstractImage, u::DisplacementConstant, theta_known::DisplacementConstant;
+    threads = false)
+    tmp .= x
+    center_point = center(x) .+ u
+    tform = recenter(RotMatrix(theta_known[1]), center_point)
+    tmp = warp(x, tform, axes(x), fillvalue=Flat())
+    x .= tmp
+end
+
+petflow! = petflow_interp!
+
+function petpdflow!(x, Δx, y, Δy, u, theta_known, dual_flow; threads=:none)
+    if dual_flow
+        @backgroundif (threads==:outer) begin
+        petflow!(x, Δx, u, theta_known; threads=(threads==:inner))
+    end begin
+        petflow!(@view(y[1, :, :]), @view(Δy[1, :, :]), u, theta_known; threads=(threads==:inner))
+        petflow!(@view(y[2, :, :]), @view(Δy[2, :, :]), u, theta_known; threads=(threads==:inner))
+        end
+    else
+    petflow!(x, Δx, u, theta_known)
+    end
+end
+
+# Method for greedy predictor
+function petpdflow!(x, Δx, y, Δy, u, theta_known, dual_flow, β; threads=:none)
+    oldx = copy(x)
+    center_point = center(x) .+ u
+    tform = recenter(RotMatrix(theta_known[1]), center_point)
+    Δx = warp(x, tform, axes(x), fillvalue=Flat())
+    @. x = Δx
+    @. Δy = y
+    if dual_flow
+        Dxx = copy(Δy)
+        DΔx = copy(Δy)
+        ∇₂!(Dxx, x)
+        ∇₂!(DΔx, oldx)
+        inds = abs.(Dxx) .≤ β
+        Dxx[inds] .= 1
+        DΔx[inds] .= 1
+        y .= y.* DΔx ./ Dxx            
+    end
+end
+
+# Method for affine predictor
+function petpdflow!(x, Δx, y, u, theta_known, dual_flow; threads=:none)
+    oldx = copy(x)
+    center_point = center(x) .+ u
+    tform = recenter(RotMatrix(theta_known[1]), center_point)
+    Δx = warp(x, tform, axes(x), fillvalue=Flat())
+    @. x = Δx
+    C = similar(y)
+    cc = abs.(x-oldx)
+    if dual_flow
+        cm = max(1e-12,maximum(cc))
+        c = 1 .* (1 .- cc./ cm) .^(10)
+        C[1,:,:] .= c
+        C[2,:,:] .= c
+        y .= C.*y 
+    end
+end
+
+# Method for rotation prediction (exploiting property of inverse rotation)
+function petpdflow!(x, Δx, y, Δy, u, theta_known, dual_flow, β₁, β₂; threads=:none)
+    if dual_flow
+        @backgroundif (threads==:outer) begin
+            petflow!(x, Δx, u, theta_known; threads=(threads==:inner))
+        end begin
+            petflow!(@view(y[1, :, :]), @view(Δy[1, :, :]), u, -theta_known; threads=(threads==:inner))
+            petflow!(@view(y[2, :, :]), @view(Δy[2, :, :]), u, -theta_known; threads=(threads==:inner))
+        end
+    else
+        petflow!(x, Δx, u, theta_known)
+    end
+end
+
+##########################
+# Linearised optical flow
+##########################
+
+# ⟨⟨u, ∇b⟩⟩
+function pointwise_gradiprod_2d!(y::Image, vtmp::Gradient,
+                                 u::DisplacementFull, b::Image;
+                                 add = false)
+    ∇₂c!(vtmp, b)
+
+    u′=reshape(u, (size(u, 1), prod(size(u)[2:end])))
+    vtmp′=reshape(vtmp, (size(vtmp, 1), prod(size(vtmp)[2:end])))
+    y′=reshape(y, prod(size(y)))
+
+    if add
+        @simd for i = 1:length(y′)
+            @inbounds y′[i] += dot(@view(u′[:, i]), @view(vtmp′[:, i]))
+        end
+    else
+        @simd for i = 1:length(y′)
+            @inbounds y′[i] = dot(@view(u′[:, i]), @view(vtmp′[:, i]))
+        end
+    end
+end
+
+function pointwise_gradiprod_2d!(y::Image, vtmp::Gradient,
+                                 u::DisplacementConstant, b::Image;
+                                 add = false)
+    ∇₂c!(vtmp, b)
+
+    vtmp′=reshape(vtmp, (size(vtmp, 1), prod(size(vtmp)[2:end])))
+    y′=reshape(y, prod(size(y)))
+
+    if add
+        @simd for i = 1:length(y′)
+            @inbounds y′[i] += dot(u, @view(vtmp′[:, i]))
+        end
+    else
+        @simd for i = 1:length(y′)
+            @inbounds y′[i] = dot(u, @view(vtmp′[:, i]))
+        end
+    end
+end
+
+# ∇b ⋅ y
+function pointwise_gradiprod_2dᵀ!(u::DisplacementFull, y::Image, b::Image)
+    ∇₂c!(u, b)
+
+    u′=reshape(u, (size(u, 1), prod(size(u)[2:end])))
+    y′=reshape(y, prod(size(y)))
+
+    @simd for i=1:length(y′)
+        @inbounds @. u′[:, i] *= y′[i]
+    end
+end
+
+function pointwise_gradiprod_2dᵀ!(u::DisplacementConstant, y::Image, b::Image)
+    @assert(size(y)==size(b) && size(u)==(2,))
+    u .= 0
+    ∇₂cfold!(b, nothing) do g, st, (i, j)
+        @inbounds u .+= g.*y[i, j]
+        return st
+    end
+    # Reweight to be with respect to 𝟙^*𝟙 inner product.
+    u ./= prod(size(b))
+end
+
+mutable struct ConstantDisplacementHornSchunckData
+    M₀::Array{Float64,2}
+    z::Array{Float64,1}
+    Mv::Array{Float64,2}
+    av::Array{Float64,1}
+    cv::Float64
+
+    function ConstantDisplacementHornSchunckData()
+        return new(zeros(2, 2), zeros(2), zeros(2,2), zeros(2), 0)
+    end
+end
+
+# For DisplacementConstant, for the simple prox step
+#
+# (1) argmin_u 1/(2τ)|u-ũ|^2 + (θ/2)|b⁺-b+<<u-ŭ,∇b>>|^2 + (λ/2)|u-ŭ|^2,
+#
+# construct matrix M₀ and vector z such that we can solve u from
+#
+# (2) (I/τ+M₀)u = M₀ŭ + ũ/τ - z
+#
+# Note that the problem
+#
+#    argmin_u 1/(2τ)|u-ũ|^2 + (θ/2)|b⁺-b+<<u-ŭ,∇b>>|^2 + (λ/2)|u-ŭ|^2
+#                           + (θ/2)|b⁺⁺-b⁺+<<uʹ-u,∇b⁺>>|^2 + (λ/2)|u-uʹ|^2
+#
+# has with respect to u the system
+#
+#     (I/τ+M₀+M₀ʹ)u = M₀ŭ + M₀ʹuʹ + ũ/τ - z + zʹ,
+#
+# where the primed variables correspond to (2) for (1) for uʹ in place of u:
+#
+#    argmin_uʹ 1/(2τ)|uʹ-ũʹ|^2 + (θ/2)|b⁺⁺-b⁺+<<uʹ-u,∇b⁺>>|^2 + (λ/2)|uʹ-u|^2
+#
+function horn_schunck_reg_prox_op!(hs::ConstantDisplacementHornSchunckData,
+                                   bnext::Image, b::Image, θ, λ, T)
+    @assert(size(b)==size(bnext))
+    w = prod(size(b))
+    z = hs.z
+    cv = 0
+    # Factors of symmetric matrix [a c; c d]
+    a, c, d = 0.0, 0.0, 0.0
+    # This used to use  ∇₂cfold but it is faster to allocate temporary
+    # storage for the full gradient due to probably better memory and SIMD
+    # instruction usage. 
+    g = zeros(2, size(b)...)
+    ∇₂c!(g, b)
+    @inbounds for i=1:size(b, 1)
+        for j=1:size(b, 2)
+            δ = bnext[i,j]-b[i,j]
+            @. z += g[:,i,j]*δ
+            cv += δ*δ
+            a += g[1,i,j]*g[1,i,j]
+            c += g[1,i,j]*g[2,i,j]
+            d += g[2,i,j]*g[2,i,j]
+        end
+    end
+    w₀ = λ
+    w₂ = θ/w
+    aʹ = w₀ + w₂*a
+    cʹ = w₂*c
+    dʹ = w₀ + w₂*d
+    hs.M₀ .= [aʹ cʹ; cʹ dʹ]
+    hs.Mv .= [w*λ+θ*a θ*c; θ*c w*λ+θ*d]
+    hs.cv = cv*θ
+    hs.av .= hs.z.*θ
+    hs.z .*= w₂/T
+end
+
+# Solve the 2D system (I/τ+M₀)u = z
+@inline function mldivide_step_plus_sym2x2!(u, M₀, z, τ)
+    a = 1/τ+M₀[1, 1]
+    c = M₀[1, 2]
+    d = 1/τ+M₀[2, 2]
+    u .= ([d -c; -c a]*z)./(a*d-c*c)
+end
+
+function horn_schunck_reg_prox!(u::DisplacementConstant, bnext::Image, b::Image,
+                                θ, λ, T, τ)
+    hs=ConstantDisplacementHornSchunckData()
+    horn_schunck_reg_prox_op!(hs, bnext, b, θ, λ, T)
+    mldivide_step_plus_sym2x2!(u, hs.M₀, (u./τ)-hs.z, τ)
+end
+
+function flow_grad!(x::Image, vtmp::Gradient, u::Displacement; δ=nothing)
+    if !isnothing(δ)
+        u = δ.*u
+    end
+    pointwise_gradiprod_2d!(x, vtmp, u, x; add=true)
+end
+
+# Error b-b_prev+⟨⟨u, ∇b⟩⟩ for Horn–Schunck type penalisation
+function linearised_optical_flow_error(u::Displacement, b::Image, b_prev::Image)
+    imdim = size(b)
+    vtmp = zeros(2, imdim...)
+    tmp = b-b_prev
+    pointwise_gradiprod_2d!(tmp, vtmp, u, b_prev; add=true)
+    return tmp
+end
+
+##############################################
+# Helper to smooth data for Horn–Schunck term
+##############################################
+
+function filter_hs(b, b_next, b_next_filt, kernel)
+    if kernel==nothing
+        f = x -> x
+    else
+        f = x -> simple_imfilter(x, kernel; threads=true)
+    end
+
+    # We already filtered b in the previous step (b_next in that step) 
+    b_filt = b_next_filt==nothing ? f(b) : b_next_filt
+    b_next_filt = f(b_next)
+
+    return b_filt, b_next_filt
+end
+
+end # Module

mercurial