src/AlgorithmBothNL.jl

changeset 0
a55e35d20336
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/AlgorithmBothNL.jl	Tue Apr 07 14:19:48 2020 -0500
@@ -0,0 +1,242 @@
+######################################################################
+# Predictive online PDPS for optical flow with unknown velocity field
+######################################################################
+
+__precompile__()
+
+module AlgorithmBothNL
+
+identifier = "pdps_unknown_nl"
+
+
+using Printf
+
+using AlgTools.Util
+import AlgTools.Iterate
+using ImageTools.Gradient
+
+using ..OpticalFlow: ImageSize,
+                     Image,
+                     Gradient,
+                     DisplacementConstant,
+                     DisplacementFull,
+                     pdflow!,
+                     pointwise_gradiprod_2d!,
+                     pointwise_gradiprod_2dᵀ!,
+                     filter_hs
+
+using ..Algorithm: step_lengths
+
+#############
+# Data types
+#############
+
+struct Primal{DisplacementT}
+    x :: Image
+    u :: DisplacementT
+end
+
+function Base.similar(x::Primal{DisplacementT}) where DisplacementT
+   return Primal{DisplacementT}(Base.similar(x.x), Base.similar(x.u))
+end
+
+function Base.copy(x::Primal{DisplacementT}) where DisplacementT
+    return Primal{DisplacementT}(Base.copy(x.x), Base.copy(x.u))
+ end
+
+struct Dual
+    tv :: Gradient
+    flow :: Image
+end
+
+function Base.similar(y::Dual)
+   return Dual(Base.similar(y.tv), Base.similar(y.flow))
+end
+
+function Base.copy(y::Dual)
+    return Dual(Base.copy(y.tv), Base.copy(y.flow))
+ end
+
+#########################
+# Iterate initialisation
+#########################
+
+function init_primal(xinit::Image, ::Type{DisplacementConstant})
+    return Primal{DisplacementConstant}(xinit, zeros(2))
+end
+
+function init_primal(xinit::Image, ::Type{DisplacementFull})
+    return Primal{DisplacementFull}(xinit, zeros(2, size(xinit)...))
+end
+
+function init_rest(x::Primal{DisplacementT}) where DisplacementT
+    imdim=size(x.x)
+
+    y = Dual(zeros(2, imdim...), zeros(imdim))
+    Δx = copy(x)
+    Δy = copy(y)
+    x̄ = copy(x)
+
+    return x, y, Δx, Δy, x̄
+end
+
+function init_iterates( :: Type{DisplacementT},
+                       xinit::Primal{DisplacementT}) where DisplacementT       
+    return init_rest(copy(xinit))
+end
+
+function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT    
+    return init_rest(init_primal(copy(xinit), DisplacementT))
+end
+
+function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT
+    return init_rest(init_primal(zeros(dim...), DisplacementT))
+end
+
+##############################################
+# Weighting for different displacements types
+##############################################
+
+norm²weight( :: Type{DisplacementConstant}, sz ) = prod(sz)
+norm²weight( :: Type{DisplacementFull}, sz ) = 1
+
+############
+# Algorithm
+############
+
+function solve( :: Type{DisplacementT};
+               dim :: ImageSize,
+               iterate = AlgTools.simple_iterate, 
+               params::NamedTuple) where DisplacementT
+
+    ######################
+    # Initialise iterates
+    ######################
+
+    x, y, Δx, Δy, x̄ = init_iterates(DisplacementT, dim)
+    init_data = (params.init == :data)
+
+    # … for tracking cumulative movement
+    if DisplacementT == DisplacementConstant
+        ucumul = [0.0, 0.0]
+    else
+        ucumul = [NaN, NaN]
+    end
+
+    #############################################
+    # Extract parameters and set up step lengths
+    #############################################
+
+    α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep
+    R_K² = max(∇₂_norm₂₂_est², ∇₂_norm₂∞_est²*params.dynrange^2)
+    γ = min(1, λ*norm²weight(DisplacementT, size(x.x)))
+    τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²)
+
+    kernel = params.kernel
+
+    ####################
+    # Run the algorithm
+    ####################
+
+    b_next_filt=nothing
+
+    v = iterate(params) do verbose :: Function,
+                           b :: Image,
+                           🚫unused_v_known :: DisplacementT,
+                           b_next :: Image
+
+        ####################################
+        # Smooth data for Horn–Schunck term
+        ####################################
+
+        b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel)
+
+        ############################
+        # Construct K for this step
+        ############################
+
+        K! = (yʹ, xʹ) -> begin
+            # Optical flow part
+            @. yʹ.flow = b_filt
+            flow!(yʹ.flow, Δx.x, xʹ.u)
+            @. yʹ.flow = yʹ.flow - b_next_filt
+            # TV part
+            ∇₂!(yʹ.tv, xʹ.x) 
+        end
+        Kᵀ! = (xʹ, yʹ) -> begin
+            # Optical flow part: ∇b_k ⋅ y
+            #
+            # TODO: This really should depend x.u, but x.u is zero. 
+            #
+            pointwise_gradiprod_2dᵀ!(xʹ.u, yʹ.flow, b_filt)
+            # TV part
+            ∇₂ᵀ!(xʹ.x, yʹ.tv) 
+        end
+
+        ##################
+        # Prediction step
+        ##################
+
+        if init_data
+            x .= b
+            init_data = false
+        end
+
+        pdflow!(x.x, Δx.x, y.tv, Δy.tv, y.flow, Δy.flow, x.u, params.dual_flow)
+
+        # Predict zero displacement
+        x.u .= 0
+        if params.prox_predict
+            K!(Δy, x)
+            @. y.tv = (y.tv + σ̃*Δy.tv)/(1 + σ̃*(ρ̃+ρ/α))
+            proj_norm₂₁ball!(y.tv, α) 
+            @. y.flow = (y.flow+σ̃*Δy.flow)/(1+σ̃*(ρ̃+1/θ))
+        end
+
+        ############
+        # PDPS step
+        #
+        # NOTE: For DisplacementConstant, the x.u update is supposed to be with
+        # respect to the 𝟙^*𝟙 norm/inner product that makes the norm equivalent
+        # to full-space norm when restricted to constant displacements. Since
+        # `OpticalFlow.pointwise_gradiprod_2dᵀ!` already uses this inner product,
+        # and the λ-weighted term in the problem is with respect to this norm,
+        # all the norm weights disappear in this update.
+        ############
+
+        Kᵀ!(Δx, y)                             # primal step:
+        @. x̄.x = x.x                           # |  save old x for over-relax
+        @. x̄.u = x.u                           # |  
+        @. x.x = (x.x-τ*(Δx.x-b))/(1+τ)        # |  prox
+        @. x.u = (x.u-τ*Δx.u)/(1+τ*λ)          # |  
+        @. x̄.x = 2x.x - x̄.x                    # over-relax
+        @. x̄.u = 2x.u - x̄.u                    # |
+        K!(Δy, x̄)                              # dual step: y
+        @. y.tv = (y.tv + σ*Δy.tv)/(1 + σ*ρ/α) # |
+        proj_norm₂₁ball!(y.tv, α)              # |  prox
+        @. y.flow = (y.flow+σ*Δy.flow)/(1+σ/θ)
+
+        if DisplacementT == DisplacementConstant
+            ucumul .+= x.u
+        end
+
+        ########################################################
+        # Give function value and cumulative movement if needed
+        ########################################################
+        v = verbose() do
+            K!(Δy, x)
+            value = (norm₂²(b-x.x)/2 + θ*norm₂²(Δy.flow)
+                     + λ*norm₂²(x.u)/2 + α*γnorm₂₁(Δy.tv, ρ))
+
+            value, x.x, ucumul, nothing
+        end
+
+        return v
+    end
+
+    return x, y, v
+end
+
+end # Module
+
+

mercurial