| 6 #[cfg(feature = "nightly")] |
6 #[cfg(feature = "nightly")] |
| 7 use std::mem::MaybeUninit; |
7 use std::mem::MaybeUninit; |
| 8 |
8 |
| 9 /// Gaussian elimination for $AX=B$, where $A$ and $B$ are both stored in `ab`, |
9 /// Gaussian elimination for $AX=B$, where $A$ and $B$ are both stored in `ab`, |
| 10 /// $A \in \mathbb{R}^{M \times M}$ and $X, B \in \mathbb{R}^{M \times K}$. |
10 /// $A \in \mathbb{R}^{M \times M}$ and $X, B \in \mathbb{R}^{M \times K}$. |
| 11 pub fn linsolve0<F : Float, const M : usize, const N : usize, const K : usize>( |
11 pub fn linsolve0<F: Float, const M: usize, const N: usize, const K: usize>( |
| 12 mut ab : [[F; N]; M] |
12 mut ab: [[F; N]; M], |
| 13 ) -> [[F; K]; M] { |
13 ) -> [[F; K]; M] { |
| 14 assert_eq!(M + K , N); |
14 assert_eq!(M + K, N); |
| 15 |
15 |
| 16 let mut k = 0; |
16 let mut k = 0; |
| 17 |
17 |
| 18 // Convert to row-echelon form |
18 // Convert to row-echelon form |
| 19 for h in 0..(M-1) { |
19 for h in 0..(M - 1) { |
| 20 // Find pivotable column (has some non-zero entries in rows ≥ h) |
20 // Find pivotable column (has some non-zero entries in rows ≥ h) |
| 21 'find_pivot: while k < N { |
21 'find_pivot: while k < N { |
| 22 let (mut î, mut v) = (h, ab[h][k].abs()); |
22 let (mut î, mut v) = (h, ab[h][k].abs()); |
| 23 // Find row ≥ h of maximum absolute value in this column |
23 // Find row ≥ h of maximum absolute value in this column |
| 24 for i in (h+1)..M { |
24 for i in (h + 1)..M { |
| 25 let ṽ = ab[i][k].abs(); |
25 let ṽ = ab[i][k].abs(); |
| 26 if ṽ > v { |
26 if ṽ > v { |
| 27 î = i; |
27 î = i; |
| 28 v = ṽ; |
28 v = ṽ; |
| 29 } |
29 } |
| 30 } |
30 } |
| 31 if v > F::ZERO { |
31 if v > F::ZERO { |
| 32 ab.swap(h, î); |
32 ab.swap(h, î); |
| 33 for i in (h+1)..M { |
33 for i in (h + 1)..M { |
| 34 let f = ab[i][k] / ab[h][k]; |
34 let f = ab[i][k] / ab[h][k]; |
| 35 ab[i][k] = F::ZERO; |
35 ab[i][k] = F::ZERO; |
| 36 for j in (k+1)..N { |
36 for j in (k + 1)..N { |
| 37 ab[i][j] -= ab[h][j]*f; |
37 ab[i][j] -= ab[h][j] * f; |
| 38 } |
38 } |
| 39 } |
39 } |
| 40 k += 1; |
40 k += 1; |
| 41 break 'find_pivot; |
41 break 'find_pivot; |
| 42 } |
42 } |
| 54 { |
54 { |
| 55 let mut x: [[MaybeUninit<F>; K]; M] = [[const { MaybeUninit::uninit() }; K]; M]; |
55 let mut x: [[MaybeUninit<F>; K]; M] = [[const { MaybeUninit::uninit() }; K]; M]; |
| 56 //unsafe { std::mem::MaybeUninit::uninit().assume_init() }; |
56 //unsafe { std::mem::MaybeUninit::uninit().assume_init() }; |
| 57 for i in (0..M).rev() { |
57 for i in (0..M).rev() { |
| 58 for 𝓁 in 0..K { |
58 for 𝓁 in 0..K { |
| 59 let mut tmp = ab[i][M+𝓁]; |
59 let mut tmp = ab[i][M + 𝓁]; |
| 60 for j in (i+1)..M { |
60 for j in (i + 1)..M { |
| 61 tmp -= ab[i][j] * unsafe { *(x[j][𝓁].assume_init_ref()) }; |
61 tmp -= ab[i][j] * unsafe { *(x[j][𝓁].assume_init_ref()) }; |
| 62 } |
62 } |
| 63 tmp /= ab[i][i]; |
63 tmp /= ab[i][i]; |
| 64 x[i][𝓁].write(tmp); |
64 x[i][𝓁].write(tmp); |
| 65 } |
65 } |
| 69 (&x as *const _ as *const [[F; K]; M]).read() |
69 (&x as *const _ as *const [[F; K]; M]).read() |
| 70 } |
70 } |
| 71 } |
71 } |
| 72 #[cfg(not(feature = "nightly"))] |
72 #[cfg(not(feature = "nightly"))] |
| 73 { |
73 { |
| 74 let mut x : [[F; K]; M] = [[F::ZERO; K]; M]; |
74 let mut x: [[F; K]; M] = [[F::ZERO; K]; M]; |
| 75 for i in (0..M).rev() { |
75 for i in (0..M).rev() { |
| 76 for 𝓁 in 0..K { |
76 for 𝓁 in 0..K { |
| 77 let mut tmp = ab[i][M+𝓁]; |
77 let mut tmp = ab[i][M + 𝓁]; |
| 78 for j in (i+1)..M { |
78 for j in (i + 1)..M { |
| 79 tmp -= ab[i][j] * x[j][𝓁]; |
79 tmp -= ab[i][j] * x[j][𝓁]; |
| 80 } |
80 } |
| 81 tmp /= ab[i][i]; |
81 tmp /= ab[i][i]; |
| 82 x[i][𝓁] = tmp; |
82 x[i][𝓁] = tmp; |
| 83 } |
83 } |
| 87 } |
87 } |
| 88 |
88 |
| 89 /// Gaussian elimination for $Ax=b$, where $A$ and $b$ are both stored in `ab`, |
89 /// Gaussian elimination for $Ax=b$, where $A$ and $b$ are both stored in `ab`, |
| 90 /// $A \in \mathbb{R}^{M \times M}$ and $x, b \in \mathbb{R}^M$. |
90 /// $A \in \mathbb{R}^{M \times M}$ and $x, b \in \mathbb{R}^M$. |
| 91 #[inline] |
91 #[inline] |
| 92 pub fn linsolve<F : Float, const M : usize, const N : usize>(ab : [[F; N]; M]) -> [F; M] { |
92 pub fn linsolve<F: Float, const M: usize, const N: usize>(ab: [[F; N]; M]) -> [F; M] { |
| 93 let x : [[F; 1]; M] = linsolve0(ab); |
93 let x: [[F; 1]; M] = linsolve0(ab); |
| 94 unsafe { *((&x as *const [F; 1]) as *const [F; M] ) } |
94 unsafe { *((&x as *const [F; 1]) as *const [F; M]) } |
| 95 } |
95 } |
| 96 |
|
| 97 |
96 |
| 98 #[cfg(test)] |
97 #[cfg(test)] |
| 99 mod tests { |
98 mod tests { |
| 100 use super::*; |
99 use super::*; |
| 101 |
100 |
| 102 #[test] |
101 #[test] |
| 103 fn linsolve_test() { |
102 fn linsolve_test() { |
| 104 let ab1 = [[1.0, 2.0, 3.0], [2.0, 1.0, 6.0]]; |
103 let ab1 = [[1.0, 2.0, 3.0], [2.0, 1.0, 6.0]]; |
| 105 assert_eq!(linsolve(ab1), [3.0, 0.0]); |
104 assert_eq!(linsolve(ab1), [3.0, 0.0]); |
| 106 let ab2 = [[1.0, 2.0, 0.0, 1.0], [4.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]; |
105 let ab2 = [ |
| |
106 [1.0, 2.0, 0.0, 1.0], |
| |
107 [4.0, 0.0, 0.0, 0.0], |
| |
108 [0.0, 0.0, 1.0, 0.0], |
| |
109 ]; |
| 107 assert_eq!(linsolve(ab2), [0.0, 0.5, 0.0]); |
110 assert_eq!(linsolve(ab2), [0.0, 0.5, 0.0]); |
| 108 } |
111 } |
| 109 } |
112 } |
| 110 |
|