Sat, 20 Apr 2024 12:31:37 +0300
update params
0 | 1 | ###################################################################### |
2 | # Predictive online PDPS for optical flow with unknown velocity field | |
3 | ###################################################################### | |
4 | ||
5 | __precompile__() | |
6 | ||
7 | module AlgorithmBothNL | |
8 | ||
9 | identifier = "pdps_unknown_nl" | |
10 | ||
11 | ||
12 | using Printf | |
13 | ||
14 | using AlgTools.Util | |
15 | import AlgTools.Iterate | |
16 | using ImageTools.Gradient | |
17 | ||
18 | using ..OpticalFlow: ImageSize, | |
19 | Image, | |
20 | Gradient, | |
21 | DisplacementConstant, | |
22 | DisplacementFull, | |
23 | pdflow!, | |
24 | pointwise_gradiprod_2d!, | |
25 | pointwise_gradiprod_2dᵀ!, | |
26 | filter_hs | |
27 | ||
28 | using ..Algorithm: step_lengths | |
29 | ||
30 | ############# | |
31 | # Data types | |
32 | ############# | |
33 | ||
34 | struct Primal{DisplacementT} | |
35 | x :: Image | |
36 | u :: DisplacementT | |
37 | end | |
38 | ||
39 | function Base.similar(x::Primal{DisplacementT}) where DisplacementT | |
40 | return Primal{DisplacementT}(Base.similar(x.x), Base.similar(x.u)) | |
41 | end | |
42 | ||
43 | function Base.copy(x::Primal{DisplacementT}) where DisplacementT | |
44 | return Primal{DisplacementT}(Base.copy(x.x), Base.copy(x.u)) | |
45 | end | |
46 | ||
47 | struct Dual | |
48 | tv :: Gradient | |
49 | flow :: Image | |
50 | end | |
51 | ||
52 | function Base.similar(y::Dual) | |
53 | return Dual(Base.similar(y.tv), Base.similar(y.flow)) | |
54 | end | |
55 | ||
56 | function Base.copy(y::Dual) | |
57 | return Dual(Base.copy(y.tv), Base.copy(y.flow)) | |
58 | end | |
59 | ||
60 | ######################### | |
61 | # Iterate initialisation | |
62 | ######################### | |
63 | ||
64 | function init_primal(xinit::Image, ::Type{DisplacementConstant}) | |
65 | return Primal{DisplacementConstant}(xinit, zeros(2)) | |
66 | end | |
67 | ||
68 | function init_primal(xinit::Image, ::Type{DisplacementFull}) | |
69 | return Primal{DisplacementFull}(xinit, zeros(2, size(xinit)...)) | |
70 | end | |
71 | ||
72 | function init_rest(x::Primal{DisplacementT}) where DisplacementT | |
73 | imdim=size(x.x) | |
74 | ||
75 | y = Dual(zeros(2, imdim...), zeros(imdim)) | |
76 | Δx = copy(x) | |
77 | Δy = copy(y) | |
78 | x̄ = copy(x) | |
79 | ||
80 | return x, y, Δx, Δy, x̄ | |
81 | end | |
82 | ||
83 | function init_iterates( :: Type{DisplacementT}, | |
84 | xinit::Primal{DisplacementT}) where DisplacementT | |
85 | return init_rest(copy(xinit)) | |
86 | end | |
87 | ||
88 | function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT | |
89 | return init_rest(init_primal(copy(xinit), DisplacementT)) | |
90 | end | |
91 | ||
92 | function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT | |
93 | return init_rest(init_primal(zeros(dim...), DisplacementT)) | |
94 | end | |
95 | ||
96 | ############################################## | |
97 | # Weighting for different displacements types | |
98 | ############################################## | |
99 | ||
100 | norm²weight( :: Type{DisplacementConstant}, sz ) = prod(sz) | |
101 | norm²weight( :: Type{DisplacementFull}, sz ) = 1 | |
102 | ||
103 | ############ | |
104 | # Algorithm | |
105 | ############ | |
106 | ||
107 | function solve( :: Type{DisplacementT}; | |
108 | dim :: ImageSize, | |
109 | iterate = AlgTools.simple_iterate, | |
110 | params::NamedTuple) where DisplacementT | |
111 | ||
112 | ###################### | |
113 | # Initialise iterates | |
114 | ###################### | |
115 | ||
116 | x, y, Δx, Δy, x̄ = init_iterates(DisplacementT, dim) | |
117 | init_data = (params.init == :data) | |
118 | ||
119 | # … for tracking cumulative movement | |
120 | if DisplacementT == DisplacementConstant | |
121 | ucumul = [0.0, 0.0] | |
122 | else | |
123 | ucumul = [NaN, NaN] | |
124 | end | |
125 | ||
126 | ############################################# | |
127 | # Extract parameters and set up step lengths | |
128 | ############################################# | |
129 | ||
130 | α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep | |
131 | R_K² = max(∇₂_norm₂₂_est², ∇₂_norm₂∞_est²*params.dynrange^2) | |
132 | γ = min(1, λ*norm²weight(DisplacementT, size(x.x))) | |
133 | τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²) | |
134 | ||
135 | kernel = params.kernel | |
136 | ||
137 | #################### | |
138 | # Run the algorithm | |
139 | #################### | |
140 | ||
141 | b_next_filt=nothing | |
142 | ||
143 | v = iterate(params) do verbose :: Function, | |
144 | b :: Image, | |
145 | 🚫unused_v_known :: DisplacementT, | |
146 | b_next :: Image | |
147 | ||
148 | #################################### | |
149 | # Smooth data for Horn–Schunck term | |
150 | #################################### | |
151 | ||
152 | b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel) | |
153 | ||
154 | ############################ | |
155 | # Construct K for this step | |
156 | ############################ | |
157 | ||
158 | K! = (yʹ, xʹ) -> begin | |
159 | # Optical flow part | |
160 | @. yʹ.flow = b_filt | |
161 | flow!(yʹ.flow, Δx.x, xʹ.u) | |
162 | @. yʹ.flow = yʹ.flow - b_next_filt | |
163 | # TV part | |
164 | ∇₂!(yʹ.tv, xʹ.x) | |
165 | end | |
166 | Kᵀ! = (xʹ, yʹ) -> begin | |
167 | # Optical flow part: ∇b_k ⋅ y | |
168 | # | |
169 | # TODO: This really should depend x.u, but x.u is zero. | |
170 | # | |
171 | pointwise_gradiprod_2dᵀ!(xʹ.u, yʹ.flow, b_filt) | |
172 | # TV part | |
173 | ∇₂ᵀ!(xʹ.x, yʹ.tv) | |
174 | end | |
175 | ||
176 | ################## | |
177 | # Prediction step | |
178 | ################## | |
179 | ||
180 | if init_data | |
181 | x .= b | |
182 | init_data = false | |
183 | end | |
184 | ||
185 | pdflow!(x.x, Δx.x, y.tv, Δy.tv, y.flow, Δy.flow, x.u, params.dual_flow) | |
186 | ||
187 | # Predict zero displacement | |
188 | x.u .= 0 | |
189 | if params.prox_predict | |
190 | K!(Δy, x) | |
191 | @. y.tv = (y.tv + σ̃*Δy.tv)/(1 + σ̃*(ρ̃+ρ/α)) | |
192 | proj_norm₂₁ball!(y.tv, α) | |
193 | @. y.flow = (y.flow+σ̃*Δy.flow)/(1+σ̃*(ρ̃+1/θ)) | |
194 | end | |
195 | ||
196 | ############ | |
197 | # PDPS step | |
198 | # | |
199 | # NOTE: For DisplacementConstant, the x.u update is supposed to be with | |
200 | # respect to the 𝟙^*𝟙 norm/inner product that makes the norm equivalent | |
201 | # to full-space norm when restricted to constant displacements. Since | |
202 | # `OpticalFlow.pointwise_gradiprod_2dᵀ!` already uses this inner product, | |
203 | # and the λ-weighted term in the problem is with respect to this norm, | |
204 | # all the norm weights disappear in this update. | |
205 | ############ | |
206 | ||
207 | Kᵀ!(Δx, y) # primal step: | |
208 | @. x̄.x = x.x # | save old x for over-relax | |
209 | @. x̄.u = x.u # | | |
210 | @. x.x = (x.x-τ*(Δx.x-b))/(1+τ) # | prox | |
211 | @. x.u = (x.u-τ*Δx.u)/(1+τ*λ) # | | |
212 | @. x̄.x = 2x.x - x̄.x # over-relax | |
213 | @. x̄.u = 2x.u - x̄.u # | | |
214 | K!(Δy, x̄) # dual step: y | |
215 | @. y.tv = (y.tv + σ*Δy.tv)/(1 + σ*ρ/α) # | | |
216 | proj_norm₂₁ball!(y.tv, α) # | prox | |
217 | @. y.flow = (y.flow+σ*Δy.flow)/(1+σ/θ) | |
218 | ||
219 | if DisplacementT == DisplacementConstant | |
220 | ucumul .+= x.u | |
221 | end | |
222 | ||
223 | ######################################################## | |
224 | # Give function value and cumulative movement if needed | |
225 | ######################################################## | |
226 | v = verbose() do | |
227 | K!(Δy, x) | |
228 | value = (norm₂²(b-x.x)/2 + θ*norm₂²(Δy.flow) | |
229 | + λ*norm₂²(x.u)/2 + α*γnorm₂₁(Δy.tv, ρ)) | |
230 | ||
231 | value, x.x, ucumul, nothing | |
232 | end | |
233 | ||
234 | return v | |
235 | end | |
236 | ||
237 | return x, y, v | |
238 | end | |
239 | ||
240 | end # Module | |
241 | ||
242 |