Thu, 25 Apr 2024 13:05:40 -0500
Reduce code duplication.
###################################################################### # Predictive online PDPS for optical flow with unknown velocity field ###################################################################### __precompile__() module AlgorithmBoth identifier = "pdps_unknown_basic" using Printf using AlgTools.Util import AlgTools.Iterate using ImageTools.Gradient using ..OpticalFlow: ImageSize, Image, Gradient, DisplacementConstant, DisplacementFull, pdflow!, pointwise_gradiprod_2d!, pointwise_gradiprod_2dᵀ!, filter_hs using ..Algorithm: step_lengths ############# # Data types ############# struct Primal{DisplacementT} x :: Image u :: DisplacementT end function Base.similar(x::Primal{DisplacementT}) where DisplacementT return Primal{DisplacementT}(Base.similar(x.x), Base.similar(x.u)) end function Base.copy(x::Primal{DisplacementT}) where DisplacementT return Primal{DisplacementT}(Base.copy(x.x), Base.copy(x.u)) end struct Dual tv :: Gradient flow :: Image end function Base.similar(y::Dual) return Dual(Base.similar(y.tv), Base.similar(y.flow)) end function Base.copy(y::Dual) return Dual(Base.copy(y.tv), Base.copy(y.flow)) end ######################### # Iterate initialisation ######################### function init_primal(xinit::Image, ::Type{DisplacementConstant}) return Primal{DisplacementConstant}(xinit, zeros(2)) end function init_primal(xinit::Image, ::Type{DisplacementFull}) return Primal{DisplacementFull}(xinit, zeros(2, size(xinit)...)) end function init_rest(x::Primal{DisplacementT}) where DisplacementT imdim=size(x.x) y = Dual(zeros(2, imdim...), zeros(imdim)) Δx = copy(x) Δy = copy(y) x̄ = copy(x) return x, y, Δx, Δy, x̄ end function init_iterates( :: Type{DisplacementT}, xinit::Primal{DisplacementT}) where DisplacementT return init_rest(copy(xinit)) end function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT return init_rest(init_primal(copy(xinit), DisplacementT)) end function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT return init_rest(init_primal(zeros(dim...), DisplacementT)) end ############################################## # Weighting for different displacements types ############################################## norm²weight( :: Type{DisplacementConstant}, sz ) = prod(sz) norm²weight( :: Type{DisplacementFull}, sz ) = 1 ############ # Algorithm ############ function solve( :: Type{DisplacementT}; dim :: ImageSize, iterate = AlgTools.simple_iterate, params::NamedTuple) where DisplacementT ###################### # Initialise iterates ###################### x, y, Δx, Δy, x̄ = init_iterates(DisplacementT, dim) init_data = (params.init == :data) # … for tracking cumulative movement if DisplacementT == DisplacementConstant ucumul = [0.0, 0.0] else ucumul = [NaN, NaN] end ############################################# # Extract parameters and set up step lengths ############################################# α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep R_K² = max(∇₂_norm₂₂_est², ∇₂_norm₂∞_est²*params.dynrange^2) γ = min(1, λ*norm²weight(DisplacementT, size(x.x))) τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²) kernel = params.kernel #################### # Run the algorithm #################### b_next_filt=nothing v = iterate(params) do verbose :: Function, b :: Image, 🚫unused_v_known :: DisplacementT, b_next :: Image #################################### # Smooth data for Horn–Schunck term #################################### b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel) ############################ # Construct K for this step ############################ K! = (yʹ, xʹ) -> begin # Optical flow part: ⟨⟨u, ∇b_k⟩⟩. # Use y.tv as temporary gradient storage. pointwise_gradiprod_2d!(yʹ.flow, yʹ.tv, xʹ.u, b_filt) #@. yʹ.flow = -yʹ.flow # TV part ∇₂!(yʹ.tv, xʹ.x) end Kᵀ! = (xʹ, yʹ) -> begin # Optical flow part: ∇b_k ⋅ y pointwise_gradiprod_2dᵀ!(xʹ.u, yʹ.flow, b_filt) #@. xʹ.u = -xʹ.u # TV part ∇₂ᵀ!(xʹ.x, yʹ.tv) end ################## # Prediction step ################## if init_data x .= b init_data = false end pdflow!(x.x, Δx.x, y.tv, Δy.tv, y.flow, Δy.flow, x.u, params.dual_flow) # Predict zero displacement x.u .= 0 if params.prox_predict K!(Δy, x) @. y.tv = (y.tv + σ̃*Δy.tv)/(1 + σ̃*(ρ̃+ρ/α)) proj_norm₂₁ball!(y.tv, α) @. y.flow = (y.flow+σ̃*((b_next_filt-b_filt)/T+Δy.flow))/(1+σ̃*(ρ̃+1/θ)) end ############ # PDPS step # # NOTE: For DisplacementConstant, the x.u update is supposed to be with # respect to the 𝟙^*𝟙 norm/inner product that makes the norm equivalent # to full-space norm when restricted to constant displacements. Since # `OpticalFlow.pointwise_gradiprod_2dᵀ!` already uses this inner product, # and the λ-weighted term in the problem is with respect to this norm, # all the norm weights disappear in this update. ############ Kᵀ!(Δx, y) # primal step: @. x̄.x = x.x # | save old x for over-relax @. x̄.u = x.u # | @. x.x = (x.x-τ*(Δx.x-b))/(1+τ) # | prox @. x.u = (x.u-τ*Δx.u)/(1+τ*λ) # | @. x̄.x = 2x.x - x̄.x # over-relax @. x̄.u = 2x.u - x̄.u # | K!(Δy, x̄) # dual step: y @. y.tv = (y.tv + σ*Δy.tv)/(1 + σ*ρ/α) # | proj_norm₂₁ball!(y.tv, α) # | prox @. y.flow = (y.flow+σ*((b_next_filt-b_filt)/T+Δy.flow))/(1+σ/θ) if DisplacementT == DisplacementConstant ucumul .+= x.u end ######################################################## # Give function value and cumulative movement if needed ######################################################## v = verbose() do K!(Δy, x) value = (norm₂²(b-x.x)/2 + θ*norm₂²((b_next_filt-b_filt)./T+Δy.flow) + λ*norm₂²(x.u)/2 + α*γnorm₂₁(Δy.tv, ρ)) value, x.x, ucumul, nothing end return v end return x, y, v end end # Module