|
1 ###################################################################### |
|
2 # Predictive online PDPS for optical flow with unknown velocity field |
|
3 ###################################################################### |
|
4 |
|
5 __precompile__() |
|
6 |
|
7 module AlgorithmBothGreedyV |
|
8 |
|
9 identifier = "pdps_unknown_greedyv" |
|
10 |
|
11 using Printf |
|
12 |
|
13 using AlgTools.Util |
|
14 import AlgTools.Iterate |
|
15 using ImageTools.Gradient |
|
16 |
|
17 using ..OpticalFlow: Image, |
|
18 ImageSize, |
|
19 DisplacementConstant, |
|
20 DisplacementFull, |
|
21 pdflow!, |
|
22 horn_schunck_reg_prox!, |
|
23 pointwise_gradiprod_2d!, |
|
24 filter_hs |
|
25 |
|
26 using ..Algorithm: step_lengths |
|
27 |
|
28 ######################### |
|
29 # Iterate initialisation |
|
30 ######################### |
|
31 |
|
32 function init_displ(xinit::Image, ::Type{DisplacementConstant}) |
|
33 return xinit, zeros(2) |
|
34 end |
|
35 |
|
36 function init_displ(xinit::Image, ::Type{DisplacementFull}) |
|
37 return xinit, zeros(2, size(xinit)...) |
|
38 end |
|
39 |
|
40 function init_rest(x::Image, u::DisplacementT) where DisplacementT |
|
41 imdim=size(x) |
|
42 |
|
43 y = zeros(2, imdim...) |
|
44 Δx = copy(x) |
|
45 Δy = copy(y) |
|
46 x̄ = copy(x) |
|
47 |
|
48 return x, y, Δx, Δy, x̄, u |
|
49 end |
|
50 |
|
51 function init_iterates( :: Type{DisplacementT}, xinit::Image) where DisplacementT |
|
52 return init_rest(init_displ(copy(xinit), DisplacementT)...) |
|
53 end |
|
54 |
|
55 function init_iterates( :: Type{DisplacementT}, dim::ImageSize) where DisplacementT |
|
56 return init_rest(init_displ(zeros(dim...), DisplacementT)...) |
|
57 end |
|
58 |
|
59 ############ |
|
60 # Algorithm |
|
61 ############ |
|
62 |
|
63 function solve( :: Type{DisplacementT}; |
|
64 dim :: ImageSize, |
|
65 iterate = AlgTools.simple_iterate, |
|
66 params::NamedTuple) where DisplacementT |
|
67 |
|
68 ###################### |
|
69 # Initialise iterates |
|
70 ###################### |
|
71 |
|
72 x, y, Δx, Δy, x̄, u = init_iterates(DisplacementT, dim) |
|
73 init_data = (params.init == :data) |
|
74 |
|
75 # … for tracking cumulative movement |
|
76 if DisplacementT == DisplacementConstant |
|
77 ucumul = [0.0, 0.0] |
|
78 else |
|
79 ucumul = [NaN, NaN] |
|
80 end |
|
81 |
|
82 ############################################# |
|
83 # Extract parameters and set up step lengths |
|
84 ############################################# |
|
85 |
|
86 α, ρ, λ, θ, T = params.α, params.ρ, params.λ, params.θ, params.timestep |
|
87 R_K² = ∇₂_norm₂₂_est² |
|
88 γ = 1 |
|
89 τ, σ, σ̃, ρ̃ = step_lengths(params, γ, R_K²) |
|
90 |
|
91 kernel = params.kernel |
|
92 |
|
93 #################### |
|
94 # Run the algorithm |
|
95 #################### |
|
96 |
|
97 b_next_filt=nothing |
|
98 |
|
99 v = iterate(params) do verbose :: Function, |
|
100 b :: Image, |
|
101 🚫unused_v_known :: DisplacementT, |
|
102 b_next :: Image |
|
103 |
|
104 #################################### |
|
105 # Smooth data for Horn–Schunck term |
|
106 #################################### |
|
107 |
|
108 b_filt, b_next_filt = filter_hs(b, b_next, b_next_filt, kernel) |
|
109 |
|
110 ################## |
|
111 # Prediction step |
|
112 ################## |
|
113 |
|
114 if init_data |
|
115 x .= b |
|
116 init_data = false |
|
117 end |
|
118 |
|
119 pdflow!(x, Δx, y, Δy, u, params.dual_flow) |
|
120 |
|
121 # Predict zero displacement |
|
122 u .= 0 |
|
123 if params.prox_predict |
|
124 ∇₂!(y, x) |
|
125 @. y = (y + σ̃*Δy)/(1 + σ̃*(ρ̃+ρ/α)) |
|
126 proj_norm₂₁ball!(y, α) |
|
127 end |
|
128 |
|
129 ############ |
|
130 # PDPS step |
|
131 ############ |
|
132 |
|
133 ∇₂ᵀ!(Δx, y) # primal step: |
|
134 @. x̄ = x # | save old x for over-relax |
|
135 @. x = (x-τ*(Δx-b))/(1+τ) # | prox |
|
136 horn_schunck_reg_prox!(u, b_next_filt, b_filt, θ, λ, T, τ) |
|
137 @. x̄ = 2x - x̄ # over-relax |
|
138 ∇₂!(y, x̄) # dual step: y |
|
139 @. y = (y + σ*Δy)/(1 + σ*ρ/α) # | |
|
140 proj_norm₂₁ball!(y, α) # | prox |
|
141 |
|
142 if DisplacementT == DisplacementConstant |
|
143 ucumul .+= u |
|
144 end |
|
145 |
|
146 ######################################################## |
|
147 # Give function value and cumulative movement if needed |
|
148 ######################################################## |
|
149 v = verbose() do |
|
150 ∇₂!(Δy, x) |
|
151 tmp = zeros(size(b_filt)) |
|
152 pointwise_gradiprod_2d!(tmp, Δy, u, b_filt) |
|
153 value = (norm₂²(b-x)/2 + θ*norm₂²((b_next_filt-b_filt)./T+tmp) |
|
154 + λ*norm₂²(u)/2 + α*γnorm₂₁(Δy, ρ)) |
|
155 |
|
156 value, x, ucumul, nothing |
|
157 end |
|
158 |
|
159 return v |
|
160 end |
|
161 |
|
162 return x, y, v |
|
163 end |
|
164 |
|
165 end # Module |
|
166 |
|
167 |