src/AlgorithmBoth.jl

changeset 0
a55e35d20336
equal deleted inserted replaced
-1:000000000000 0:a55e35d20336
1 ######################################################################
2 # Predictive online PDPS for optical flow with unknown velocity field
3 ######################################################################
4
5 __precompile__()
6
7 module AlgorithmBoth
8
9 identifier = "pdps_unknown_basic"
10
11 using Printf
12
13 using AlgTools.Util
14 import AlgTools.Iterate
15 using ImageTools.Gradient
16
17 using ..OpticalFlow: ImageSize,
18 Image,
19 Gradient,
20 DisplacementConstant,
21 DisplacementFull,
22 pdflow!,
23 pointwise_gradiprod_2d!,
24 pointwise_gradiprod_2dᵀ!,
25 filter_hs
26
27 using ..Algorithm: step_lengths
28
29 #############
30 # Data types
31 #############
32
33 struct Primal{DisplacementT}
34 x :: Image
35 u :: DisplacementT
36 end
37
38 function Base.similar(x::Primal{DisplacementT}) where DisplacementT
39 return Primal{DisplacementT}(Base.similar(x.x), Base.similar(x.u))
40 end
41
42 function Base.copy(x::Primal{DisplacementT}) where DisplacementT
43 return Primal{DisplacementT}(Base.copy(x.x), Base.copy(x.u))
44 end
45
46 struct Dual
47 tv :: Gradient
48 flow :: Image
49 end
50
51 function Base.similar(y::Dual)
52 return Dual(Base.similar(y.tv), Base.similar(y.flow))
53 end
54
55 function Base.copy(y::Dual)
56 return Dual(Base.copy(y.tv), Base.copy(y.flow))
57 end
58
59 #########################
60 # Iterate initialisation
61 #########################
62
63 function init_primal(xinit::Image, ::Type{DisplacementConstant})
64 return Primal{DisplacementConstant}(xinit, zeros(2))
65 end
66
67 function init_primal(xinit::Image, ::Type{DisplacementFull})
68 return Primal{DisplacementFull}(xinit, zeros(2, size(xinit)...))
69 end
70
71 function init_rest(x::Primal{DisplacementT}) where DisplacementT
72 imdim=size(x.x)
73
74 y = Dual(zeros(2, imdim...), zeros(imdim))
75 Δx = copy(x)
76 Δy = copy(y)
77 x̄ = copy(x)
78
79 return x, y, Δx, Δy, x̄
80 end
81
82 function init_iterates( :: Type{DisplacementT},
83 xinit::Primal{DisplacementT}) where DisplacementT
84 return init_rest(copy(xinit))
85 end
86
87 function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT
88 return init_rest(init_primal(copy(xinit), DisplacementT))
89 end
90
91 function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT
92 return init_rest(init_primal(zeros(dim...), DisplacementT))
93 end
94
95 ##############################################
96 # Weighting for different displacements types
97 ##############################################
98
99 norm²weight( :: Type{DisplacementConstant}, sz ) = prod(sz)
100 norm²weight( :: Type{DisplacementFull}, sz ) = 1
101
102 ############
103 # Algorithm
104 ############
105
106 function solve( :: Type{DisplacementT};
107 dim :: ImageSize,
108 iterate = AlgTools.simple_iterate,
109 params::NamedTuple) where DisplacementT
110
111 ######################
112 # Initialise iterates
113 ######################
114
115 x, y, Δx, Δy, x̄ = init_iterates(DisplacementT, dim)
116 init_data = (params.init == :data)
117
118 # … for tracking cumulative movement
119 if DisplacementT == DisplacementConstant
120 ucumul = [0.0, 0.0]
121 else
122 ucumul = [NaN, NaN]
123 end
124
125 #############################################
126 # Extract parameters and set up step lengths
127 #############################################
128
129 α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep
130 R_K² = max(∇₂_norm₂₂_est², ∇₂_norm₂∞_est²*params.dynrange^2)
131 γ = min(1, λ*norm²weight(DisplacementT, size(x.x)))
132 τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²)
133
134 kernel = params.kernel
135
136 ####################
137 # Run the algorithm
138 ####################
139
140 b_next_filt=nothing
141
142 v = iterate(params) do verbose :: Function,
143 b :: Image,
144 🚫unused_v_known :: DisplacementT,
145 b_next :: Image
146
147 ####################################
148 # Smooth data for Horn–Schunck term
149 ####################################
150
151 b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel)
152
153 ############################
154 # Construct K for this step
155 ############################
156
157 K! = (yʹ, xʹ) -> begin
158 # Optical flow part: ⟨⟨u, ∇b_k⟩⟩.
159 # Use y.tv as temporary gradient storage.
160 pointwise_gradiprod_2d!(yʹ.flow, yʹ.tv, xʹ.u, b_filt)
161 #@. yʹ.flow = -yʹ.flow
162 # TV part
163 ∇₂!(yʹ.tv, xʹ.x)
164 end
165 Kᵀ! = (xʹ, yʹ) -> begin
166 # Optical flow part: ∇b_k ⋅ y
167 pointwise_gradiprod_2dᵀ!(xʹ.u, yʹ.flow, b_filt)
168 #@. xʹ.u = -xʹ.u
169 # TV part
170 ∇₂ᵀ!(xʹ.x, yʹ.tv)
171 end
172
173 ##################
174 # Prediction step
175 ##################
176
177 if init_data
178 x .= b
179 init_data = false
180 end
181
182 pdflow!(x.x, Δx.x, y.tv, Δy.tv, y.flow, Δy.flow, x.u, params.dual_flow)
183
184 # Predict zero displacement
185 x.u .= 0
186 if params.prox_predict
187 K!(Δy, x)
188 @. y.tv = (y.tv + σ̃*Δy.tv)/(1 + σ̃*(ρ̃+ρ/α))
189 proj_norm₂₁ball!(y.tv, α)
190 @. y.flow = (y.flow+σ̃*((b_next_filt-b_filt)/T+Δy.flow))/(1+σ̃*(ρ̃+1/θ))
191 end
192
193 ############
194 # PDPS step
195 #
196 # NOTE: For DisplacementConstant, the x.u update is supposed to be with
197 # respect to the 𝟙^*𝟙 norm/inner product that makes the norm equivalent
198 # to full-space norm when restricted to constant displacements. Since
199 # `OpticalFlow.pointwise_gradiprod_2dᵀ!` already uses this inner product,
200 # and the λ-weighted term in the problem is with respect to this norm,
201 # all the norm weights disappear in this update.
202 ############
203
204 Kᵀ!(Δx, y) # primal step:
205 @. x̄.x = x.x # | save old x for over-relax
206 @. x̄.u = x.u # |
207 @. x.x = (x.x-τ*(Δx.x-b))/(1+τ) # | prox
208 @. x.u = (x.u-τ*Δx.u)/(1+τ*λ) # |
209 @. x̄.x = 2x.x - x̄.x # over-relax
210 @. x̄.u = 2x.u - x̄.u # |
211 K!(Δy, x̄) # dual step: y
212 @. y.tv = (y.tv + σ*Δy.tv)/(1 + σ*ρ/α) # |
213 proj_norm₂₁ball!(y.tv, α) # | prox
214 @. y.flow = (y.flow+σ*((b_next_filt-b_filt)/T+Δy.flow))/(1+σ/θ)
215
216 if DisplacementT == DisplacementConstant
217 ucumul .+= x.u
218 end
219
220 ########################################################
221 # Give function value and cumulative movement if needed
222 ########################################################
223 v = verbose() do
224 K!(Δy, x)
225 value = (norm₂²(b-x.x)/2 + θ*norm₂²((b_next_filt-b_filt)./T+Δy.flow)
226 + λ*norm₂²(x.u)/2 + α*γnorm₂₁(Δy.tv, ρ))
227
228 value, x.x, ucumul, nothing
229 end
230
231 return v
232 end
233
234 return x, y, v
235 end
236
237 end # Module
238
239

mercurial