src/linsolve.rs

changeset 93
123f7f38e161
parent 92
e11986179a4b
equal deleted inserted replaced
92:e11986179a4b 93:123f7f38e161
6 #[cfg(feature = "nightly")] 6 #[cfg(feature = "nightly")]
7 use std::mem::MaybeUninit; 7 use std::mem::MaybeUninit;
8 8
9 /// Gaussian elimination for $AX=B$, where $A$ and $B$ are both stored in `ab`, 9 /// Gaussian elimination for $AX=B$, where $A$ and $B$ are both stored in `ab`,
10 /// $A \in \mathbb{R}^{M \times M}$ and $X, B \in \mathbb{R}^{M \times K}$. 10 /// $A \in \mathbb{R}^{M \times M}$ and $X, B \in \mathbb{R}^{M \times K}$.
11 pub fn linsolve0<F : Float, const M : usize, const N : usize, const K : usize>( 11 pub fn linsolve0<F: Float, const M: usize, const N: usize, const K: usize>(
12 mut ab : [[F; N]; M] 12 mut ab: [[F; N]; M],
13 ) -> [[F; K]; M] { 13 ) -> [[F; K]; M] {
14 assert_eq!(M + K , N); 14 assert_eq!(M + K, N);
15 15
16 let mut k = 0; 16 let mut k = 0;
17 17
18 // Convert to row-echelon form 18 // Convert to row-echelon form
19 for h in 0..(M-1) { 19 for h in 0..(M - 1) {
20 // Find pivotable column (has some non-zero entries in rows ≥ h) 20 // Find pivotable column (has some non-zero entries in rows ≥ h)
21 'find_pivot: while k < N { 21 'find_pivot: while k < N {
22 let (mut î, mut v) = (h, ab[h][k].abs()); 22 let (mut î, mut v) = (h, ab[h][k].abs());
23 // Find row ≥ h of maximum absolute value in this column 23 // Find row ≥ h of maximum absolute value in this column
24 for i in (h+1)..M { 24 for i in (h + 1)..M {
25 let ṽ = ab[i][k].abs(); 25 let ṽ = ab[i][k].abs();
26 if ṽ > v { 26 if ṽ > v {
27 î = i; 27 î = i;
28 v = ṽ; 28 v = ṽ;
29 } 29 }
30 } 30 }
31 if v > F::ZERO { 31 if v > F::ZERO {
32 ab.swap(h, î); 32 ab.swap(h, î);
33 for i in (h+1)..M { 33 for i in (h + 1)..M {
34 let f = ab[i][k] / ab[h][k]; 34 let f = ab[i][k] / ab[h][k];
35 ab[i][k] = F::ZERO; 35 ab[i][k] = F::ZERO;
36 for j in (k+1)..N { 36 for j in (k + 1)..N {
37 ab[i][j] -= ab[h][j]*f; 37 ab[i][j] -= ab[h][j] * f;
38 } 38 }
39 } 39 }
40 k += 1; 40 k += 1;
41 break 'find_pivot; 41 break 'find_pivot;
42 } 42 }
54 { 54 {
55 let mut x: [[MaybeUninit<F>; K]; M] = [[const { MaybeUninit::uninit() }; K]; M]; 55 let mut x: [[MaybeUninit<F>; K]; M] = [[const { MaybeUninit::uninit() }; K]; M];
56 //unsafe { std::mem::MaybeUninit::uninit().assume_init() }; 56 //unsafe { std::mem::MaybeUninit::uninit().assume_init() };
57 for i in (0..M).rev() { 57 for i in (0..M).rev() {
58 for 𝓁 in 0..K { 58 for 𝓁 in 0..K {
59 let mut tmp = ab[i][M+𝓁]; 59 let mut tmp = ab[i][M + 𝓁];
60 for j in (i+1)..M { 60 for j in (i + 1)..M {
61 tmp -= ab[i][j] * unsafe { *(x[j][𝓁].assume_init_ref()) }; 61 tmp -= ab[i][j] * unsafe { *(x[j][𝓁].assume_init_ref()) };
62 } 62 }
63 tmp /= ab[i][i]; 63 tmp /= ab[i][i];
64 x[i][𝓁].write(tmp); 64 x[i][𝓁].write(tmp);
65 } 65 }
69 (&x as *const _ as *const [[F; K]; M]).read() 69 (&x as *const _ as *const [[F; K]; M]).read()
70 } 70 }
71 } 71 }
72 #[cfg(not(feature = "nightly"))] 72 #[cfg(not(feature = "nightly"))]
73 { 73 {
74 let mut x : [[F; K]; M] = [[F::ZERO; K]; M]; 74 let mut x: [[F; K]; M] = [[F::ZERO; K]; M];
75 for i in (0..M).rev() { 75 for i in (0..M).rev() {
76 for 𝓁 in 0..K { 76 for 𝓁 in 0..K {
77 let mut tmp = ab[i][M+𝓁]; 77 let mut tmp = ab[i][M + 𝓁];
78 for j in (i+1)..M { 78 for j in (i + 1)..M {
79 tmp -= ab[i][j] * x[j][𝓁]; 79 tmp -= ab[i][j] * x[j][𝓁];
80 } 80 }
81 tmp /= ab[i][i]; 81 tmp /= ab[i][i];
82 x[i][𝓁] = tmp; 82 x[i][𝓁] = tmp;
83 } 83 }
87 } 87 }
88 88
89 /// Gaussian elimination for $Ax=b$, where $A$ and $b$ are both stored in `ab`, 89 /// Gaussian elimination for $Ax=b$, where $A$ and $b$ are both stored in `ab`,
90 /// $A \in \mathbb{R}^{M \times M}$ and $x, b \in \mathbb{R}^M$. 90 /// $A \in \mathbb{R}^{M \times M}$ and $x, b \in \mathbb{R}^M$.
91 #[inline] 91 #[inline]
92 pub fn linsolve<F : Float, const M : usize, const N : usize>(ab : [[F; N]; M]) -> [F; M] { 92 pub fn linsolve<F: Float, const M: usize, const N: usize>(ab: [[F; N]; M]) -> [F; M] {
93 let x : [[F; 1]; M] = linsolve0(ab); 93 let x: [[F; 1]; M] = linsolve0(ab);
94 unsafe { *((&x as *const [F; 1]) as *const [F; M] ) } 94 unsafe { *((&x as *const [F; 1]) as *const [F; M]) }
95 } 95 }
96
97 96
98 #[cfg(test)] 97 #[cfg(test)]
99 mod tests { 98 mod tests {
100 use super::*; 99 use super::*;
101 100
102 #[test] 101 #[test]
103 fn linsolve_test() { 102 fn linsolve_test() {
104 let ab1 = [[1.0, 2.0, 3.0], [2.0, 1.0, 6.0]]; 103 let ab1 = [[1.0, 2.0, 3.0], [2.0, 1.0, 6.0]];
105 assert_eq!(linsolve(ab1), [3.0, 0.0]); 104 assert_eq!(linsolve(ab1), [3.0, 0.0]);
106 let ab2 = [[1.0, 2.0, 0.0, 1.0], [4.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]; 105 let ab2 = [
106 [1.0, 2.0, 0.0, 1.0],
107 [4.0, 0.0, 0.0, 0.0],
108 [0.0, 0.0, 1.0, 0.0],
109 ];
107 assert_eq!(linsolve(ab2), [0.0, 0.5, 0.0]); 110 assert_eq!(linsolve(ab2), [0.0, 0.5, 0.0]);
108 } 111 }
109 } 112 }
110

mercurial