src/Util.jl

changeset 0
888dfd34d24a
child 4
59fd17a3cea0
equal deleted inserted replaced
-1:000000000000 0:888dfd34d24a
1 #########################
2 # Some utility functions
3 #########################
4
5 module Util
6
7 ##############
8 # Our exports
9 ##############
10
11 export map_first_slice!,
12 reduce_first_slice,
13 norm₂,
14 γnorm₂,
15 norm₂w,
16 norm₂²,
17 norm₂w²,
18 norm₂₁,
19 γnorm₂₁,
20 dot,
21 mean,
22 proj_norm₂₁ball!,
23 curry,
24 ⬿
25
26 ########################
27 # Functional programming
28 #########################
29
30 curry = (f::Function,y...)->(z...)->f(y...,z...)
31
32 ###############################
33 # For working with NamedTuples
34 ###############################
35
36 ⬿ = merge
37
38 ######
39 # map
40 ######
41
42 @inline function map_first_slice!(f!, y)
43 for i in CartesianIndices(size(y)[2:end])
44 @inbounds f!(@view(y[:, i]))
45 end
46 end
47
48 @inline function map_first_slice!(x, f!, y)
49 for i in CartesianIndices(size(y)[2:end])
50 @inbounds f!(@view(x[:, i]), @view(y[:, i]))
51 end
52 end
53
54 @inline function reduce_first_slice(f, y; init=0.0)
55 accum=init
56 for i in CartesianIndices(size(y)[2:end])
57 @inbounds accum=f(accum, @view(y[:, i]))
58 end
59 return accum
60 end
61
62 ###########################
63 # Norms and inner products
64 ###########################
65
66 @inline function dot(x, y)
67 @assert(length(x)==length(y))
68
69 accum=0
70 for i=1:length(y)
71 @inbounds accum += x[i]*y[i]
72 end
73 return accum
74 end
75
76 @inline function norm₂w²(y, w)
77 #Insane memory allocs
78 #return @inbounds sum(i -> y[i]*y[i]*w[i], 1:length(y))
79 accum=0
80 for i=1:length(y)
81 @inbounds accum=accum+y[i]*y[i]*w[i]
82 end
83 return accum
84 end
85
86 @inline function norm₂w(y, w)
87 return √(norm₂w²(y, w))
88 end
89
90 @inline function norm₂²(y)
91 #Insane memory allocs
92 #return @inbounds sum(i -> y[i]*y[i], 1:length(y))
93 accum=0
94 for i=1:length(y)
95 @inbounds accum=accum+y[i]*y[i]
96 end
97 return accum
98 end
99
100 @inline function norm₂(y)
101 return √(norm₂²(y))
102 end
103
104 @inline function γnorm₂(y, γ)
105 hubersq = xsq -> begin
106 x=√xsq
107 return if x > γ
108 x-γ/2
109 elseif x<-γ
110 -x-γ/2
111 else
112 xsq/(2γ)
113 end
114 end
115
116 if γ==0
117 return norm₂(y)
118 else
119 return hubersq(norm₂²(y))
120 end
121 end
122
123 function norm₂₁(y)
124 return reduce_first_slice((s, x) -> s+norm₂(x), y)
125 end
126
127 function γnorm₂₁(y,γ)
128 return reduce_first_slice((s, x) -> s+γnorm₂(x, γ), y)
129 end
130
131 function mean(v)
132 return sum(v)/prod(size(v))
133 end
134
135 @inline function proj_norm₂₁ball!(y, α)
136 α²=α*α
137 y′=reshape(y, (size(y, 1), prod(size(y)[2:end])))
138
139 @inbounds @simd for i=1:size(y′, 2)# in CartesianIndices(size(y)[2:end])
140 n² = norm₂²(@view(y′[:, i]))
141 if n²>α²
142 y′[:, i] .*= (α/√n²)
143 end
144 end
145 end
146
147 end # Module
148

mercurial