Sat, 20 Apr 2024 12:31:37 +0300
update params
0 | 1 | ###################################################################### |
2 | # Predictive online PDPS for optical flow with unknown velocity field | |
3 | ###################################################################### | |
4 | ||
5 | __precompile__() | |
6 | ||
7 | module AlgorithmBoth | |
8 | ||
9 | identifier = "pdps_unknown_basic" | |
10 | ||
11 | using Printf | |
12 | ||
13 | using AlgTools.Util | |
14 | import AlgTools.Iterate | |
15 | using ImageTools.Gradient | |
16 | ||
17 | using ..OpticalFlow: ImageSize, | |
18 | Image, | |
19 | Gradient, | |
20 | DisplacementConstant, | |
21 | DisplacementFull, | |
22 | pdflow!, | |
23 | pointwise_gradiprod_2d!, | |
24 | pointwise_gradiprod_2dᵀ!, | |
25 | filter_hs | |
26 | ||
27 | using ..Algorithm: step_lengths | |
28 | ||
29 | ############# | |
30 | # Data types | |
31 | ############# | |
32 | ||
33 | struct Primal{DisplacementT} | |
34 | x :: Image | |
35 | u :: DisplacementT | |
36 | end | |
37 | ||
38 | function Base.similar(x::Primal{DisplacementT}) where DisplacementT | |
39 | return Primal{DisplacementT}(Base.similar(x.x), Base.similar(x.u)) | |
40 | end | |
41 | ||
42 | function Base.copy(x::Primal{DisplacementT}) where DisplacementT | |
43 | return Primal{DisplacementT}(Base.copy(x.x), Base.copy(x.u)) | |
44 | end | |
45 | ||
46 | struct Dual | |
47 | tv :: Gradient | |
48 | flow :: Image | |
49 | end | |
50 | ||
51 | function Base.similar(y::Dual) | |
52 | return Dual(Base.similar(y.tv), Base.similar(y.flow)) | |
53 | end | |
54 | ||
55 | function Base.copy(y::Dual) | |
56 | return Dual(Base.copy(y.tv), Base.copy(y.flow)) | |
57 | end | |
58 | ||
59 | ######################### | |
60 | # Iterate initialisation | |
61 | ######################### | |
62 | ||
63 | function init_primal(xinit::Image, ::Type{DisplacementConstant}) | |
64 | return Primal{DisplacementConstant}(xinit, zeros(2)) | |
65 | end | |
66 | ||
67 | function init_primal(xinit::Image, ::Type{DisplacementFull}) | |
68 | return Primal{DisplacementFull}(xinit, zeros(2, size(xinit)...)) | |
69 | end | |
70 | ||
71 | function init_rest(x::Primal{DisplacementT}) where DisplacementT | |
72 | imdim=size(x.x) | |
73 | ||
74 | y = Dual(zeros(2, imdim...), zeros(imdim)) | |
75 | Δx = copy(x) | |
76 | Δy = copy(y) | |
77 | x̄ = copy(x) | |
78 | ||
79 | return x, y, Δx, Δy, x̄ | |
80 | end | |
81 | ||
82 | function init_iterates( :: Type{DisplacementT}, | |
83 | xinit::Primal{DisplacementT}) where DisplacementT | |
84 | return init_rest(copy(xinit)) | |
85 | end | |
86 | ||
87 | function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT | |
88 | return init_rest(init_primal(copy(xinit), DisplacementT)) | |
89 | end | |
90 | ||
91 | function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT | |
92 | return init_rest(init_primal(zeros(dim...), DisplacementT)) | |
93 | end | |
94 | ||
95 | ############################################## | |
96 | # Weighting for different displacements types | |
97 | ############################################## | |
98 | ||
99 | norm²weight( :: Type{DisplacementConstant}, sz ) = prod(sz) | |
100 | norm²weight( :: Type{DisplacementFull}, sz ) = 1 | |
101 | ||
102 | ############ | |
103 | # Algorithm | |
104 | ############ | |
105 | ||
106 | function solve( :: Type{DisplacementT}; | |
107 | dim :: ImageSize, | |
108 | iterate = AlgTools.simple_iterate, | |
109 | params::NamedTuple) where DisplacementT | |
110 | ||
111 | ###################### | |
112 | # Initialise iterates | |
113 | ###################### | |
114 | ||
115 | x, y, Δx, Δy, x̄ = init_iterates(DisplacementT, dim) | |
116 | init_data = (params.init == :data) | |
117 | ||
118 | # … for tracking cumulative movement | |
119 | if DisplacementT == DisplacementConstant | |
120 | ucumul = [0.0, 0.0] | |
121 | else | |
122 | ucumul = [NaN, NaN] | |
123 | end | |
124 | ||
125 | ############################################# | |
126 | # Extract parameters and set up step lengths | |
127 | ############################################# | |
128 | ||
129 | α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep | |
130 | R_K² = max(∇₂_norm₂₂_est², ∇₂_norm₂∞_est²*params.dynrange^2) | |
131 | γ = min(1, λ*norm²weight(DisplacementT, size(x.x))) | |
132 | τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²) | |
133 | ||
134 | kernel = params.kernel | |
135 | ||
136 | #################### | |
137 | # Run the algorithm | |
138 | #################### | |
139 | ||
140 | b_next_filt=nothing | |
141 | ||
142 | v = iterate(params) do verbose :: Function, | |
143 | b :: Image, | |
144 | 🚫unused_v_known :: DisplacementT, | |
145 | b_next :: Image | |
146 | ||
147 | #################################### | |
148 | # Smooth data for Horn–Schunck term | |
149 | #################################### | |
150 | ||
151 | b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel) | |
152 | ||
153 | ############################ | |
154 | # Construct K for this step | |
155 | ############################ | |
156 | ||
157 | K! = (yʹ, xʹ) -> begin | |
158 | # Optical flow part: ⟨⟨u, ∇b_k⟩⟩. | |
159 | # Use y.tv as temporary gradient storage. | |
160 | pointwise_gradiprod_2d!(yʹ.flow, yʹ.tv, xʹ.u, b_filt) | |
161 | #@. yʹ.flow = -yʹ.flow | |
162 | # TV part | |
163 | ∇₂!(yʹ.tv, xʹ.x) | |
164 | end | |
165 | Kᵀ! = (xʹ, yʹ) -> begin | |
166 | # Optical flow part: ∇b_k ⋅ y | |
167 | pointwise_gradiprod_2dᵀ!(xʹ.u, yʹ.flow, b_filt) | |
168 | #@. xʹ.u = -xʹ.u | |
169 | # TV part | |
170 | ∇₂ᵀ!(xʹ.x, yʹ.tv) | |
171 | end | |
172 | ||
173 | ################## | |
174 | # Prediction step | |
175 | ################## | |
176 | ||
177 | if init_data | |
178 | x .= b | |
179 | init_data = false | |
180 | end | |
181 | ||
182 | pdflow!(x.x, Δx.x, y.tv, Δy.tv, y.flow, Δy.flow, x.u, params.dual_flow) | |
183 | ||
184 | # Predict zero displacement | |
185 | x.u .= 0 | |
186 | if params.prox_predict | |
187 | K!(Δy, x) | |
188 | @. y.tv = (y.tv + σ̃*Δy.tv)/(1 + σ̃*(ρ̃+ρ/α)) | |
189 | proj_norm₂₁ball!(y.tv, α) | |
190 | @. y.flow = (y.flow+σ̃*((b_next_filt-b_filt)/T+Δy.flow))/(1+σ̃*(ρ̃+1/θ)) | |
191 | end | |
192 | ||
193 | ############ | |
194 | # PDPS step | |
195 | # | |
196 | # NOTE: For DisplacementConstant, the x.u update is supposed to be with | |
197 | # respect to the 𝟙^*𝟙 norm/inner product that makes the norm equivalent | |
198 | # to full-space norm when restricted to constant displacements. Since | |
199 | # `OpticalFlow.pointwise_gradiprod_2dᵀ!` already uses this inner product, | |
200 | # and the λ-weighted term in the problem is with respect to this norm, | |
201 | # all the norm weights disappear in this update. | |
202 | ############ | |
203 | ||
204 | Kᵀ!(Δx, y) # primal step: | |
205 | @. x̄.x = x.x # | save old x for over-relax | |
206 | @. x̄.u = x.u # | | |
207 | @. x.x = (x.x-τ*(Δx.x-b))/(1+τ) # | prox | |
208 | @. x.u = (x.u-τ*Δx.u)/(1+τ*λ) # | | |
209 | @. x̄.x = 2x.x - x̄.x # over-relax | |
210 | @. x̄.u = 2x.u - x̄.u # | | |
211 | K!(Δy, x̄) # dual step: y | |
212 | @. y.tv = (y.tv + σ*Δy.tv)/(1 + σ*ρ/α) # | | |
213 | proj_norm₂₁ball!(y.tv, α) # | prox | |
214 | @. y.flow = (y.flow+σ*((b_next_filt-b_filt)/T+Δy.flow))/(1+σ/θ) | |
215 | ||
216 | if DisplacementT == DisplacementConstant | |
217 | ucumul .+= x.u | |
218 | end | |
219 | ||
220 | ######################################################## | |
221 | # Give function value and cumulative movement if needed | |
222 | ######################################################## | |
223 | v = verbose() do | |
224 | K!(Δy, x) | |
225 | value = (norm₂²(b-x.x)/2 + θ*norm₂²((b_next_filt-b_filt)./T+Δy.flow) | |
226 | + λ*norm₂²(x.u)/2 + α*γnorm₂₁(Δy.tv, ρ)) | |
227 | ||
228 | value, x.x, ucumul, nothing | |
229 | end | |
230 | ||
231 | return v | |
232 | end | |
233 | ||
234 | return x, y, v | |
235 | end | |
236 | ||
237 | end # Module | |
238 | ||
239 |