| 1 ###################################################################### |
|
| 2 # Predictive online PDPS for optical flow with unknown velocity field |
|
| 3 ###################################################################### |
|
| 4 |
|
| 5 __precompile__() |
|
| 6 |
|
| 7 module AlgorithmBothNL |
|
| 8 |
|
| 9 identifier = "pdps_unknown_nl" |
|
| 10 |
|
| 11 |
|
| 12 using Printf |
|
| 13 |
|
| 14 using AlgTools.Util |
|
| 15 import AlgTools.Iterate |
|
| 16 using ImageTools.Gradient |
|
| 17 |
|
| 18 using ..OpticalFlow: ImageSize, |
|
| 19 Image, |
|
| 20 Gradient, |
|
| 21 DisplacementConstant, |
|
| 22 DisplacementFull, |
|
| 23 pdflow!, |
|
| 24 pointwise_gradiprod_2d!, |
|
| 25 pointwise_gradiprod_2dᵀ!, |
|
| 26 filter_hs |
|
| 27 |
|
| 28 using ..Algorithm: step_lengths |
|
| 29 |
|
| 30 ############# |
|
| 31 # Data types |
|
| 32 ############# |
|
| 33 |
|
| 34 struct Primal{DisplacementT} |
|
| 35 x :: Image |
|
| 36 u :: DisplacementT |
|
| 37 end |
|
| 38 |
|
| 39 function Base.similar(x::Primal{DisplacementT}) where DisplacementT |
|
| 40 return Primal{DisplacementT}(Base.similar(x.x), Base.similar(x.u)) |
|
| 41 end |
|
| 42 |
|
| 43 function Base.copy(x::Primal{DisplacementT}) where DisplacementT |
|
| 44 return Primal{DisplacementT}(Base.copy(x.x), Base.copy(x.u)) |
|
| 45 end |
|
| 46 |
|
| 47 struct Dual |
|
| 48 tv :: Gradient |
|
| 49 flow :: Image |
|
| 50 end |
|
| 51 |
|
| 52 function Base.similar(y::Dual) |
|
| 53 return Dual(Base.similar(y.tv), Base.similar(y.flow)) |
|
| 54 end |
|
| 55 |
|
| 56 function Base.copy(y::Dual) |
|
| 57 return Dual(Base.copy(y.tv), Base.copy(y.flow)) |
|
| 58 end |
|
| 59 |
|
| 60 ######################### |
|
| 61 # Iterate initialisation |
|
| 62 ######################### |
|
| 63 |
|
| 64 function init_primal(xinit::Image, ::Type{DisplacementConstant}) |
|
| 65 return Primal{DisplacementConstant}(xinit, zeros(2)) |
|
| 66 end |
|
| 67 |
|
| 68 function init_primal(xinit::Image, ::Type{DisplacementFull}) |
|
| 69 return Primal{DisplacementFull}(xinit, zeros(2, size(xinit)...)) |
|
| 70 end |
|
| 71 |
|
| 72 function init_rest(x::Primal{DisplacementT}) where DisplacementT |
|
| 73 imdim=size(x.x) |
|
| 74 |
|
| 75 y = Dual(zeros(2, imdim...), zeros(imdim)) |
|
| 76 Δx = copy(x) |
|
| 77 Δy = copy(y) |
|
| 78 x̄ = copy(x) |
|
| 79 |
|
| 80 return x, y, Δx, Δy, x̄ |
|
| 81 end |
|
| 82 |
|
| 83 function init_iterates( :: Type{DisplacementT}, |
|
| 84 xinit::Primal{DisplacementT}) where DisplacementT |
|
| 85 return init_rest(copy(xinit)) |
|
| 86 end |
|
| 87 |
|
| 88 function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT |
|
| 89 return init_rest(init_primal(copy(xinit), DisplacementT)) |
|
| 90 end |
|
| 91 |
|
| 92 function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT |
|
| 93 return init_rest(init_primal(zeros(dim...), DisplacementT)) |
|
| 94 end |
|
| 95 |
|
| 96 ############################################## |
|
| 97 # Weighting for different displacements types |
|
| 98 ############################################## |
|
| 99 |
|
| 100 norm²weight( :: Type{DisplacementConstant}, sz ) = prod(sz) |
|
| 101 norm²weight( :: Type{DisplacementFull}, sz ) = 1 |
|
| 102 |
|
| 103 ############ |
|
| 104 # Algorithm |
|
| 105 ############ |
|
| 106 |
|
| 107 function solve( :: Type{DisplacementT}; |
|
| 108 dim :: ImageSize, |
|
| 109 iterate = AlgTools.simple_iterate, |
|
| 110 params::NamedTuple) where DisplacementT |
|
| 111 |
|
| 112 ###################### |
|
| 113 # Initialise iterates |
|
| 114 ###################### |
|
| 115 |
|
| 116 x, y, Δx, Δy, x̄ = init_iterates(DisplacementT, dim) |
|
| 117 init_data = (params.init == :data) |
|
| 118 |
|
| 119 # … for tracking cumulative movement |
|
| 120 if DisplacementT == DisplacementConstant |
|
| 121 ucumul = [0.0, 0.0] |
|
| 122 else |
|
| 123 ucumul = [NaN, NaN] |
|
| 124 end |
|
| 125 |
|
| 126 ############################################# |
|
| 127 # Extract parameters and set up step lengths |
|
| 128 ############################################# |
|
| 129 |
|
| 130 α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep |
|
| 131 R_K² = max(∇₂_norm₂₂_est², ∇₂_norm₂∞_est²*params.dynrange^2) |
|
| 132 γ = min(1, λ*norm²weight(DisplacementT, size(x.x))) |
|
| 133 τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²) |
|
| 134 |
|
| 135 kernel = params.kernel |
|
| 136 |
|
| 137 #################### |
|
| 138 # Run the algorithm |
|
| 139 #################### |
|
| 140 |
|
| 141 b_next_filt=nothing |
|
| 142 |
|
| 143 v = iterate(params) do verbose :: Function, |
|
| 144 b :: Image, |
|
| 145 🚫unused_v_known :: DisplacementT, |
|
| 146 b_next :: Image |
|
| 147 |
|
| 148 #################################### |
|
| 149 # Smooth data for Horn–Schunck term |
|
| 150 #################################### |
|
| 151 |
|
| 152 b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel) |
|
| 153 |
|
| 154 ############################ |
|
| 155 # Construct K for this step |
|
| 156 ############################ |
|
| 157 |
|
| 158 K! = (yʹ, xʹ) -> begin |
|
| 159 # Optical flow part |
|
| 160 @. yʹ.flow = b_filt |
|
| 161 flow!(yʹ.flow, Δx.x, xʹ.u) |
|
| 162 @. yʹ.flow = yʹ.flow - b_next_filt |
|
| 163 # TV part |
|
| 164 ∇₂!(yʹ.tv, xʹ.x) |
|
| 165 end |
|
| 166 Kᵀ! = (xʹ, yʹ) -> begin |
|
| 167 # Optical flow part: ∇b_k ⋅ y |
|
| 168 # |
|
| 169 # TODO: This really should depend x.u, but x.u is zero. |
|
| 170 # |
|
| 171 pointwise_gradiprod_2dᵀ!(xʹ.u, yʹ.flow, b_filt) |
|
| 172 # TV part |
|
| 173 ∇₂ᵀ!(xʹ.x, yʹ.tv) |
|
| 174 end |
|
| 175 |
|
| 176 ################## |
|
| 177 # Prediction step |
|
| 178 ################## |
|
| 179 |
|
| 180 if init_data |
|
| 181 x .= b |
|
| 182 init_data = false |
|
| 183 end |
|
| 184 |
|
| 185 pdflow!(x.x, Δx.x, y.tv, Δy.tv, y.flow, Δy.flow, x.u, params.dual_flow) |
|
| 186 |
|
| 187 # Predict zero displacement |
|
| 188 x.u .= 0 |
|
| 189 if params.prox_predict |
|
| 190 K!(Δy, x) |
|
| 191 @. y.tv = (y.tv + σ̃*Δy.tv)/(1 + σ̃*(ρ̃+ρ/α)) |
|
| 192 proj_norm₂₁ball!(y.tv, α) |
|
| 193 @. y.flow = (y.flow+σ̃*Δy.flow)/(1+σ̃*(ρ̃+1/θ)) |
|
| 194 end |
|
| 195 |
|
| 196 ############ |
|
| 197 # PDPS step |
|
| 198 # |
|
| 199 # NOTE: For DisplacementConstant, the x.u update is supposed to be with |
|
| 200 # respect to the 𝟙^*𝟙 norm/inner product that makes the norm equivalent |
|
| 201 # to full-space norm when restricted to constant displacements. Since |
|
| 202 # `OpticalFlow.pointwise_gradiprod_2dᵀ!` already uses this inner product, |
|
| 203 # and the λ-weighted term in the problem is with respect to this norm, |
|
| 204 # all the norm weights disappear in this update. |
|
| 205 ############ |
|
| 206 |
|
| 207 Kᵀ!(Δx, y) # primal step: |
|
| 208 @. x̄.x = x.x # | save old x for over-relax |
|
| 209 @. x̄.u = x.u # | |
|
| 210 @. x.x = (x.x-τ*(Δx.x-b))/(1+τ) # | prox |
|
| 211 @. x.u = (x.u-τ*Δx.u)/(1+τ*λ) # | |
|
| 212 @. x̄.x = 2x.x - x̄.x # over-relax |
|
| 213 @. x̄.u = 2x.u - x̄.u # | |
|
| 214 K!(Δy, x̄) # dual step: y |
|
| 215 @. y.tv = (y.tv + σ*Δy.tv)/(1 + σ*ρ/α) # | |
|
| 216 proj_norm₂₁ball!(y.tv, α) # | prox |
|
| 217 @. y.flow = (y.flow+σ*Δy.flow)/(1+σ/θ) |
|
| 218 |
|
| 219 if DisplacementT == DisplacementConstant |
|
| 220 ucumul .+= x.u |
|
| 221 end |
|
| 222 |
|
| 223 ######################################################## |
|
| 224 # Give function value and cumulative movement if needed |
|
| 225 ######################################################## |
|
| 226 v = verbose() do |
|
| 227 K!(Δy, x) |
|
| 228 value = (norm₂²(b-x.x)/2 + θ*norm₂²(Δy.flow) |
|
| 229 + λ*norm₂²(x.u)/2 + α*γnorm₂₁(Δy.tv, ρ)) |
|
| 230 |
|
| 231 value, x.x, ucumul, nothing |
|
| 232 end |
|
| 233 |
|
| 234 return v |
|
| 235 end |
|
| 236 |
|
| 237 return x, y, v |
|
| 238 end |
|
| 239 |
|
| 240 end # Module |
|
| 241 |
|
| 242 |
|