134 return sum(v)/prod(size(v)) |
134 return sum(v)/prod(size(v)) |
135 end |
135 end |
136 |
136 |
137 @inline function proj_norm₂₁ball!(y, α) |
137 @inline function proj_norm₂₁ball!(y, α) |
138 α²=α*α |
138 α²=α*α |
139 y′=reshape(y, (size(y, 1), prod(size(y)[2:end]))) |
|
140 |
139 |
141 @inbounds @simd for i=1:size(y′, 2)# in CartesianIndices(size(y)[2:end]) |
140 if ndims(y)==3 && size(y, 1)==2 |
142 n² = norm₂²(@view(y′[:, i])) |
141 @inbounds for i=1:size(y, 2) |
143 if n²>α² |
142 @simd for j=1:size(y, 3) |
144 y′[:, i] .*= (α/√n²) |
143 n² = y[1,i,j]*y[1,i,j]+y[2,i,j]*y[2,i,j] |
|
144 if n²>α² |
|
145 v = α/√n² |
|
146 y[1, i, j] *= v |
|
147 y[2, i, j] *= v |
|
148 end |
|
149 end |
|
150 end |
|
151 else |
|
152 y′=reshape(y, (size(y, 1), prod(size(y)[2:end]))) |
|
153 |
|
154 @inbounds @simd for i=1:size(y′, 2)# in CartesianIndices(size(y)[2:end]) |
|
155 n² = norm₂²(@view(y′[:, i])) |
|
156 if n²>α² |
|
157 y′[:, i] .*= (α/√n²) |
|
158 end |
145 end |
159 end |
146 end |
160 end |
147 end |
161 end |
148 |
162 |
149 end # Module |
163 end # Module |