src/LinSolve.jl

changeset 37
f8be66557e0f
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/LinSolve.jl	Wed Dec 22 11:13:38 2021 +0200
@@ -0,0 +1,190 @@
+"""
+Linear solvers for small problems.
+"""
+module LinSolve
+
+using ..Metaprogramming
+
+export linsolve,
+       TupleMatrix
+
+const TupleMatrix{M,N} = NTuple{M, NTuple{N, Float64}}
+
+"""
+`linsolve(AB :: TupleMatrix{M,N}, :: Type{TupleMatrix{M, K}}) :: TupleMatrix{M, K}`
+
+where
+
+`TupleMatrix{M, N} = NTuple{M, NTuple{N, Float64}}`
+
+“Immutable” Gaussian elimination on tuples: solve AX=B for X,
+Both A and B are stored in AB. The second type parameter indicates the size of B.
+"""
+@polly function linsolve₀(AB :: TupleMatrix{M,N}, :: Val{K}) :: TupleMatrix{M, K} where {N,M,K}
+    @assert(M == N - K)
+
+    k = 0
+
+    # Convert to row-echelon form
+    for h = 1:(M-1)
+        # Find pivotable column (has some non-zero entries in rows ≥ h)
+        v = 0.0
+        î = h
+        while k ≤ N-1 && v == 0
+            k = k + 1
+            v = abs(AB[h][k])
+            # Find row ≥ h of maximum absolute value in this column
+            for i=(h+1):M
+                local v′ = abs(AB[i][k])
+                if v′ > v
+                    î = i
+                    v = v′
+                end
+            end
+        end
+
+        if v > 0
+            AB = (
+                AB[1:(h-1)]...,
+                AB[î],
+                (let ĩ = (i==î ? h : i), h̃ = î, f = AB[ĩ][k] / AB[h̃][k]
+                    ((0.0 for _ = 1:k)..., (AB[ĩ][j]-AB[h̃][j]*f for j = (k+1):N)...,)
+                end for i = (h+1):M)...
+            )
+        end
+    end
+
+    # Solve UAX=UB for X where UA with U presenting the transformations above an
+    # upper triangular matrix.
+    X = ()
+    for i=M:-1:1
+        r = .+(AB[i][M+1:end], (-AB[i][j].*X[j-i] for j=i+1:M)...)./AB[i][i]
+        X = (r, X...)
+    end
+    return X
+end
+
+@inline function linsolve₀(AB :: TupleMatrix{M,N}) :: NTuple{M,Float64} where {N,M}
+    X = linsolve₀(AB, Val(1))
+    return ((x for (x,) ∈ X)...,)
+end
+
+#@generated function linsolve₁(AB :: TupleMatrix{M,N}, :: Type{TupleMatrix{M, K}}) :: TupleMatrix{M, K} where {N,M,K}
+function generate_linsolve(AB :: Symbol, M :: Int, N :: Int, K :: Int)
+    @assert(M == N - K)
+    # The variables of ABN collect the stepwise stages of the transformed matrix.
+    # Initial i-1 entries of ABN[i] are never used, as the previous steps have already
+    # finalised the corresponding rows, but are included as “missing” (at compile time)
+    # for the sake of indexing clarity. The M-1:th row-wise step finalises the row-echelon
+    # form, so ABN has M-1 rows itself.
+    step_ABN(step) = ((missing for i=1:step-1)...,
+                      (gensym("ABN_$(step)_$(i)") for i ∈ step:M)...,)
+    ABN = ((step_ABN(step) for step=1:M-1)...,)
+    # UAB “diagonally” refers to ABN to collate the final rows of the transformed matrix UAB.
+    # Since the M-1:th row-wise step finalises the row-echelon form, the last row comes already
+    # from the M-1:th step.
+    UAB = ((ABN[i][i] for i=1:M-1)..., ABN[M-1][M])
+    # The variables of X collect the rows of the solution to AX=B.
+    X = ((gensym("X_$(i)") for i ∈ 1:M)...,)
+
+    # Convert to row-echelon form. On each step we strip leading zeroes from ABN.
+    # In the end UAB should be upper-triangular.
+    convert_row(ABNout, h, ABNin) = quote
+        # Find pivotable column (has some non-zero entries in rows ≥ h)
+        $(ABNout[h]) = $(ABNin[h])
+        local v = abs($(ABNout[h])[1])
+        # Find row ≥ h of maximum absolute value in this column
+        $(sequence_exprs(:(
+            let v′ = abs($(ABNin[i])[1])
+                if v′ > v
+                    $(ABNout[h]), $(ABNin[i]) = $(ABNin[i]), $(ABNout[h])
+                    v = v′
+                end
+            end
+        ) for i=(h+1):M))
+
+        $(lift_exprs(ABNout[h+1:M])) = v > 0 ? ( $(lift_exprs( :(
+            # Transform
+            $(ABNin[i])[2:$N-$h+1] .- $(ABNout[h])[2:$N-$h+1].*( $(ABNin[i])[1] / $(ABNout[h])[1]) 
+        ) for i=h+1:M)) ) : ( $(lift_exprs( :(
+            # Strip leading zeroes
+            $(ABNin[i])[2:$N-$h+1]
+        ) for i=h+1:M)) )
+    end
+
+    # Solve UAX=UB for X where UA with U presenting the transformations above an
+    # upper triangular matrix.
+    solve_row(UAB, i) = :(
+        $(X[i]) = $(lift_exprs( :(
+            +($(UAB[i])[$M-$i+1+$k], $(( :( -$(UAB[i])[$j-$i+1]*$(X[j])[$k] ) for j=i+1:M)...))
+        ) for k=1:K )) ./ $(UAB[i])[1]
+    )
+
+    return X, quote
+        $(lift_exprs(ABN[1][i] for i=1:M)) = $(lift_exprs( :( $AB[$i] ) for i=1:M))
+        $(convert_row(ABN[1], 1, ABN[1]))
+        $((convert_row(ABN[h], h, ABN[h-1]) for h = 2:(M-1))...)
+        $((solve_row(UAB, i) for i=M:-1:1)...)
+    end
+end
+
+@inline @generated function linsolve₁(AB :: TupleMatrix{M,N}, :: Val{K}) :: TupleMatrix{M, K} where {N,M,K}
+    X, solver = generate_linsolve(:AB, M, N, K)
+    return quote
+        $solver
+        return $(lift_exprs( X[i] for i=1:M ))
+    end
+end
+
+@inline @generated function linsolve₁(AB :: TupleMatrix{M,N}) :: NTuple{M,Float64} where {N,M}
+    X, solver = generate_linsolve(:AB, M, N, 1)
+    return quote
+        $solver
+        return $(lift_exprs( :( $(X[i])[1] ) for i=1:M ))
+    end
+end
+
+const linsolve = linsolve₁
+
+function tuplify(M, N, A, b)
+    ((((A[i][j] for j=1:N)..., b[j]) for i=1:M)...,)
+end
+
+
+function compare(; dim=5, n_matrices=10000, n_testvectors=100)
+    testmatrices=[]
+    testvectors=[]
+    while length(testmatrices)<n_matrices
+        A=randn(dim, dim)
+        push!(testmatrices, A)
+    end
+    test_vectors=[randn(dim) for _ = 1:n_testvectors]
+
+
+    function evaluate(fn)
+        for A ∈ testmatrices
+            for b ∈ testvectors
+                fn(A, b)
+            end
+        end
+    end
+
+    function evaluate_and_report(fn, name)
+        printstyled("Evaluating $name…\n", color=:cyan)
+        @time evaluate(fn)
+    end
+
+    evaluate_and_report("tuple-linsolve, ungenerated") do A, b
+        linsolve₀(tuplify(A, b))
+    end
+
+    evaluate_and_report("tuple-linsolve, generated") do A, b
+        linsolve₁(tuplify(A, b))
+    end
+
+    evaluate_and_report("backslash") do A, b
+        A \ b
+    end
+end
+
+end # module

mercurial