Thu, 18 Apr 2024 11:31:32 +0300
added seed restart
0 | 1 | ################### |
2 | # Image generation | |
3 | ################### | |
4 | ||
5 | module ImGenerate | |
6 | ||
7 | using ColorTypes: Gray | |
8 | import TestImages | |
9 | # We don't really *directly* depend on QuartzImageIO. The import here is | |
10 | # merely a workaround to suppress warnings when loading TestImages. | |
11 | # Something is broken in those packages. | |
12 | import QuartzImageIO | |
13 | ||
14 | using AlgTools.Util | |
15 | using AlgTools.Comms | |
16 | using ImageTools.Translate | |
17 | ||
18 | using ..OpticalFlow: Image, DisplacementConstant, DisplacementFull | |
19 | ||
5 | 20 | # Added for reproducibility |
21 | import StableRNGs: StableRNG, Random | |
22 | const rng = StableRNG(9182737465) | |
23 | ||
0 | 24 | ############## |
25 | # Our exports | |
26 | ############## | |
27 | ||
28 | export ImGen, | |
29 | OnlineData, | |
30 | imgen_square, | |
31 | imgen_shake | |
32 | ||
33 | ################## | |
34 | # Data structures | |
35 | ################## | |
36 | ||
37 | struct ImGen | |
38 | f :: Function | |
39 | dim :: Tuple{Int64,Int64} | |
40 | Λ :: Float64 | |
41 | dynrange :: Float64 | |
42 | name :: String | |
43 | end | |
44 | ||
45 | struct OnlineData{DisplacementT} | |
46 | b_true :: Image | |
47 | b_noisy :: Image | |
48 | v :: DisplacementT | |
49 | v_true :: DisplacementT | |
50 | v_cumul_true :: DisplacementT | |
51 | end | |
52 | ||
53 | ################### | |
54 | # Shake generation | |
55 | ################### | |
56 | ||
57 | function make_const_v(displ, sz) | |
58 | v = zeros(2, sz...) | |
59 | v[1, :, :] .= displ[1] | |
60 | v[2, :, :] .= displ[2] | |
61 | return v | |
62 | end | |
63 | ||
64 | function shake(params) | |
65 | if !haskey(params, :shaketype) || params.shaketype == :gaussian | |
5 | 66 | return () -> params.shake.*randn(rng,2) |
0 | 67 | elseif params.shaketype == :disk |
68 | return () -> begin | |
5 | 69 | θ = 2π*rand(rng,Float64) |
70 | r = params.shake*√(rand(rng,Float64)) | |
0 | 71 | return [r*cos(θ), r*sin(θ)] |
72 | end | |
73 | elseif params.shaketype == :circle | |
74 | return () -> begin | |
5 | 75 | θ = 2π*rand(rng,Float64) |
0 | 76 | r = params.shake |
77 | return [r*cos(θ), r*sin(θ)] | |
78 | end | |
79 | else | |
80 | error("Unknown shaketype $(params.shaketype)") | |
81 | end | |
82 | end | |
83 | ||
84 | pixelwise = (shakefn, sz) -> () -> make_const_u(shakefn(), sz) | |
85 | ||
86 | ################ | |
87 | # Moving square | |
88 | ################ | |
89 | ||
90 | function generate_square(sz, | |
91 | :: Type{DisplacementT}, | |
92 | datachannel :: Channel{OnlineData{DisplacementT}}, | |
93 | params) where DisplacementT | |
94 | ||
95 | if false | |
96 | v₀ = make_const_v(0.1.*(-1, 1), sz) | |
97 | nextv = () -> v₀ | |
98 | elseif DisplacementT == DisplacementFull | |
99 | nextv = pixelwise(shake(params), sz) | |
100 | elseif DisplacementT == DisplacementConstant | |
101 | nextv = shake(params) | |
102 | else | |
103 | @error "Invalid DisplacementT" | |
104 | end | |
105 | ||
106 | # Constant linear displacement everywhere has Jacobian determinant one | |
107 | # (modulo the boundaries which we ignore here) | |
108 | m = round(Int, sz[1]/5) | |
109 | b_orig = zeros(sz...) | |
110 | b_orig[sz[1].-(2*m:3*m), 2*m:3*m] .= 1 | |
111 | ||
112 | v_true = nextv() | |
113 | v_cumul = copy(v_true) | |
114 | ||
115 | while true | |
116 | # Flow original data and add noise | |
117 | b_true = zeros(sz...) | |
118 | translate_image!(b_true, b_orig, v_cumul; threads=true) | |
5 | 119 | b = b_true .+ params.noise_level.*randn(rng,sz...) |
120 | v = v_true.*(1.0 .+ params.shake_noise_level.*randn(rng,size(v_true)...)) | |
0 | 121 | # Pass true data to iteration routine |
122 | data = OnlineData{DisplacementT}(b_true, b, v, v_true, v_cumul) | |
123 | if !put_unless_closed!(datachannel, data) | |
124 | return | |
125 | end | |
126 | # Next step shake | |
127 | v_true = nextv() | |
128 | v_cumul .+= v_true | |
129 | end | |
130 | end | |
131 | ||
132 | function imgen_square(sz) | |
133 | return ImGen(curry(generate_square, sz), sz, 1, 1, "square$(sz[1])x$(sz[2])") | |
134 | end | |
135 | ||
136 | ################ | |
137 | # Shake a photo | |
138 | ################ | |
139 | ||
140 | function generate_shake_image(im, sz, | |
141 | :: Type{DisplacementConstant}, | |
142 | datachannel :: Channel{OnlineData{DisplacementConstant}}, | |
143 | params :: NamedTuple) | |
144 | ||
7 | 145 | # Restart the seed to enable comparison across predictors |
146 | Random.seed!(rng,9182737465) | |
147 | ||
0 | 148 | nextv = shake(params) |
149 | v_true = nextv() | |
150 | v_cumul = copy(v_true) | |
151 | ||
152 | while true | |
153 | # Extract subwindow of original image and add noise | |
154 | b_true = zeros(sz...) | |
155 | extract_subimage!(b_true, im, v_cumul; threads=true) | |
5 | 156 | b = b_true .+ params.noise_level.*randn(rng,sz...) |
157 | v = v_true.*(1.0 .+ params.shake_noise_level.*randn(rng,size(v_true)...)) | |
0 | 158 | # Pass data to iteration routine |
159 | data = OnlineData{DisplacementConstant}(b_true, b, v, v_true, v_cumul) | |
160 | if !put_unless_closed!(datachannel, data) | |
161 | return | |
162 | end | |
163 | # Next step shake | |
164 | v_true = nextv() | |
165 | v_cumul .+= v_true | |
166 | end | |
167 | end | |
168 | ||
169 | function imgen_shake(imname, sz) | |
170 | im = Float64.(Gray.(TestImages.testimage(imname))) | |
171 | dynrange = maximum(im) | |
172 | return ImGen(curry(generate_shake_image, im, sz), sz, 1, dynrange, | |
173 | "$(imname)$(sz[1])x$(sz[2])") | |
174 | end | |
175 | ||
176 | end # Module |