src/LinSolve.jl

changeset 37
f8be66557e0f
equal deleted inserted replaced
36:6dfa8001eed2 37:f8be66557e0f
1 """
2 Linear solvers for small problems.
3 """
4 module LinSolve
5
6 using ..Metaprogramming
7
8 export linsolve,
9 TupleMatrix
10
11 const TupleMatrix{M,N} = NTuple{M, NTuple{N, Float64}}
12
13 """
14 `linsolve(AB :: TupleMatrix{M,N}, :: Type{TupleMatrix{M, K}}) :: TupleMatrix{M, K}`
15
16 where
17
18 `TupleMatrix{M, N} = NTuple{M, NTuple{N, Float64}}`
19
20 “Immutable” Gaussian elimination on tuples: solve AX=B for X,
21 Both A and B are stored in AB. The second type parameter indicates the size of B.
22 """
23 @polly function linsolve₀(AB :: TupleMatrix{M,N}, :: Val{K}) :: TupleMatrix{M, K} where {N,M,K}
24 @assert(M == N - K)
25
26 k = 0
27
28 # Convert to row-echelon form
29 for h = 1:(M-1)
30 # Find pivotable column (has some non-zero entries in rows ≥ h)
31 v = 0.0
32 î = h
33 while k ≤ N-1 && v == 0
34 k = k + 1
35 v = abs(AB[h][k])
36 # Find row ≥ h of maximum absolute value in this column
37 for i=(h+1):M
38 local v′ = abs(AB[i][k])
39 if v′ > v
40 î = i
41 v = v′
42 end
43 end
44 end
45
46 if v > 0
47 AB = (
48 AB[1:(h-1)]...,
49 AB[î],
50 (let ĩ = (i==î ? h : i), h̃ = î, f = AB[ĩ][k] / AB[h̃][k]
51 ((0.0 for _ = 1:k)..., (AB[ĩ][j]-AB[h̃][j]*f for j = (k+1):N)...,)
52 end for i = (h+1):M)...
53 )
54 end
55 end
56
57 # Solve UAX=UB for X where UA with U presenting the transformations above an
58 # upper triangular matrix.
59 X = ()
60 for i=M:-1:1
61 r = .+(AB[i][M+1:end], (-AB[i][j].*X[j-i] for j=i+1:M)...)./AB[i][i]
62 X = (r, X...)
63 end
64 return X
65 end
66
67 @inline function linsolve₀(AB :: TupleMatrix{M,N}) :: NTuple{M,Float64} where {N,M}
68 X = linsolve₀(AB, Val(1))
69 return ((x for (x,) ∈ X)...,)
70 end
71
72 #@generated function linsolve₁(AB :: TupleMatrix{M,N}, :: Type{TupleMatrix{M, K}}) :: TupleMatrix{M, K} where {N,M,K}
73 function generate_linsolve(AB :: Symbol, M :: Int, N :: Int, K :: Int)
74 @assert(M == N - K)
75 # The variables of ABN collect the stepwise stages of the transformed matrix.
76 # Initial i-1 entries of ABN[i] are never used, as the previous steps have already
77 # finalised the corresponding rows, but are included as “missing” (at compile time)
78 # for the sake of indexing clarity. The M-1:th row-wise step finalises the row-echelon
79 # form, so ABN has M-1 rows itself.
80 step_ABN(step) = ((missing for i=1:step-1)...,
81 (gensym("ABN_$(step)_$(i)") for i ∈ step:M)...,)
82 ABN = ((step_ABN(step) for step=1:M-1)...,)
83 # UAB “diagonally” refers to ABN to collate the final rows of the transformed matrix UAB.
84 # Since the M-1:th row-wise step finalises the row-echelon form, the last row comes already
85 # from the M-1:th step.
86 UAB = ((ABN[i][i] for i=1:M-1)..., ABN[M-1][M])
87 # The variables of X collect the rows of the solution to AX=B.
88 X = ((gensym("X_$(i)") for i ∈ 1:M)...,)
89
90 # Convert to row-echelon form. On each step we strip leading zeroes from ABN.
91 # In the end UAB should be upper-triangular.
92 convert_row(ABNout, h, ABNin) = quote
93 # Find pivotable column (has some non-zero entries in rows ≥ h)
94 $(ABNout[h]) = $(ABNin[h])
95 local v = abs($(ABNout[h])[1])
96 # Find row ≥ h of maximum absolute value in this column
97 $(sequence_exprs(:(
98 let v′ = abs($(ABNin[i])[1])
99 if v′ > v
100 $(ABNout[h]), $(ABNin[i]) = $(ABNin[i]), $(ABNout[h])
101 v = v′
102 end
103 end
104 ) for i=(h+1):M))
105
106 $(lift_exprs(ABNout[h+1:M])) = v > 0 ? ( $(lift_exprs( :(
107 # Transform
108 $(ABNin[i])[2:$N-$h+1] .- $(ABNout[h])[2:$N-$h+1].*( $(ABNin[i])[1] / $(ABNout[h])[1])
109 ) for i=h+1:M)) ) : ( $(lift_exprs( :(
110 # Strip leading zeroes
111 $(ABNin[i])[2:$N-$h+1]
112 ) for i=h+1:M)) )
113 end
114
115 # Solve UAX=UB for X where UA with U presenting the transformations above an
116 # upper triangular matrix.
117 solve_row(UAB, i) = :(
118 $(X[i]) = $(lift_exprs( :(
119 +($(UAB[i])[$M-$i+1+$k], $(( :( -$(UAB[i])[$j-$i+1]*$(X[j])[$k] ) for j=i+1:M)...))
120 ) for k=1:K )) ./ $(UAB[i])[1]
121 )
122
123 return X, quote
124 $(lift_exprs(ABN[1][i] for i=1:M)) = $(lift_exprs( :( $AB[$i] ) for i=1:M))
125 $(convert_row(ABN[1], 1, ABN[1]))
126 $((convert_row(ABN[h], h, ABN[h-1]) for h = 2:(M-1))...)
127 $((solve_row(UAB, i) for i=M:-1:1)...)
128 end
129 end
130
131 @inline @generated function linsolve₁(AB :: TupleMatrix{M,N}, :: Val{K}) :: TupleMatrix{M, K} where {N,M,K}
132 X, solver = generate_linsolve(:AB, M, N, K)
133 return quote
134 $solver
135 return $(lift_exprs( X[i] for i=1:M ))
136 end
137 end
138
139 @inline @generated function linsolve₁(AB :: TupleMatrix{M,N}) :: NTuple{M,Float64} where {N,M}
140 X, solver = generate_linsolve(:AB, M, N, 1)
141 return quote
142 $solver
143 return $(lift_exprs( :( $(X[i])[1] ) for i=1:M ))
144 end
145 end
146
147 const linsolve = linsolve₁
148
149 function tuplify(M, N, A, b)
150 ((((A[i][j] for j=1:N)..., b[j]) for i=1:M)...,)
151 end
152
153
154 function compare(; dim=5, n_matrices=10000, n_testvectors=100)
155 testmatrices=[]
156 testvectors=[]
157 while length(testmatrices)<n_matrices
158 A=randn(dim, dim)
159 push!(testmatrices, A)
160 end
161 test_vectors=[randn(dim) for _ = 1:n_testvectors]
162
163
164 function evaluate(fn)
165 for A ∈ testmatrices
166 for b ∈ testvectors
167 fn(A, b)
168 end
169 end
170 end
171
172 function evaluate_and_report(fn, name)
173 printstyled("Evaluating $name…\n", color=:cyan)
174 @time evaluate(fn)
175 end
176
177 evaluate_and_report("tuple-linsolve, ungenerated") do A, b
178 linsolve₀(tuplify(A, b))
179 end
180
181 evaluate_and_report("tuple-linsolve, generated") do A, b
182 linsolve₁(tuplify(A, b))
183 end
184
185 evaluate_and_report("backslash") do A, b
186 A \ b
187 end
188 end
189
190 end # module

mercurial