Thu, 18 Apr 2024 10:51:10 +0300
commit before adding PET
###################################################################### # Predictive online PDPS for optical flow with unknown velocity field ###################################################################### __precompile__() module AlgorithmBothMulti using Printf using AlgTools.Util import AlgTools.Iterate using ImageTools.Gradient using ImageTools.ImFilter using ..OpticalFlow: Image, ImageSize, DisplacementConstant, pdflow!, horn_schunck_reg_prox!, pointwise_gradiprod_2d!, horn_schunck_reg_prox_op!, mldivide_step_plus_sym2x2!, ConstantDisplacementHornSchunckData, filter_hs using ..Algorithm: step_lengths ######################### # Iterate initialisation ######################### function init_displ(xinit::Image, ::Type{DisplacementConstant}, n::Integer) return xinit, zeros(n, 2) end # function init_displ(xinit::Image, ::Type{DisplacementFull}, n::Integer) # return xinit, zeros(n, 2, size(xinit)...) # end function init_rest(x::Image, u::DisplacementT) where DisplacementT imdim=size(x) y = zeros(2, imdim...) Δx = copy(x) Δy = copy(y) x̄ = copy(x) return x, y, Δx, Δy, x̄, u end function init_iterates( :: Type{DisplacementT}, xinit::Image, n::Integer) where DisplacementT return init_rest(init_displ(copy(xinit), DisplacementT, n)...) end function init_iterates( :: Type{DisplacementT}, dim::ImageSize, n::Integer) where DisplacementT return init_rest(init_displ(zeros(dim...), DisplacementT, n)...) end ############ # Algorithm ############ function solve( :: Type{DisplacementT}; dim :: ImageSize, iterate = AlgTools.simple_iterate, params::NamedTuple) where DisplacementT ###################### # Initialise iterates ###################### n = params.displacement_count k = 0 # number of displacements we have already x, y, Δx, Δy, x̄, u = init_iterates(DisplacementT, dim, n) init_data = (params.init == :data) hs = [ConstantDisplacementHornSchunckData() for i=1:n] #hs = Array{ConstantDisplacementHornSchunckData}(undef, n) A = Array{Float64,3}(undef, n, 2, 2) d = Array{Float64,2}(undef, n, 2) # … for tracking cumulative movement ucumulbase = [0.0, 0.0] ############################################# # Extract parameters and set up step lengths ############################################# α, ρ, λ, θ = params.α, params.ρ, params.λ, params.θ R_K² = ∇₂_norm₂₂_est² γ = 1 τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²) kernel = params.kernel T = params.timestep #################### # Run the algorithm #################### b_next_filt = nothing diffu = similar(u[1, :]) v = iterate(params) do verbose :: Function, b :: Image, 🚫unused_v_known :: DisplacementT, b_next :: Image #################################### # Smooth data for Horn–Schunck term #################################### b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel) ################################################ # Prediction step ################################################ # Predict x and y if k==0 if init_data x .= b init_data = false end else # Displacement from previous to this image is estimated as # the difference of their displacements from beginning of window. if k>1 @. @views diffu = u[k, :] - u[k-1, :] else @. @views diffu = u[k, :] end pdflow!(x, Δx, y, Δy, diffu, params.dual_flow) end # Shift stored prox matrices if k==n tmp = copy(u[1, :]) ucumulbase .+= tmp for j=1:(n-1) @. @views u[j, :] = u[j+1, :] - tmp hs[j] = hs[j+1] end # Create new struct as original contains references to objects that # have been moved to index n-1. hs[n]=ConstantDisplacementHornSchunckData() else k += 1 end # Predict u: zero displacement from current to next image, i.e., # same displacement to beginning of window. if k==1 @. @views u[k, :] = 0.0 else @. @views u[k, :] = u[k-1, :] end # Predictor proximals tep if params.prox_predict ∇₂!(Δy, x) @. y = (y + σ̃*Δy)/(1 + σ̃*(ρ̃+ρ/α)) proj_norm₂₁ball!(y, α) end ################################################################################# # PDPS step # # For the displacements, with τ̃=τ/k, we need to solve for 2≤j<k, # # (1) (I/τ̃+M₀^j+M₀^{j+1})u^j = M₀^ju^{j-1} + M₀^{j+1}u^{j+1} # + ũ^j/τ̃ - z^j + z^{j+1}, # # as well as # # (2) (I/τ̃+M₀^k)u^k = M₀^k u^{k-1} + ũ^k/τ̃ - z^k # # and # # (3) (I/τ̃+M₀^1+M₀^2)u^1 = 0 + M₀^{2}u^{2} + ũ^1/τ̃ - z^1 + z^{2} # # We first construct from (2) that # # u^k = A^k u^{k-1} + d^k # # for # # A^k := (I/τ̃+M₀^k)^{-1} M₀^k # d_k := (I/τ̃+M₀^k)^{-1} (ũ^k/τ̃ - z^k). # # Inserting this into (1) we need # # (4) (I/τ̃+M₀^j+M₀^{j+1}(I-A^{j+1}))u^j = M₀^ju^{j-1} + M₀^{j+1}d^{j+1} # + ũ^j/τ̃ - z^j + z^{j+1}. # # This is well-defined because A^{j+1} < I. It also has the same form as (1), so # we continue with # # (5) u^j = A^j u^{j-1} + d^j # # for # # A^j := (I/τ̃+M₀^j+M₀^{j+1}(I-A^{j+1}))^{-1} M₀^j # d^j := (I/τ̃+M₀^j+M₀^{j+1}(I-A^{j+1}))^{-1} # (M₀^{j+1}d^{j+1} + ũ^j/τ̃ - z^j + z^{j+1}) # # Finally from (3) with these we need # # (I/τ̃+M₀^1+M₀^2(I-A^2))u^1 = M₀^2d^2 + ũ^1/τ̃ - z^1 + z^2, # # which is of the same form as (4) with u^0=0, so by (5) u^1=d^1. # ################################################################################# ∇₂ᵀ!(Δx, y) # primal step: @. x̄ = x # | save old x for over-relax @. x = (x-τ*(Δx-b))/(1+τ) # | prox # | | for displacement # Calculate matrices for latest data; rest is stored. @views begin horn_schunck_reg_prox_op!(hs[k], b_next_filt, b_filt, θ, λ, T) τ̃=τ/k B = hs[k].M₀ c = u[k, :]./τ̃-hs[k].z mldivide_step_plus_sym2x2!(A[k, :, :], B, B, τ̃) mldivide_step_plus_sym2x2!(d[k, :], B, c, τ̃) for j=(k-1):-1:1 B = hs[j].M₀+hs[j+1].M₀*([1 0; 0 1]-A[j+1, :, :]) c = hs[j+1].M₀*d[j+1, :]+u[j, :]./τ̃-hs[j].z+hs[j+1].z mldivide_step_plus_sym2x2!(A[j, :, :], B, hs[j].M₀, τ̃) mldivide_step_plus_sym2x2!(d[j, :], B, c, τ̃) end u[1, :] .= d[1, :] for j=2:k u[j, :] .= A[j, :, :]*u[j-1, :] + d[j, :] end end @. x̄ = 2x - x̄ # over-relax: x̄ = 2x-x_old ∇₂!(Δy, x̄) # dual step: y @. y = (y + σ*Δy)/(1 + σ*ρ/α) # | proj_norm₂₁ball!(y, α) # | prox ######################################################## # Give function value and cumulative movement if needed ######################################################## v = verbose() do ∇₂!(Δy, x) hs_plus_reg=0 for j=1:k v=(j==1 ? u[j, :] : u[j, :]-u[j-1, :]) hs_plus_reg += hs[j].cv/2 + dot(hs[j].Mv*v, v)/2+dot(hs[j].av, v) end value = (norm₂²(b-x)/2 + hs_plus_reg/k + α*γnorm₂₁(Δy, ρ)) value, x, u[k, :]+ucumulbase, u[1:k,:].+ucumulbase' end return v end return x, y, v end end # Module