| |
1 """ |
| |
2 Linear solvers for small problems. |
| |
3 """ |
| |
4 module LinSolve |
| |
5 |
| |
6 using ..Metaprogramming |
| |
7 |
| |
8 export linsolve, |
| |
9 TupleMatrix |
| |
10 |
| |
11 const TupleMatrix{M,N} = NTuple{M, NTuple{N, Float64}} |
| |
12 |
| |
13 """ |
| |
14 `linsolve(AB :: TupleMatrix{M,N}, :: Type{TupleMatrix{M, K}}) :: TupleMatrix{M, K}` |
| |
15 |
| |
16 where |
| |
17 |
| |
18 `TupleMatrix{M, N} = NTuple{M, NTuple{N, Float64}}` |
| |
19 |
| |
20 “Immutable” Gaussian elimination on tuples: solve AX=B for X, |
| |
21 Both A and B are stored in AB. The second type parameter indicates the size of B. |
| |
22 """ |
| |
23 @polly function linsolve₀(AB :: TupleMatrix{M,N}, :: Val{K}) :: TupleMatrix{M, K} where {N,M,K} |
| |
24 @assert(M == N - K) |
| |
25 |
| |
26 k = 0 |
| |
27 |
| |
28 # Convert to row-echelon form |
| |
29 for h = 1:(M-1) |
| |
30 # Find pivotable column (has some non-zero entries in rows ≥ h) |
| |
31 v = 0.0 |
| |
32 î = h |
| |
33 while k ≤ N-1 && v == 0 |
| |
34 k = k + 1 |
| |
35 v = abs(AB[h][k]) |
| |
36 # Find row ≥ h of maximum absolute value in this column |
| |
37 for i=(h+1):M |
| |
38 local v′ = abs(AB[i][k]) |
| |
39 if v′ > v |
| |
40 î = i |
| |
41 v = v′ |
| |
42 end |
| |
43 end |
| |
44 end |
| |
45 |
| |
46 if v > 0 |
| |
47 AB = ( |
| |
48 AB[1:(h-1)]..., |
| |
49 AB[î], |
| |
50 (let ĩ = (i==î ? h : i), h̃ = î, f = AB[ĩ][k] / AB[h̃][k] |
| |
51 ((0.0 for _ = 1:k)..., (AB[ĩ][j]-AB[h̃][j]*f for j = (k+1):N)...,) |
| |
52 end for i = (h+1):M)... |
| |
53 ) |
| |
54 end |
| |
55 end |
| |
56 |
| |
57 # Solve UAX=UB for X where UA with U presenting the transformations above an |
| |
58 # upper triangular matrix. |
| |
59 X = () |
| |
60 for i=M:-1:1 |
| |
61 r = .+(AB[i][M+1:end], (-AB[i][j].*X[j-i] for j=i+1:M)...)./AB[i][i] |
| |
62 X = (r, X...) |
| |
63 end |
| |
64 return X |
| |
65 end |
| |
66 |
| |
67 @inline function linsolve₀(AB :: TupleMatrix{M,N}) :: NTuple{M,Float64} where {N,M} |
| |
68 X = linsolve₀(AB, Val(1)) |
| |
69 return ((x for (x,) ∈ X)...,) |
| |
70 end |
| |
71 |
| |
72 #@generated function linsolve₁(AB :: TupleMatrix{M,N}, :: Type{TupleMatrix{M, K}}) :: TupleMatrix{M, K} where {N,M,K} |
| |
73 function generate_linsolve(AB :: Symbol, M :: Int, N :: Int, K :: Int) |
| |
74 @assert(M == N - K) |
| |
75 # The variables of ABN collect the stepwise stages of the transformed matrix. |
| |
76 # Initial i-1 entries of ABN[i] are never used, as the previous steps have already |
| |
77 # finalised the corresponding rows, but are included as “missing” (at compile time) |
| |
78 # for the sake of indexing clarity. The M-1:th row-wise step finalises the row-echelon |
| |
79 # form, so ABN has M-1 rows itself. |
| |
80 step_ABN(step) = ((missing for i=1:step-1)..., |
| |
81 (gensym("ABN_$(step)_$(i)") for i ∈ step:M)...,) |
| |
82 ABN = ((step_ABN(step) for step=1:M-1)...,) |
| |
83 # UAB “diagonally” refers to ABN to collate the final rows of the transformed matrix UAB. |
| |
84 # Since the M-1:th row-wise step finalises the row-echelon form, the last row comes already |
| |
85 # from the M-1:th step. |
| |
86 UAB = ((ABN[i][i] for i=1:M-1)..., ABN[M-1][M]) |
| |
87 # The variables of X collect the rows of the solution to AX=B. |
| |
88 X = ((gensym("X_$(i)") for i ∈ 1:M)...,) |
| |
89 |
| |
90 # Convert to row-echelon form. On each step we strip leading zeroes from ABN. |
| |
91 # In the end UAB should be upper-triangular. |
| |
92 convert_row(ABNout, h, ABNin) = quote |
| |
93 # Find pivotable column (has some non-zero entries in rows ≥ h) |
| |
94 $(ABNout[h]) = $(ABNin[h]) |
| |
95 local v = abs($(ABNout[h])[1]) |
| |
96 # Find row ≥ h of maximum absolute value in this column |
| |
97 $(sequence_exprs(:( |
| |
98 let v′ = abs($(ABNin[i])[1]) |
| |
99 if v′ > v |
| |
100 $(ABNout[h]), $(ABNin[i]) = $(ABNin[i]), $(ABNout[h]) |
| |
101 v = v′ |
| |
102 end |
| |
103 end |
| |
104 ) for i=(h+1):M)) |
| |
105 |
| |
106 $(lift_exprs(ABNout[h+1:M])) = v > 0 ? ( $(lift_exprs( :( |
| |
107 # Transform |
| |
108 $(ABNin[i])[2:$N-$h+1] .- $(ABNout[h])[2:$N-$h+1].*( $(ABNin[i])[1] / $(ABNout[h])[1]) |
| |
109 ) for i=h+1:M)) ) : ( $(lift_exprs( :( |
| |
110 # Strip leading zeroes |
| |
111 $(ABNin[i])[2:$N-$h+1] |
| |
112 ) for i=h+1:M)) ) |
| |
113 end |
| |
114 |
| |
115 # Solve UAX=UB for X where UA with U presenting the transformations above an |
| |
116 # upper triangular matrix. |
| |
117 solve_row(UAB, i) = :( |
| |
118 $(X[i]) = $(lift_exprs( :( |
| |
119 +($(UAB[i])[$M-$i+1+$k], $(( :( -$(UAB[i])[$j-$i+1]*$(X[j])[$k] ) for j=i+1:M)...)) |
| |
120 ) for k=1:K )) ./ $(UAB[i])[1] |
| |
121 ) |
| |
122 |
| |
123 return X, quote |
| |
124 $(lift_exprs(ABN[1][i] for i=1:M)) = $(lift_exprs( :( $AB[$i] ) for i=1:M)) |
| |
125 $(convert_row(ABN[1], 1, ABN[1])) |
| |
126 $((convert_row(ABN[h], h, ABN[h-1]) for h = 2:(M-1))...) |
| |
127 $((solve_row(UAB, i) for i=M:-1:1)...) |
| |
128 end |
| |
129 end |
| |
130 |
| |
131 @inline @generated function linsolve₁(AB :: TupleMatrix{M,N}, :: Val{K}) :: TupleMatrix{M, K} where {N,M,K} |
| |
132 X, solver = generate_linsolve(:AB, M, N, K) |
| |
133 return quote |
| |
134 $solver |
| |
135 return $(lift_exprs( X[i] for i=1:M )) |
| |
136 end |
| |
137 end |
| |
138 |
| |
139 @inline @generated function linsolve₁(AB :: TupleMatrix{M,N}) :: NTuple{M,Float64} where {N,M} |
| |
140 X, solver = generate_linsolve(:AB, M, N, 1) |
| |
141 return quote |
| |
142 $solver |
| |
143 return $(lift_exprs( :( $(X[i])[1] ) for i=1:M )) |
| |
144 end |
| |
145 end |
| |
146 |
| |
147 const linsolve = linsolve₁ |
| |
148 |
| |
149 function tuplify(M, N, A, b) |
| |
150 ((((A[i][j] for j=1:N)..., b[j]) for i=1:M)...,) |
| |
151 end |
| |
152 |
| |
153 |
| |
154 function compare(; dim=5, n_matrices=10000, n_testvectors=100) |
| |
155 testmatrices=[] |
| |
156 testvectors=[] |
| |
157 while length(testmatrices)<n_matrices |
| |
158 A=randn(dim, dim) |
| |
159 push!(testmatrices, A) |
| |
160 end |
| |
161 test_vectors=[randn(dim) for _ = 1:n_testvectors] |
| |
162 |
| |
163 |
| |
164 function evaluate(fn) |
| |
165 for A ∈ testmatrices |
| |
166 for b ∈ testvectors |
| |
167 fn(A, b) |
| |
168 end |
| |
169 end |
| |
170 end |
| |
171 |
| |
172 function evaluate_and_report(fn, name) |
| |
173 printstyled("Evaluating $name…\n", color=:cyan) |
| |
174 @time evaluate(fn) |
| |
175 end |
| |
176 |
| |
177 evaluate_and_report("tuple-linsolve, ungenerated") do A, b |
| |
178 linsolve₀(tuplify(A, b)) |
| |
179 end |
| |
180 |
| |
181 evaluate_and_report("tuple-linsolve, generated") do A, b |
| |
182 linsolve₁(tuplify(A, b)) |
| |
183 end |
| |
184 |
| |
185 evaluate_and_report("backslash") do A, b |
| |
186 A \ b |
| |
187 end |
| |
188 end |
| |
189 |
| |
190 end # module |