src/AlgorithmBothNL.jl

changeset 0
a55e35d20336
equal deleted inserted replaced
-1:000000000000 0:a55e35d20336
1 ######################################################################
2 # Predictive online PDPS for optical flow with unknown velocity field
3 ######################################################################
4
5 __precompile__()
6
7 module AlgorithmBothNL
8
9 identifier = "pdps_unknown_nl"
10
11
12 using Printf
13
14 using AlgTools.Util
15 import AlgTools.Iterate
16 using ImageTools.Gradient
17
18 using ..OpticalFlow: ImageSize,
19 Image,
20 Gradient,
21 DisplacementConstant,
22 DisplacementFull,
23 pdflow!,
24 pointwise_gradiprod_2d!,
25 pointwise_gradiprod_2dᵀ!,
26 filter_hs
27
28 using ..Algorithm: step_lengths
29
30 #############
31 # Data types
32 #############
33
34 struct Primal{DisplacementT}
35 x :: Image
36 u :: DisplacementT
37 end
38
39 function Base.similar(x::Primal{DisplacementT}) where DisplacementT
40 return Primal{DisplacementT}(Base.similar(x.x), Base.similar(x.u))
41 end
42
43 function Base.copy(x::Primal{DisplacementT}) where DisplacementT
44 return Primal{DisplacementT}(Base.copy(x.x), Base.copy(x.u))
45 end
46
47 struct Dual
48 tv :: Gradient
49 flow :: Image
50 end
51
52 function Base.similar(y::Dual)
53 return Dual(Base.similar(y.tv), Base.similar(y.flow))
54 end
55
56 function Base.copy(y::Dual)
57 return Dual(Base.copy(y.tv), Base.copy(y.flow))
58 end
59
60 #########################
61 # Iterate initialisation
62 #########################
63
64 function init_primal(xinit::Image, ::Type{DisplacementConstant})
65 return Primal{DisplacementConstant}(xinit, zeros(2))
66 end
67
68 function init_primal(xinit::Image, ::Type{DisplacementFull})
69 return Primal{DisplacementFull}(xinit, zeros(2, size(xinit)...))
70 end
71
72 function init_rest(x::Primal{DisplacementT}) where DisplacementT
73 imdim=size(x.x)
74
75 y = Dual(zeros(2, imdim...), zeros(imdim))
76 Δx = copy(x)
77 Δy = copy(y)
78 x̄ = copy(x)
79
80 return x, y, Δx, Δy, x̄
81 end
82
83 function init_iterates( :: Type{DisplacementT},
84 xinit::Primal{DisplacementT}) where DisplacementT
85 return init_rest(copy(xinit))
86 end
87
88 function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT
89 return init_rest(init_primal(copy(xinit), DisplacementT))
90 end
91
92 function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT
93 return init_rest(init_primal(zeros(dim...), DisplacementT))
94 end
95
96 ##############################################
97 # Weighting for different displacements types
98 ##############################################
99
100 norm²weight( :: Type{DisplacementConstant}, sz ) = prod(sz)
101 norm²weight( :: Type{DisplacementFull}, sz ) = 1
102
103 ############
104 # Algorithm
105 ############
106
107 function solve( :: Type{DisplacementT};
108 dim :: ImageSize,
109 iterate = AlgTools.simple_iterate,
110 params::NamedTuple) where DisplacementT
111
112 ######################
113 # Initialise iterates
114 ######################
115
116 x, y, Δx, Δy, x̄ = init_iterates(DisplacementT, dim)
117 init_data = (params.init == :data)
118
119 # … for tracking cumulative movement
120 if DisplacementT == DisplacementConstant
121 ucumul = [0.0, 0.0]
122 else
123 ucumul = [NaN, NaN]
124 end
125
126 #############################################
127 # Extract parameters and set up step lengths
128 #############################################
129
130 α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep
131 R_K² = max(∇₂_norm₂₂_est², ∇₂_norm₂∞_est²*params.dynrange^2)
132 γ = min(1, λ*norm²weight(DisplacementT, size(x.x)))
133 τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²)
134
135 kernel = params.kernel
136
137 ####################
138 # Run the algorithm
139 ####################
140
141 b_next_filt=nothing
142
143 v = iterate(params) do verbose :: Function,
144 b :: Image,
145 🚫unused_v_known :: DisplacementT,
146 b_next :: Image
147
148 ####################################
149 # Smooth data for Horn–Schunck term
150 ####################################
151
152 b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel)
153
154 ############################
155 # Construct K for this step
156 ############################
157
158 K! = (yʹ, xʹ) -> begin
159 # Optical flow part
160 @. yʹ.flow = b_filt
161 flow!(yʹ.flow, Δx.x, xʹ.u)
162 @. yʹ.flow = yʹ.flow - b_next_filt
163 # TV part
164 ∇₂!(yʹ.tv, xʹ.x)
165 end
166 Kᵀ! = (xʹ, yʹ) -> begin
167 # Optical flow part: ∇b_k ⋅ y
168 #
169 # TODO: This really should depend x.u, but x.u is zero.
170 #
171 pointwise_gradiprod_2dᵀ!(xʹ.u, yʹ.flow, b_filt)
172 # TV part
173 ∇₂ᵀ!(xʹ.x, yʹ.tv)
174 end
175
176 ##################
177 # Prediction step
178 ##################
179
180 if init_data
181 x .= b
182 init_data = false
183 end
184
185 pdflow!(x.x, Δx.x, y.tv, Δy.tv, y.flow, Δy.flow, x.u, params.dual_flow)
186
187 # Predict zero displacement
188 x.u .= 0
189 if params.prox_predict
190 K!(Δy, x)
191 @. y.tv = (y.tv + σ̃*Δy.tv)/(1 + σ̃*(ρ̃+ρ/α))
192 proj_norm₂₁ball!(y.tv, α)
193 @. y.flow = (y.flow+σ̃*Δy.flow)/(1+σ̃*(ρ̃+1/θ))
194 end
195
196 ############
197 # PDPS step
198 #
199 # NOTE: For DisplacementConstant, the x.u update is supposed to be with
200 # respect to the 𝟙^*𝟙 norm/inner product that makes the norm equivalent
201 # to full-space norm when restricted to constant displacements. Since
202 # `OpticalFlow.pointwise_gradiprod_2dᵀ!` already uses this inner product,
203 # and the λ-weighted term in the problem is with respect to this norm,
204 # all the norm weights disappear in this update.
205 ############
206
207 Kᵀ!(Δx, y) # primal step:
208 @. x̄.x = x.x # | save old x for over-relax
209 @. x̄.u = x.u # |
210 @. x.x = (x.x-τ*(Δx.x-b))/(1+τ) # | prox
211 @. x.u = (x.u-τ*Δx.u)/(1+τ*λ) # |
212 @. x̄.x = 2x.x - x̄.x # over-relax
213 @. x̄.u = 2x.u - x̄.u # |
214 K!(Δy, x̄) # dual step: y
215 @. y.tv = (y.tv + σ*Δy.tv)/(1 + σ*ρ/α) # |
216 proj_norm₂₁ball!(y.tv, α) # | prox
217 @. y.flow = (y.flow+σ*Δy.flow)/(1+σ/θ)
218
219 if DisplacementT == DisplacementConstant
220 ucumul .+= x.u
221 end
222
223 ########################################################
224 # Give function value and cumulative movement if needed
225 ########################################################
226 v = verbose() do
227 K!(Δy, x)
228 value = (norm₂²(b-x.x)/2 + θ*norm₂²(Δy.flow)
229 + λ*norm₂²(x.u)/2 + α*γnorm₂₁(Δy.tv, ρ))
230
231 value, x.x, ucumul, nothing
232 end
233
234 return v
235 end
236
237 return x, y, v
238 end
239
240 end # Module
241
242

mercurial