--- a/src/AlgorithmBothNL.jl Thu Apr 25 13:05:40 2024 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,242 +0,0 @@ -###################################################################### -# Predictive online PDPS for optical flow with unknown velocity field -###################################################################### - -__precompile__() - -module AlgorithmBothNL - -identifier = "pdps_unknown_nl" - - -using Printf - -using AlgTools.Util -import AlgTools.Iterate -using ImageTools.Gradient - -using ..OpticalFlow: ImageSize, - Image, - Gradient, - DisplacementConstant, - DisplacementFull, - pdflow!, - pointwise_gradiprod_2d!, - pointwise_gradiprod_2dᵀ!, - filter_hs - -using ..Algorithm: step_lengths - -############# -# Data types -############# - -struct Primal{DisplacementT} - x :: Image - u :: DisplacementT -end - -function Base.similar(x::Primal{DisplacementT}) where DisplacementT - return Primal{DisplacementT}(Base.similar(x.x), Base.similar(x.u)) -end - -function Base.copy(x::Primal{DisplacementT}) where DisplacementT - return Primal{DisplacementT}(Base.copy(x.x), Base.copy(x.u)) - end - -struct Dual - tv :: Gradient - flow :: Image -end - -function Base.similar(y::Dual) - return Dual(Base.similar(y.tv), Base.similar(y.flow)) -end - -function Base.copy(y::Dual) - return Dual(Base.copy(y.tv), Base.copy(y.flow)) - end - -######################### -# Iterate initialisation -######################### - -function init_primal(xinit::Image, ::Type{DisplacementConstant}) - return Primal{DisplacementConstant}(xinit, zeros(2)) -end - -function init_primal(xinit::Image, ::Type{DisplacementFull}) - return Primal{DisplacementFull}(xinit, zeros(2, size(xinit)...)) -end - -function init_rest(x::Primal{DisplacementT}) where DisplacementT - imdim=size(x.x) - - y = Dual(zeros(2, imdim...), zeros(imdim)) - Δx = copy(x) - Δy = copy(y) - x̄ = copy(x) - - return x, y, Δx, Δy, x̄ -end - -function init_iterates( :: Type{DisplacementT}, - xinit::Primal{DisplacementT}) where DisplacementT - return init_rest(copy(xinit)) -end - -function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT - return init_rest(init_primal(copy(xinit), DisplacementT)) -end - -function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT - return init_rest(init_primal(zeros(dim...), DisplacementT)) -end - -############################################## -# Weighting for different displacements types -############################################## - -norm²weight( :: Type{DisplacementConstant}, sz ) = prod(sz) -norm²weight( :: Type{DisplacementFull}, sz ) = 1 - -############ -# Algorithm -############ - -function solve( :: Type{DisplacementT}; - dim :: ImageSize, - iterate = AlgTools.simple_iterate, - params::NamedTuple) where DisplacementT - - ###################### - # Initialise iterates - ###################### - - x, y, Δx, Δy, x̄ = init_iterates(DisplacementT, dim) - init_data = (params.init == :data) - - # … for tracking cumulative movement - if DisplacementT == DisplacementConstant - ucumul = [0.0, 0.0] - else - ucumul = [NaN, NaN] - end - - ############################################# - # Extract parameters and set up step lengths - ############################################# - - α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep - R_K² = max(∇₂_norm₂₂_est², ∇₂_norm₂∞_est²*params.dynrange^2) - γ = min(1, λ*norm²weight(DisplacementT, size(x.x))) - τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²) - - kernel = params.kernel - - #################### - # Run the algorithm - #################### - - b_next_filt=nothing - - v = iterate(params) do verbose :: Function, - b :: Image, - 🚫unused_v_known :: DisplacementT, - b_next :: Image - - #################################### - # Smooth data for Horn–Schunck term - #################################### - - b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel) - - ############################ - # Construct K for this step - ############################ - - K! = (yʹ, xʹ) -> begin - # Optical flow part - @. yʹ.flow = b_filt - flow!(yʹ.flow, Δx.x, xʹ.u) - @. yʹ.flow = yʹ.flow - b_next_filt - # TV part - ∇₂!(yʹ.tv, xʹ.x) - end - Kᵀ! = (xʹ, yʹ) -> begin - # Optical flow part: ∇b_k ⋅ y - # - # TODO: This really should depend x.u, but x.u is zero. - # - pointwise_gradiprod_2dᵀ!(xʹ.u, yʹ.flow, b_filt) - # TV part - ∇₂ᵀ!(xʹ.x, yʹ.tv) - end - - ################## - # Prediction step - ################## - - if init_data - x .= b - init_data = false - end - - pdflow!(x.x, Δx.x, y.tv, Δy.tv, y.flow, Δy.flow, x.u, params.dual_flow) - - # Predict zero displacement - x.u .= 0 - if params.prox_predict - K!(Δy, x) - @. y.tv = (y.tv + σ̃*Δy.tv)/(1 + σ̃*(ρ̃+ρ/α)) - proj_norm₂₁ball!(y.tv, α) - @. y.flow = (y.flow+σ̃*Δy.flow)/(1+σ̃*(ρ̃+1/θ)) - end - - ############ - # PDPS step - # - # NOTE: For DisplacementConstant, the x.u update is supposed to be with - # respect to the 𝟙^*𝟙 norm/inner product that makes the norm equivalent - # to full-space norm when restricted to constant displacements. Since - # `OpticalFlow.pointwise_gradiprod_2dᵀ!` already uses this inner product, - # and the λ-weighted term in the problem is with respect to this norm, - # all the norm weights disappear in this update. - ############ - - Kᵀ!(Δx, y) # primal step: - @. x̄.x = x.x # | save old x for over-relax - @. x̄.u = x.u # | - @. x.x = (x.x-τ*(Δx.x-b))/(1+τ) # | prox - @. x.u = (x.u-τ*Δx.u)/(1+τ*λ) # | - @. x̄.x = 2x.x - x̄.x # over-relax - @. x̄.u = 2x.u - x̄.u # | - K!(Δy, x̄) # dual step: y - @. y.tv = (y.tv + σ*Δy.tv)/(1 + σ*ρ/α) # | - proj_norm₂₁ball!(y.tv, α) # | prox - @. y.flow = (y.flow+σ*Δy.flow)/(1+σ/θ) - - if DisplacementT == DisplacementConstant - ucumul .+= x.u - end - - ######################################################## - # Give function value and cumulative movement if needed - ######################################################## - v = verbose() do - K!(Δy, x) - value = (norm₂²(b-x.x)/2 + θ*norm₂²(Δy.flow) - + λ*norm₂²(x.u)/2 + α*γnorm₂₁(Δy.tv, ρ)) - - value, x.x, ucumul, nothing - end - - return v - end - - return x, y, v -end - -end # Module - -