Sun, 21 Apr 2024 20:42:43 +0300
added plotting functions
###################################################################### # Predictive online PDPS for optical flow with unknown velocity field ###################################################################### __precompile__() module AlgorithmBothGreedyV identifier = "pdps_unknown_greedyv" using Printf using AlgTools.Util import AlgTools.Iterate using ImageTools.Gradient using ..OpticalFlow: Image, ImageSize, DisplacementConstant, DisplacementFull, pdflow!, horn_schunck_reg_prox!, pointwise_gradiprod_2d!, filter_hs using ..Algorithm: step_lengths ######################### # Iterate initialisation ######################### function init_displ(xinit::Image, ::Type{DisplacementConstant}) return xinit, zeros(2) end function init_displ(xinit::Image, ::Type{DisplacementFull}) return xinit, zeros(2, size(xinit)...) end function init_rest(x::Image, u::DisplacementT) where DisplacementT imdim=size(x) y = zeros(2, imdim...) Δx = copy(x) Δy = copy(y) x̄ = copy(x) return x, y, Δx, Δy, x̄, u end function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT return init_rest(init_displ(copy(xinit), DisplacementT)...) end function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT return init_rest(init_displ(zeros(dim...), DisplacementT)...) end ############ # Algorithm ############ function solve( :: Type{DisplacementT}; dim :: ImageSize, iterate = AlgTools.simple_iterate, params::NamedTuple) where DisplacementT ###################### # Initialise iterates ###################### x, y, Δx, Δy, x̄, u = init_iterates(DisplacementT, dim) init_data = (params.init == :data) # … for tracking cumulative movement if DisplacementT == DisplacementConstant ucumul = [0.0, 0.0] else ucumul = [NaN, NaN] end ############################################# # Extract parameters and set up step lengths ############################################# α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep R_K² = ∇₂_norm₂₂_est² γ = 1 τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²) kernel = params.kernel #################### # Run the algorithm #################### b_next_filt=nothing v = iterate(params) do verbose :: Function, b :: Image, 🚫unused_v_known :: DisplacementT, b_next :: Image #################################### # Smooth data for Horn–Schunck term #################################### b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel) ################## # Prediction step ################## if init_data x .= b init_data = false end pdflow!(x, Δx, y, Δy, u, params.dual_flow) # Predict zero displacement u .= 0 if params.prox_predict ∇₂!(y, x) @. y = (y + σ̃*Δy)/(1 + σ̃*(ρ̃+ρ/α)) proj_norm₂₁ball!(y, α) end ############ # PDPS step ############ ∇₂ᵀ!(Δx, y) # primal step: @. x̄ = x # | save old x for over-relax @. x = (x-τ*(Δx-b))/(1+τ) # | prox horn_schunck_reg_prox!(u, b_next_filt, b_filt, θ, λ, T, τ) @. x̄ = 2x - x̄ # over-relax ∇₂!(y, x̄) # dual step: y @. y = (y + σ*Δy)/(1 + σ*ρ/α) # | proj_norm₂₁ball!(y, α) # | prox if DisplacementT == DisplacementConstant ucumul .+= u end ######################################################## # Give function value and cumulative movement if needed ######################################################## v = verbose() do ∇₂!(Δy, x) tmp = zeros(size(b_filt)) pointwise_gradiprod_2d!(tmp, Δy, u, b_filt) value = (norm₂²(b-x)/2 + θ*norm₂²((b_next_filt-b_filt)./T+tmp) + λ*norm₂²(u)/2 + α*γnorm₂₁(Δy, ρ)) value, x, ucumul, nothing end return v end return x, y, v end end # Module