--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/AlgorithmBothNL.jl Tue Apr 07 14:19:48 2020 -0500 @@ -0,0 +1,242 @@ +###################################################################### +# Predictive online PDPS for optical flow with unknown velocity field +###################################################################### + +__precompile__() + +module AlgorithmBothNL + +identifier = "pdps_unknown_nl" + + +using Printf + +using AlgTools.Util +import AlgTools.Iterate +using ImageTools.Gradient + +using ..OpticalFlow: ImageSize, + Image, + Gradient, + DisplacementConstant, + DisplacementFull, + pdflow!, + pointwise_gradiprod_2d!, + pointwise_gradiprod_2dᵀ!, + filter_hs + +using ..Algorithm: step_lengths + +############# +# Data types +############# + +struct Primal{DisplacementT} + x :: Image + u :: DisplacementT +end + +function Base.similar(x::Primal{DisplacementT}) where DisplacementT + return Primal{DisplacementT}(Base.similar(x.x), Base.similar(x.u)) +end + +function Base.copy(x::Primal{DisplacementT}) where DisplacementT + return Primal{DisplacementT}(Base.copy(x.x), Base.copy(x.u)) + end + +struct Dual + tv :: Gradient + flow :: Image +end + +function Base.similar(y::Dual) + return Dual(Base.similar(y.tv), Base.similar(y.flow)) +end + +function Base.copy(y::Dual) + return Dual(Base.copy(y.tv), Base.copy(y.flow)) + end + +######################### +# Iterate initialisation +######################### + +function init_primal(xinit::Image, ::Type{DisplacementConstant}) + return Primal{DisplacementConstant}(xinit, zeros(2)) +end + +function init_primal(xinit::Image, ::Type{DisplacementFull}) + return Primal{DisplacementFull}(xinit, zeros(2, size(xinit)...)) +end + +function init_rest(x::Primal{DisplacementT}) where DisplacementT + imdim=size(x.x) + + y = Dual(zeros(2, imdim...), zeros(imdim)) + Δx = copy(x) + Δy = copy(y) + x̄ = copy(x) + + return x, y, Δx, Δy, x̄ +end + +function init_iterates( :: Type{DisplacementT}, + xinit::Primal{DisplacementT}) where DisplacementT + return init_rest(copy(xinit)) +end + +function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT + return init_rest(init_primal(copy(xinit), DisplacementT)) +end + +function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT + return init_rest(init_primal(zeros(dim...), DisplacementT)) +end + +############################################## +# Weighting for different displacements types +############################################## + +norm²weight( :: Type{DisplacementConstant}, sz ) = prod(sz) +norm²weight( :: Type{DisplacementFull}, sz ) = 1 + +############ +# Algorithm +############ + +function solve( :: Type{DisplacementT}; + dim :: ImageSize, + iterate = AlgTools.simple_iterate, + params::NamedTuple) where DisplacementT + + ###################### + # Initialise iterates + ###################### + + x, y, Δx, Δy, x̄ = init_iterates(DisplacementT, dim) + init_data = (params.init == :data) + + # … for tracking cumulative movement + if DisplacementT == DisplacementConstant + ucumul = [0.0, 0.0] + else + ucumul = [NaN, NaN] + end + + ############################################# + # Extract parameters and set up step lengths + ############################################# + + α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep + R_K² = max(∇₂_norm₂₂_est², ∇₂_norm₂∞_est²*params.dynrange^2) + γ = min(1, λ*norm²weight(DisplacementT, size(x.x))) + τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²) + + kernel = params.kernel + + #################### + # Run the algorithm + #################### + + b_next_filt=nothing + + v = iterate(params) do verbose :: Function, + b :: Image, + 🚫unused_v_known :: DisplacementT, + b_next :: Image + + #################################### + # Smooth data for Horn–Schunck term + #################################### + + b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel) + + ############################ + # Construct K for this step + ############################ + + K! = (yʹ, xʹ) -> begin + # Optical flow part + @. yʹ.flow = b_filt + flow!(yʹ.flow, Δx.x, xʹ.u) + @. yʹ.flow = yʹ.flow - b_next_filt + # TV part + ∇₂!(yʹ.tv, xʹ.x) + end + Kᵀ! = (xʹ, yʹ) -> begin + # Optical flow part: ∇b_k ⋅ y + # + # TODO: This really should depend x.u, but x.u is zero. + # + pointwise_gradiprod_2dᵀ!(xʹ.u, yʹ.flow, b_filt) + # TV part + ∇₂ᵀ!(xʹ.x, yʹ.tv) + end + + ################## + # Prediction step + ################## + + if init_data + x .= b + init_data = false + end + + pdflow!(x.x, Δx.x, y.tv, Δy.tv, y.flow, Δy.flow, x.u, params.dual_flow) + + # Predict zero displacement + x.u .= 0 + if params.prox_predict + K!(Δy, x) + @. y.tv = (y.tv + σ̃*Δy.tv)/(1 + σ̃*(ρ̃+ρ/α)) + proj_norm₂₁ball!(y.tv, α) + @. y.flow = (y.flow+σ̃*Δy.flow)/(1+σ̃*(ρ̃+1/θ)) + end + + ############ + # PDPS step + # + # NOTE: For DisplacementConstant, the x.u update is supposed to be with + # respect to the 𝟙^*𝟙 norm/inner product that makes the norm equivalent + # to full-space norm when restricted to constant displacements. Since + # `OpticalFlow.pointwise_gradiprod_2dᵀ!` already uses this inner product, + # and the λ-weighted term in the problem is with respect to this norm, + # all the norm weights disappear in this update. + ############ + + Kᵀ!(Δx, y) # primal step: + @. x̄.x = x.x # | save old x for over-relax + @. x̄.u = x.u # | + @. x.x = (x.x-τ*(Δx.x-b))/(1+τ) # | prox + @. x.u = (x.u-τ*Δx.u)/(1+τ*λ) # | + @. x̄.x = 2x.x - x̄.x # over-relax + @. x̄.u = 2x.u - x̄.u # | + K!(Δy, x̄) # dual step: y + @. y.tv = (y.tv + σ*Δy.tv)/(1 + σ*ρ/α) # | + proj_norm₂₁ball!(y.tv, α) # | prox + @. y.flow = (y.flow+σ*Δy.flow)/(1+σ/θ) + + if DisplacementT == DisplacementConstant + ucumul .+= x.u + end + + ######################################################## + # Give function value and cumulative movement if needed + ######################################################## + v = verbose() do + K!(Δy, x) + value = (norm₂²(b-x.x)/2 + θ*norm₂²(Δy.flow) + + λ*norm₂²(x.u)/2 + α*γnorm₂₁(Δy.tv, ρ)) + + value, x.x, ucumul, nothing + end + + return v + end + + return x, y, v +end + +end # Module + +