src/PET/OpticalFlow.jl

changeset 36
e4a8f662a1ac
parent 35
74b1a9f0c35e
child 37
bba159cf1636
equal deleted inserted replaced
35:74b1a9f0c35e 36:e4a8f662a1ac
1 ################################
2 # Code relevant to optical flow
3 ################################
4
5 __precompile__()
6
7 module OpticalFlow
8
9 using AlgTools.Util
10 using ImageTools.Gradient
11 import ImageTools.Translate
12 using ImageTools.ImFilter
13
14 # using ImageTransformations
15 # using Images, CoordinateTransformations, Rotations, OffsetArrays
16 # using Interpolations
17
18 import Images: center, warp
19 import CoordinateTransformations: recenter
20 import Rotations: RotMatrix
21 import Interpolations: Flat
22
23 ##########
24 # Exports
25 ##########
26
27 export flow!,
28 pdflow!,
29 flow_grad!,
30 flow_interp!,
31 estimate_Λ²,
32 estimate_linear_Λ²,
33 pointwise_gradiprod_2d!,
34 pointwise_gradiprod_2dᵀ!,
35 horn_schunck_reg_prox!,
36 horn_schunck_reg_prox_op!,
37 mldivide_step_plus_sym2x2!,
38 linearised_optical_flow_error,
39 Image, AbstractImage, ImageSize,
40 Gradient, Displacement,
41 DisplacementFull, DisplacementConstant,
42 HornSchunckData,
43 filter_hs,
44 petpdflow!,
45 DualScaling, Greedy, Rotation
46
47 ###############################################
48 # Types (several imported from ImageTools.Translate)
49 ###############################################
50
51 Image = Translate.Image
52 AbstractImage = AbstractArray{Float64,2}
53 Displacement = Translate.Displacement
54 DisplacementFull = Translate.DisplacementFull
55 DisplacementConstant = Translate.DisplacementConstant
56 Gradient = Array{Float64,3}
57 ImageSize = Tuple{Int64,Int64}
58
59
60 #################################
61 # Struct for flow
62 #################################
63 struct DualScaling end
64 struct Greedy end
65 struct Rotation end
66
67 #################################
68 # Displacement field based flow
69 #################################
70
71 function flow_interp!(x::AbstractImage, u::Displacement, tmp::AbstractImage;
72 threads = false)
73 tmp .= x
74 Translate.translate_image!(x, tmp, u; threads=threads)
75 end
76
77 function flow_interp!(x::AbstractImage, u::Displacement;
78 threads = false)
79 tmp = copy(x)
80 Translate.translate_image!(x, tmp, u; threads=threads)
81 end
82
83 flow! = flow_interp!
84
85 function pdflow!(x, Δx, y, Δy, u, dual_flow; threads=:none)
86 if dual_flow
87 #flow!((x, @view(y[1, :, :]), @view(y[2, :, :])), diffu,
88 # (Δx, @view(Δy[1, :, :]), @view(Δy[2, :, :])))
89 @backgroundif (threads==:outer) begin
90 flow!(x, u, Δx; threads=(threads==:inner))
91 end begin
92 flow!(@view(y[1, :, :]), u, @view(Δy[1, :, :]); threads=(threads==:inner))
93 flow!(@view(y[2, :, :]), u, @view(Δy[2, :, :]); threads=(threads==:inner))
94 end
95 else
96 flow!(x, u, Δx)
97 end
98 end
99
100 function pdflow!(x, Δx, y, Δy, z, Δz, u, dual_flow; threads=:none)
101 if dual_flow
102 @backgroundif (threads==:outer) begin
103 flow!(x, u, Δx; threads=(threads==:inner))
104 flow!(z, u, Δz; threads=(threads==:inner))
105 end begin
106 flow!(@view(y[1, :, :]), u, @view(Δy[1, :, :]); threads=(threads==:inner))
107 flow!(@view(y[2, :, :]), u, @view(Δy[2, :, :]); threads=(threads==:inner))
108 end
109 else
110 flow!(x, u, Δx; threads=(threads==:inner))
111 flow!(z, u, Δz; threads=(threads==:inner))
112 end
113 end
114
115 # Additional method for Greedy
116 function pdflow!(x, Δx, y, Δy, u, flow :: Greedy; threads=:none)
117 @assert(size(u)==(2,))
118 Δx .= x
119 Δy .= y
120 flow!(x, u; threads=(threads==:inner))
121 Dxx = similar(Δy)
122 DΔx = similar(Δy)
123 ∇₂!(Dxx, x)
124 ∇₂!(DΔx, Δx)
125 inds = abs.(Dxx) .≤ 1e-1
126 Dxx[inds] .= 1
127 DΔx[inds] .= 1
128 y .= y.* DΔx ./ Dxx
129 end
130
131 # Additional method for Rotation
132 function pdflow!(x, Δx, y, Δy, u, flow :: Rotation; threads=:none)
133 @assert(size(u)==(2,))
134 Δx .= x
135 flow!(x, u; threads=(threads==:inner))
136 (m,n) = size(x)
137 dx = similar(y)
138 dx_banana = similar(y)
139 ∇₂!(dx, Δx)
140 ∇₂!(dx_banana, x)
141 for i=1:m
142 for j=1:n
143 ndx = @views sum(dx[:, i, j].^2)
144 ndx_banana = @views sum(dx_banana[:, i, j].^2)
145 if ndx > 1e-4 && ndx_banana > 1e-4
146 A = dx[:, i, j]
147 B = dx_banana[:, i, j]
148 theta = atan(B[1] * A[2] - B[2] * A[1], B[1] * A[1] + B[2] * A[2]) # Oriented angle from A to B
149 cos_theta = cos(theta)
150 sin_theta = sin(theta)
151 a = cos_theta * y[1, i, j] - sin_theta * y[2, i, j]
152 b = sin_theta * y[1, i, j] + cos_theta * y[2, i, j]
153 y[1, i, j] = a
154 y[2, i, j] = b
155 end
156 end
157 end
158 end
159
160 # Additional method for Dual Scaling
161 function pdflow!(x, Δx, y, Δy, u, flow :: DualScaling; threads=:none)
162 @assert(size(u)==(2,))
163 oldx = copy(x)
164 flow!(x, u; threads=(threads==:inner))
165 C = similar(y)
166 cc = abs.(x-oldx)
167 cm = max(1e-12,maximum(cc))
168 c = 1 .* (1 .- cc./ cm) .^(10)
169 C[1,:,:] .= c
170 C[2,:,:] .= c
171 y .= C.*y
172 end
173
174
175 ##########################
176 # PET
177 ##########################
178 function petflow_interp!(x::AbstractImage, tmp::AbstractImage, u::DisplacementConstant, theta_known::DisplacementConstant;
179 threads = false)
180 tmp .= x
181 center_point = center(x) .+ u
182 tform = recenter(RotMatrix(theta_known[1]), center_point)
183 tmp = warp(x, tform, axes(x), fillvalue=Flat())
184 x .= tmp
185 end
186
187 petflow! = petflow_interp!
188
189 function petpdflow!(x, Δx, y, Δy, u, theta_known, dual_flow; threads=:none)
190 if dual_flow
191 @backgroundif (threads==:outer) begin
192 petflow!(x, Δx, u, theta_known; threads=(threads==:inner))
193 end begin
194 petflow!(@view(y[1, :, :]), @view(Δy[1, :, :]), u, theta_known; threads=(threads==:inner))
195 petflow!(@view(y[2, :, :]), @view(Δy[2, :, :]), u, theta_known; threads=(threads==:inner))
196 end
197 else
198 petflow!(x, Δx, u, theta_known)
199 end
200 end
201
202 # Method for greedy predictor
203 function petpdflow!(x, Δx, y, Δy, u, theta_known, flow :: Greedy; threads=:none)
204 oldx = copy(x)
205 center_point = center(x) .+ u
206 tform = recenter(RotMatrix(theta_known[1]), center_point)
207 Δx = warp(x, tform, axes(x), fillvalue=Flat())
208 @. x = Δx
209 @. Δy = y
210 Dxx = copy(Δy)
211 DΔx = copy(Δy)
212 ∇₂!(Dxx, x)
213 ∇₂!(DΔx, oldx)
214 inds = abs.(Dxx) .≤ 1e-2
215 Dxx[inds] .= 1
216 DΔx[inds] .= 1
217 y .= y.* DΔx ./ Dxx
218 end
219
220 # Method for dual scaling predictor
221 function petpdflow!(x, Δx, y, Δy, u, theta_known, flow :: DualScaling; threads=:none)
222 oldx = copy(x)
223 center_point = center(x) .+ u
224 tform = recenter(RotMatrix(theta_known[1]), center_point)
225 Δx = warp(x, tform, axes(x), fillvalue=Flat())
226 @. x = Δx
227 C = similar(y)
228 cc = abs.(x-oldx)
229 cm = max(1e-12,maximum(cc))
230 c = 1 .* (1 .- cc./ cm) .^(10)
231 C[1,:,:] .= c
232 C[2,:,:] .= c
233 y .= C.*y
234 end
235
236 # Method for rotation prediction (exploiting property of inverse rotation)
237 function petpdflow!(x, Δx, y, Δy, u, theta_known, flow :: Rotation; threads=:none)
238 @backgroundif (threads==:outer) begin
239 petflow!(x, Δx, u, theta_known; threads=(threads==:inner))
240 end begin
241 petflow!(@view(y[1, :, :]), @view(Δy[1, :, :]), u, -theta_known; threads=(threads==:inner))
242 petflow!(@view(y[2, :, :]), @view(Δy[2, :, :]), u, -theta_known; threads=(threads==:inner))
243 end
244 end
245
246 ##########################
247 # Linearised optical flow
248 ##########################
249
250 # ⟨⟨u, ∇b⟩⟩
251 function pointwise_gradiprod_2d!(y::Image, vtmp::Gradient,
252 u::DisplacementFull, b::Image;
253 add = false)
254 ∇₂c!(vtmp, b)
255
256 u′=reshape(u, (size(u, 1), prod(size(u)[2:end])))
257 vtmp′=reshape(vtmp, (size(vtmp, 1), prod(size(vtmp)[2:end])))
258 y′=reshape(y, prod(size(y)))
259
260 if add
261 @simd for i = 1:length(y′)
262 @inbounds y′[i] += dot(@view(u′[:, i]), @view(vtmp′[:, i]))
263 end
264 else
265 @simd for i = 1:length(y′)
266 @inbounds y′[i] = dot(@view(u′[:, i]), @view(vtmp′[:, i]))
267 end
268 end
269 end
270
271 function pointwise_gradiprod_2d!(y::Image, vtmp::Gradient,
272 u::DisplacementConstant, b::Image;
273 add = false)
274 ∇₂c!(vtmp, b)
275
276 vtmp′=reshape(vtmp, (size(vtmp, 1), prod(size(vtmp)[2:end])))
277 y′=reshape(y, prod(size(y)))
278
279 if add
280 @simd for i = 1:length(y′)
281 @inbounds y′[i] += dot(u, @view(vtmp′[:, i]))
282 end
283 else
284 @simd for i = 1:length(y′)
285 @inbounds y′[i] = dot(u, @view(vtmp′[:, i]))
286 end
287 end
288 end
289
290 # ∇b ⋅ y
291 function pointwise_gradiprod_2dᵀ!(u::DisplacementFull, y::Image, b::Image)
292 ∇₂c!(u, b)
293
294 u′=reshape(u, (size(u, 1), prod(size(u)[2:end])))
295 y′=reshape(y, prod(size(y)))
296
297 @simd for i=1:length(y′)
298 @inbounds @. u′[:, i] *= y′[i]
299 end
300 end
301
302 function pointwise_gradiprod_2dᵀ!(u::DisplacementConstant, y::Image, b::Image)
303 @assert(size(y)==size(b) && size(u)==(2,))
304 u .= 0
305 ∇₂cfold!(b, nothing) do g, st, (i, j)
306 @inbounds u .+= g.*y[i, j]
307 return st
308 end
309 # Reweight to be with respect to 𝟙^*𝟙 inner product.
310 u ./= prod(size(b))
311 end
312
313 mutable struct ConstantDisplacementHornSchunckData
314 M₀::Array{Float64,2}
315 z::Array{Float64,1}
316 Mv::Array{Float64,2}
317 av::Array{Float64,1}
318 cv::Float64
319
320 function ConstantDisplacementHornSchunckData()
321 return new(zeros(2, 2), zeros(2), zeros(2,2), zeros(2), 0)
322 end
323 end
324
325 # For DisplacementConstant, for the simple prox step
326 #
327 # (1) argmin_u 1/(2τ)|u-ũ|^2 + (θ/2)|b⁺-b+<<u-ŭ,∇b>>|^2 + (λ/2)|u-ŭ|^2,
328 #
329 # construct matrix M₀ and vector z such that we can solve u from
330 #
331 # (2) (I/τ+M₀)u = M₀ŭ + ũ/τ - z
332 #
333 # Note that the problem
334 #
335 # argmin_u 1/(2τ)|u-ũ|^2 + (θ/2)|b⁺-b+<<u-ŭ,∇b>>|^2 + (λ/2)|u-ŭ|^2
336 # + (θ/2)|b⁺⁺-b⁺+<<uʹ-u,∇b⁺>>|^2 + (λ/2)|u-uʹ|^2
337 #
338 # has with respect to u the system
339 #
340 # (I/τ+M₀+M₀ʹ)u = M₀ŭ + M₀ʹuʹ + ũ/τ - z + zʹ,
341 #
342 # where the primed variables correspond to (2) for (1) for uʹ in place of u:
343 #
344 # argmin_uʹ 1/(2τ)|uʹ-ũʹ|^2 + (θ/2)|b⁺⁺-b⁺+<<uʹ-u,∇b⁺>>|^2 + (λ/2)|uʹ-u|^2
345 #
346 function horn_schunck_reg_prox_op!(hs::ConstantDisplacementHornSchunckData,
347 bnext::Image, b::Image, θ, λ, T)
348 @assert(size(b)==size(bnext))
349 w = prod(size(b))
350 z = hs.z
351 cv = 0
352 # Factors of symmetric matrix [a c; c d]
353 a, c, d = 0.0, 0.0, 0.0
354 # This used to use ∇₂cfold but it is faster to allocate temporary
355 # storage for the full gradient due to probably better memory and SIMD
356 # instruction usage.
357 g = zeros(2, size(b)...)
358 ∇₂c!(g, b)
359 @inbounds for i=1:size(b, 1)
360 for j=1:size(b, 2)
361 δ = bnext[i,j]-b[i,j]
362 @. z += g[:,i,j]*δ
363 cv += δ*δ
364 a += g[1,i,j]*g[1,i,j]
365 c += g[1,i,j]*g[2,i,j]
366 d += g[2,i,j]*g[2,i,j]
367 end
368 end
369 w₀ = λ
370 w₂ = θ/w
371 aʹ = w₀ + w₂*a
372 cʹ = w₂*c
373 dʹ = w₀ + w₂*d
374 hs.M₀ .= [aʹ cʹ; cʹ dʹ]
375 hs.Mv .= [w*λ+θ*a θ*c; θ*c w*λ+θ*d]
376 hs.cv = cv*θ
377 hs.av .= hs.z.*θ
378 hs.z .*= w₂/T
379 end
380
381 # Solve the 2D system (I/τ+M₀)u = z
382 @inline function mldivide_step_plus_sym2x2!(u, M₀, z, τ)
383 a = 1/τ+M₀[1, 1]
384 c = M₀[1, 2]
385 d = 1/τ+M₀[2, 2]
386 u .= ([d -c; -c a]*z)./(a*d-c*c)
387 end
388
389 function horn_schunck_reg_prox!(u::DisplacementConstant, bnext::Image, b::Image,
390 θ, λ, T, τ)
391 hs=ConstantDisplacementHornSchunckData()
392 horn_schunck_reg_prox_op!(hs, bnext, b, θ, λ, T)
393 mldivide_step_plus_sym2x2!(u, hs.M₀, (u./τ)-hs.z, τ)
394 end
395
396 function flow_grad!(x::Image, vtmp::Gradient, u::Displacement; δ=nothing)
397 if !isnothing(δ)
398 u = δ.*u
399 end
400 pointwise_gradiprod_2d!(x, vtmp, u, x; add=true)
401 end
402
403 # Error b-b_prev+⟨⟨u, ∇b⟩⟩ for Horn–Schunck type penalisation
404 function linearised_optical_flow_error(u::Displacement, b::Image, b_prev::Image)
405 imdim = size(b)
406 vtmp = zeros(2, imdim...)
407 tmp = b-b_prev
408 pointwise_gradiprod_2d!(tmp, vtmp, u, b_prev; add=true)
409 return tmp
410 end
411
412 ##############################################
413 # Helper to smooth data for Horn–Schunck term
414 ##############################################
415
416 function filter_hs(b, b_next, b_next_filt, kernel)
417 if kernel==nothing
418 f = x -> x
419 else
420 f = x -> simple_imfilter(x, kernel; threads=true)
421 end
422
423 # We already filtered b in the previous step (b_next in that step)
424 b_filt = b_next_filt==nothing ? f(b) : b_next_filt
425 b_next_filt = f(b_next)
426
427 return b_filt, b_next_filt
428 end
429
430 end # Module

mercurial