|
1 ###################################################################### |
|
2 # Predictive online PDPS for optical flow with unknown velocity field |
|
3 ###################################################################### |
|
4 |
|
5 __precompile__() |
|
6 |
|
7 module AlgorithmBoth |
|
8 |
|
9 identifier = "pdps_unknown_basic" |
|
10 |
|
11 using Printf |
|
12 |
|
13 using AlgTools.Util |
|
14 import AlgTools.Iterate |
|
15 using ImageTools.Gradient |
|
16 |
|
17 using ..OpticalFlow: ImageSize, |
|
18 Image, |
|
19 Gradient, |
|
20 DisplacementConstant, |
|
21 DisplacementFull, |
|
22 pdflow!, |
|
23 pointwise_gradiprod_2d!, |
|
24 pointwise_gradiprod_2dᵀ!, |
|
25 filter_hs |
|
26 |
|
27 using ..Algorithm: step_lengths |
|
28 |
|
29 ############# |
|
30 # Data types |
|
31 ############# |
|
32 |
|
33 struct Primal{DisplacementT} |
|
34 x :: Image |
|
35 u :: DisplacementT |
|
36 end |
|
37 |
|
38 function Base.similar(x::Primal{DisplacementT}) where DisplacementT |
|
39 return Primal{DisplacementT}(Base.similar(x.x), Base.similar(x.u)) |
|
40 end |
|
41 |
|
42 function Base.copy(x::Primal{DisplacementT}) where DisplacementT |
|
43 return Primal{DisplacementT}(Base.copy(x.x), Base.copy(x.u)) |
|
44 end |
|
45 |
|
46 struct Dual |
|
47 tv :: Gradient |
|
48 flow :: Image |
|
49 end |
|
50 |
|
51 function Base.similar(y::Dual) |
|
52 return Dual(Base.similar(y.tv), Base.similar(y.flow)) |
|
53 end |
|
54 |
|
55 function Base.copy(y::Dual) |
|
56 return Dual(Base.copy(y.tv), Base.copy(y.flow)) |
|
57 end |
|
58 |
|
59 ######################### |
|
60 # Iterate initialisation |
|
61 ######################### |
|
62 |
|
63 function init_primal(xinit::Image, ::Type{DisplacementConstant}) |
|
64 return Primal{DisplacementConstant}(xinit, zeros(2)) |
|
65 end |
|
66 |
|
67 function init_primal(xinit::Image, ::Type{DisplacementFull}) |
|
68 return Primal{DisplacementFull}(xinit, zeros(2, size(xinit)...)) |
|
69 end |
|
70 |
|
71 function init_rest(x::Primal{DisplacementT}) where DisplacementT |
|
72 imdim=size(x.x) |
|
73 |
|
74 y = Dual(zeros(2, imdim...), zeros(imdim)) |
|
75 Δx = copy(x) |
|
76 Δy = copy(y) |
|
77 x̄ = copy(x) |
|
78 |
|
79 return x, y, Δx, Δy, x̄ |
|
80 end |
|
81 |
|
82 function init_iterates( :: Type{DisplacementT}, |
|
83 xinit::Primal{DisplacementT}) where DisplacementT |
|
84 return init_rest(copy(xinit)) |
|
85 end |
|
86 |
|
87 function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT |
|
88 return init_rest(init_primal(copy(xinit), DisplacementT)) |
|
89 end |
|
90 |
|
91 function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT |
|
92 return init_rest(init_primal(zeros(dim...), DisplacementT)) |
|
93 end |
|
94 |
|
95 ############################################## |
|
96 # Weighting for different displacements types |
|
97 ############################################## |
|
98 |
|
99 norm²weight( :: Type{DisplacementConstant}, sz ) = prod(sz) |
|
100 norm²weight( :: Type{DisplacementFull}, sz ) = 1 |
|
101 |
|
102 ############ |
|
103 # Algorithm |
|
104 ############ |
|
105 |
|
106 function solve( :: Type{DisplacementT}; |
|
107 dim :: ImageSize, |
|
108 iterate = AlgTools.simple_iterate, |
|
109 params::NamedTuple) where DisplacementT |
|
110 |
|
111 ###################### |
|
112 # Initialise iterates |
|
113 ###################### |
|
114 |
|
115 x, y, Δx, Δy, x̄ = init_iterates(DisplacementT, dim) |
|
116 init_data = (params.init == :data) |
|
117 |
|
118 # … for tracking cumulative movement |
|
119 if DisplacementT == DisplacementConstant |
|
120 ucumul = [0.0, 0.0] |
|
121 else |
|
122 ucumul = [NaN, NaN] |
|
123 end |
|
124 |
|
125 ############################################# |
|
126 # Extract parameters and set up step lengths |
|
127 ############################################# |
|
128 |
|
129 α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep |
|
130 R_K² = max(∇₂_norm₂₂_est², ∇₂_norm₂∞_est²*params.dynrange^2) |
|
131 γ = min(1, λ*norm²weight(DisplacementT, size(x.x))) |
|
132 τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²) |
|
133 |
|
134 kernel = params.kernel |
|
135 |
|
136 #################### |
|
137 # Run the algorithm |
|
138 #################### |
|
139 |
|
140 b_next_filt=nothing |
|
141 |
|
142 v = iterate(params) do verbose :: Function, |
|
143 b :: Image, |
|
144 🚫unused_v_known :: DisplacementT, |
|
145 b_next :: Image |
|
146 |
|
147 #################################### |
|
148 # Smooth data for Horn–Schunck term |
|
149 #################################### |
|
150 |
|
151 b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel) |
|
152 |
|
153 ############################ |
|
154 # Construct K for this step |
|
155 ############################ |
|
156 |
|
157 K! = (yʹ, xʹ) -> begin |
|
158 # Optical flow part: ⟨⟨u, ∇b_k⟩⟩. |
|
159 # Use y.tv as temporary gradient storage. |
|
160 pointwise_gradiprod_2d!(yʹ.flow, yʹ.tv, xʹ.u, b_filt) |
|
161 #@. yʹ.flow = -yʹ.flow |
|
162 # TV part |
|
163 ∇₂!(yʹ.tv, xʹ.x) |
|
164 end |
|
165 Kᵀ! = (xʹ, yʹ) -> begin |
|
166 # Optical flow part: ∇b_k ⋅ y |
|
167 pointwise_gradiprod_2dᵀ!(xʹ.u, yʹ.flow, b_filt) |
|
168 #@. xʹ.u = -xʹ.u |
|
169 # TV part |
|
170 ∇₂ᵀ!(xʹ.x, yʹ.tv) |
|
171 end |
|
172 |
|
173 ################## |
|
174 # Prediction step |
|
175 ################## |
|
176 |
|
177 if init_data |
|
178 x .= b |
|
179 init_data = false |
|
180 end |
|
181 |
|
182 pdflow!(x.x, Δx.x, y.tv, Δy.tv, y.flow, Δy.flow, x.u, params.dual_flow) |
|
183 |
|
184 # Predict zero displacement |
|
185 x.u .= 0 |
|
186 if params.prox_predict |
|
187 K!(Δy, x) |
|
188 @. y.tv = (y.tv + σ̃*Δy.tv)/(1 + σ̃*(ρ̃+ρ/α)) |
|
189 proj_norm₂₁ball!(y.tv, α) |
|
190 @. y.flow = (y.flow+σ̃*((b_next_filt-b_filt)/T+Δy.flow))/(1+σ̃*(ρ̃+1/θ)) |
|
191 end |
|
192 |
|
193 ############ |
|
194 # PDPS step |
|
195 # |
|
196 # NOTE: For DisplacementConstant, the x.u update is supposed to be with |
|
197 # respect to the 𝟙^*𝟙 norm/inner product that makes the norm equivalent |
|
198 # to full-space norm when restricted to constant displacements. Since |
|
199 # `OpticalFlow.pointwise_gradiprod_2dᵀ!` already uses this inner product, |
|
200 # and the λ-weighted term in the problem is with respect to this norm, |
|
201 # all the norm weights disappear in this update. |
|
202 ############ |
|
203 |
|
204 Kᵀ!(Δx, y) # primal step: |
|
205 @. x̄.x = x.x # | save old x for over-relax |
|
206 @. x̄.u = x.u # | |
|
207 @. x.x = (x.x-τ*(Δx.x-b))/(1+τ) # | prox |
|
208 @. x.u = (x.u-τ*Δx.u)/(1+τ*λ) # | |
|
209 @. x̄.x = 2x.x - x̄.x # over-relax |
|
210 @. x̄.u = 2x.u - x̄.u # | |
|
211 K!(Δy, x̄) # dual step: y |
|
212 @. y.tv = (y.tv + σ*Δy.tv)/(1 + σ*ρ/α) # | |
|
213 proj_norm₂₁ball!(y.tv, α) # | prox |
|
214 @. y.flow = (y.flow+σ*((b_next_filt-b_filt)/T+Δy.flow))/(1+σ/θ) |
|
215 |
|
216 if DisplacementT == DisplacementConstant |
|
217 ucumul .+= x.u |
|
218 end |
|
219 |
|
220 ######################################################## |
|
221 # Give function value and cumulative movement if needed |
|
222 ######################################################## |
|
223 v = verbose() do |
|
224 K!(Δy, x) |
|
225 value = (norm₂²(b-x.x)/2 + θ*norm₂²((b_next_filt-b_filt)./T+Δy.flow) |
|
226 + λ*norm₂²(x.u)/2 + α*γnorm₂₁(Δy.tv, ρ)) |
|
227 |
|
228 value, x.x, ucumul, nothing |
|
229 end |
|
230 |
|
231 return v |
|
232 end |
|
233 |
|
234 return x, y, v |
|
235 end |
|
236 |
|
237 end # Module |
|
238 |
|
239 |