diff -r 6dfa8001eed2 -r f8be66557e0f src/LinSolve.jl --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/LinSolve.jl Wed Dec 22 11:13:38 2021 +0200 @@ -0,0 +1,190 @@ +""" +Linear solvers for small problems. +""" +module LinSolve + +using ..Metaprogramming + +export linsolve, + TupleMatrix + +const TupleMatrix{M,N} = NTuple{M, NTuple{N, Float64}} + +""" +`linsolve(AB :: TupleMatrix{M,N}, :: Type{TupleMatrix{M, K}}) :: TupleMatrix{M, K}` + +where + +`TupleMatrix{M, N} = NTuple{M, NTuple{N, Float64}}` + +“Immutable” Gaussian elimination on tuples: solve AX=B for X, +Both A and B are stored in AB. The second type parameter indicates the size of B. +""" +@polly function linsolve₀(AB :: TupleMatrix{M,N}, :: Val{K}) :: TupleMatrix{M, K} where {N,M,K} + @assert(M == N - K) + + k = 0 + + # Convert to row-echelon form + for h = 1:(M-1) + # Find pivotable column (has some non-zero entries in rows ≥ h) + v = 0.0 + î = h + while k ≤ N-1 && v == 0 + k = k + 1 + v = abs(AB[h][k]) + # Find row ≥ h of maximum absolute value in this column + for i=(h+1):M + local v′ = abs(AB[i][k]) + if v′ > v + î = i + v = v′ + end + end + end + + if v > 0 + AB = ( + AB[1:(h-1)]..., + AB[î], + (let ĩ = (i==î ? h : i), h̃ = î, f = AB[ĩ][k] / AB[h̃][k] + ((0.0 for _ = 1:k)..., (AB[ĩ][j]-AB[h̃][j]*f for j = (k+1):N)...,) + end for i = (h+1):M)... + ) + end + end + + # Solve UAX=UB for X where UA with U presenting the transformations above an + # upper triangular matrix. + X = () + for i=M:-1:1 + r = .+(AB[i][M+1:end], (-AB[i][j].*X[j-i] for j=i+1:M)...)./AB[i][i] + X = (r, X...) + end + return X +end + +@inline function linsolve₀(AB :: TupleMatrix{M,N}) :: NTuple{M,Float64} where {N,M} + X = linsolve₀(AB, Val(1)) + return ((x for (x,) ∈ X)...,) +end + +#@generated function linsolve₁(AB :: TupleMatrix{M,N}, :: Type{TupleMatrix{M, K}}) :: TupleMatrix{M, K} where {N,M,K} +function generate_linsolve(AB :: Symbol, M :: Int, N :: Int, K :: Int) + @assert(M == N - K) + # The variables of ABN collect the stepwise stages of the transformed matrix. + # Initial i-1 entries of ABN[i] are never used, as the previous steps have already + # finalised the corresponding rows, but are included as “missing” (at compile time) + # for the sake of indexing clarity. The M-1:th row-wise step finalises the row-echelon + # form, so ABN has M-1 rows itself. + step_ABN(step) = ((missing for i=1:step-1)..., + (gensym("ABN_$(step)_$(i)") for i ∈ step:M)...,) + ABN = ((step_ABN(step) for step=1:M-1)...,) + # UAB “diagonally” refers to ABN to collate the final rows of the transformed matrix UAB. + # Since the M-1:th row-wise step finalises the row-echelon form, the last row comes already + # from the M-1:th step. + UAB = ((ABN[i][i] for i=1:M-1)..., ABN[M-1][M]) + # The variables of X collect the rows of the solution to AX=B. + X = ((gensym("X_$(i)") for i ∈ 1:M)...,) + + # Convert to row-echelon form. On each step we strip leading zeroes from ABN. + # In the end UAB should be upper-triangular. + convert_row(ABNout, h, ABNin) = quote + # Find pivotable column (has some non-zero entries in rows ≥ h) + $(ABNout[h]) = $(ABNin[h]) + local v = abs($(ABNout[h])[1]) + # Find row ≥ h of maximum absolute value in this column + $(sequence_exprs(:( + let v′ = abs($(ABNin[i])[1]) + if v′ > v + $(ABNout[h]), $(ABNin[i]) = $(ABNin[i]), $(ABNout[h]) + v = v′ + end + end + ) for i=(h+1):M)) + + $(lift_exprs(ABNout[h+1:M])) = v > 0 ? ( $(lift_exprs( :( + # Transform + $(ABNin[i])[2:$N-$h+1] .- $(ABNout[h])[2:$N-$h+1].*( $(ABNin[i])[1] / $(ABNout[h])[1]) + ) for i=h+1:M)) ) : ( $(lift_exprs( :( + # Strip leading zeroes + $(ABNin[i])[2:$N-$h+1] + ) for i=h+1:M)) ) + end + + # Solve UAX=UB for X where UA with U presenting the transformations above an + # upper triangular matrix. + solve_row(UAB, i) = :( + $(X[i]) = $(lift_exprs( :( + +($(UAB[i])[$M-$i+1+$k], $(( :( -$(UAB[i])[$j-$i+1]*$(X[j])[$k] ) for j=i+1:M)...)) + ) for k=1:K )) ./ $(UAB[i])[1] + ) + + return X, quote + $(lift_exprs(ABN[1][i] for i=1:M)) = $(lift_exprs( :( $AB[$i] ) for i=1:M)) + $(convert_row(ABN[1], 1, ABN[1])) + $((convert_row(ABN[h], h, ABN[h-1]) for h = 2:(M-1))...) + $((solve_row(UAB, i) for i=M:-1:1)...) + end +end + +@inline @generated function linsolve₁(AB :: TupleMatrix{M,N}, :: Val{K}) :: TupleMatrix{M, K} where {N,M,K} + X, solver = generate_linsolve(:AB, M, N, K) + return quote + $solver + return $(lift_exprs( X[i] for i=1:M )) + end +end + +@inline @generated function linsolve₁(AB :: TupleMatrix{M,N}) :: NTuple{M,Float64} where {N,M} + X, solver = generate_linsolve(:AB, M, N, 1) + return quote + $solver + return $(lift_exprs( :( $(X[i])[1] ) for i=1:M )) + end +end + +const linsolve = linsolve₁ + +function tuplify(M, N, A, b) + ((((A[i][j] for j=1:N)..., b[j]) for i=1:M)...,) +end + + +function compare(; dim=5, n_matrices=10000, n_testvectors=100) + testmatrices=[] + testvectors=[] + while length(testmatrices)