6 #[cfg(feature = "nightly")] |
6 #[cfg(feature = "nightly")] |
7 use std::mem::MaybeUninit; |
7 use std::mem::MaybeUninit; |
8 |
8 |
9 /// Gaussian elimination for $AX=B$, where $A$ and $B$ are both stored in `ab`, |
9 /// Gaussian elimination for $AX=B$, where $A$ and $B$ are both stored in `ab`, |
10 /// $A \in \mathbb{R}^{M \times M}$ and $X, B \in \mathbb{R}^{M \times K}$. |
10 /// $A \in \mathbb{R}^{M \times M}$ and $X, B \in \mathbb{R}^{M \times K}$. |
11 pub fn linsolve0<F : Float, const M : usize, const N : usize, const K : usize>( |
11 pub fn linsolve0<F: Float, const M: usize, const N: usize, const K: usize>( |
12 mut ab : [[F; N]; M] |
12 mut ab: [[F; N]; M], |
13 ) -> [[F; K]; M] { |
13 ) -> [[F; K]; M] { |
14 assert_eq!(M + K , N); |
14 assert_eq!(M + K, N); |
15 |
15 |
16 let mut k = 0; |
16 let mut k = 0; |
17 |
17 |
18 // Convert to row-echelon form |
18 // Convert to row-echelon form |
19 for h in 0..(M-1) { |
19 for h in 0..(M - 1) { |
20 // Find pivotable column (has some non-zero entries in rows ≥ h) |
20 // Find pivotable column (has some non-zero entries in rows ≥ h) |
21 'find_pivot: while k < N { |
21 'find_pivot: while k < N { |
22 let (mut î, mut v) = (h, ab[h][k].abs()); |
22 let (mut î, mut v) = (h, ab[h][k].abs()); |
23 // Find row ≥ h of maximum absolute value in this column |
23 // Find row ≥ h of maximum absolute value in this column |
24 for i in (h+1)..M { |
24 for i in (h + 1)..M { |
25 let ṽ = ab[i][k].abs(); |
25 let ṽ = ab[i][k].abs(); |
26 if ṽ > v { |
26 if ṽ > v { |
27 î = i; |
27 î = i; |
28 v = ṽ; |
28 v = ṽ; |
29 } |
29 } |
30 } |
30 } |
31 if v > F::ZERO { |
31 if v > F::ZERO { |
32 ab.swap(h, î); |
32 ab.swap(h, î); |
33 for i in (h+1)..M { |
33 for i in (h + 1)..M { |
34 let f = ab[i][k] / ab[h][k]; |
34 let f = ab[i][k] / ab[h][k]; |
35 ab[i][k] = F::ZERO; |
35 ab[i][k] = F::ZERO; |
36 for j in (k+1)..N { |
36 for j in (k + 1)..N { |
37 ab[i][j] -= ab[h][j]*f; |
37 ab[i][j] -= ab[h][j] * f; |
38 } |
38 } |
39 } |
39 } |
40 k += 1; |
40 k += 1; |
41 break 'find_pivot; |
41 break 'find_pivot; |
42 } |
42 } |
54 { |
54 { |
55 let mut x: [[MaybeUninit<F>; K]; M] = [[const { MaybeUninit::uninit() }; K]; M]; |
55 let mut x: [[MaybeUninit<F>; K]; M] = [[const { MaybeUninit::uninit() }; K]; M]; |
56 //unsafe { std::mem::MaybeUninit::uninit().assume_init() }; |
56 //unsafe { std::mem::MaybeUninit::uninit().assume_init() }; |
57 for i in (0..M).rev() { |
57 for i in (0..M).rev() { |
58 for 𝓁 in 0..K { |
58 for 𝓁 in 0..K { |
59 let mut tmp = ab[i][M+𝓁]; |
59 let mut tmp = ab[i][M + 𝓁]; |
60 for j in (i+1)..M { |
60 for j in (i + 1)..M { |
61 tmp -= ab[i][j] * unsafe { *(x[j][𝓁].assume_init_ref()) }; |
61 tmp -= ab[i][j] * unsafe { *(x[j][𝓁].assume_init_ref()) }; |
62 } |
62 } |
63 tmp /= ab[i][i]; |
63 tmp /= ab[i][i]; |
64 x[i][𝓁].write(tmp); |
64 x[i][𝓁].write(tmp); |
65 } |
65 } |
69 (&x as *const _ as *const [[F; K]; M]).read() |
69 (&x as *const _ as *const [[F; K]; M]).read() |
70 } |
70 } |
71 } |
71 } |
72 #[cfg(not(feature = "nightly"))] |
72 #[cfg(not(feature = "nightly"))] |
73 { |
73 { |
74 let mut x : [[F; K]; M] = [[F::ZERO; K]; M]; |
74 let mut x: [[F; K]; M] = [[F::ZERO; K]; M]; |
75 for i in (0..M).rev() { |
75 for i in (0..M).rev() { |
76 for 𝓁 in 0..K { |
76 for 𝓁 in 0..K { |
77 let mut tmp = ab[i][M+𝓁]; |
77 let mut tmp = ab[i][M + 𝓁]; |
78 for j in (i+1)..M { |
78 for j in (i + 1)..M { |
79 tmp -= ab[i][j] * x[j][𝓁]; |
79 tmp -= ab[i][j] * x[j][𝓁]; |
80 } |
80 } |
81 tmp /= ab[i][i]; |
81 tmp /= ab[i][i]; |
82 x[i][𝓁] = tmp; |
82 x[i][𝓁] = tmp; |
83 } |
83 } |
87 } |
87 } |
88 |
88 |
89 /// Gaussian elimination for $Ax=b$, where $A$ and $b$ are both stored in `ab`, |
89 /// Gaussian elimination for $Ax=b$, where $A$ and $b$ are both stored in `ab`, |
90 /// $A \in \mathbb{R}^{M \times M}$ and $x, b \in \mathbb{R}^M$. |
90 /// $A \in \mathbb{R}^{M \times M}$ and $x, b \in \mathbb{R}^M$. |
91 #[inline] |
91 #[inline] |
92 pub fn linsolve<F : Float, const M : usize, const N : usize>(ab : [[F; N]; M]) -> [F; M] { |
92 pub fn linsolve<F: Float, const M: usize, const N: usize>(ab: [[F; N]; M]) -> [F; M] { |
93 let x : [[F; 1]; M] = linsolve0(ab); |
93 let x: [[F; 1]; M] = linsolve0(ab); |
94 unsafe { *((&x as *const [F; 1]) as *const [F; M] ) } |
94 unsafe { *((&x as *const [F; 1]) as *const [F; M]) } |
95 } |
95 } |
96 |
|
97 |
96 |
98 #[cfg(test)] |
97 #[cfg(test)] |
99 mod tests { |
98 mod tests { |
100 use super::*; |
99 use super::*; |
101 |
100 |
102 #[test] |
101 #[test] |
103 fn linsolve_test() { |
102 fn linsolve_test() { |
104 let ab1 = [[1.0, 2.0, 3.0], [2.0, 1.0, 6.0]]; |
103 let ab1 = [[1.0, 2.0, 3.0], [2.0, 1.0, 6.0]]; |
105 assert_eq!(linsolve(ab1), [3.0, 0.0]); |
104 assert_eq!(linsolve(ab1), [3.0, 0.0]); |
106 let ab2 = [[1.0, 2.0, 0.0, 1.0], [4.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]; |
105 let ab2 = [ |
|
106 [1.0, 2.0, 0.0, 1.0], |
|
107 [4.0, 0.0, 0.0, 0.0], |
|
108 [0.0, 0.0, 1.0, 0.0], |
|
109 ]; |
107 assert_eq!(linsolve(ab2), [0.0, 0.5, 0.0]); |
110 assert_eq!(linsolve(ab2), [0.0, 0.5, 0.0]); |
108 } |
111 } |
109 } |
112 } |
110 |
|