src/linsolve.rs

Fri, 13 Oct 2023 13:32:15 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Fri, 13 Oct 2023 13:32:15 -0500
changeset 22
013274b0b388
parent 5
59dc4c5883f4
child 24
8efa7abce7c7
permissions
-rw-r--r--

Update Cargo.lock to stop build failures with current nightly rust.

/*!
Linear equation solvers for small problems stored in Rust arrays.
*/

use crate::types::Float;
use std::mem::MaybeUninit;

/// Gaussian elimination for $AX=B$, where $A$ and $B$ are both stored in `ab`,
/// $A \in \mathbb{R}^{M \times M}$ and $X, B \in \mathbb{R}^{M \times K}$.
pub fn linsolve0<F : Float, const M : usize, const N : usize, const K : usize>(
    mut ab : [[F; N]; M]
) -> [[F; K]; M] {
    assert_eq!(M + K , N);

    let mut k = 0;

    // Convert to row-echelon form
    for h in 0..(M-1) {
        // Find pivotable column (has some non-zero entries in rows ≥ h)
        'find_pivot: while k < N {
            let (mut î, mut v) = (h, ab[h][k].abs());
            // Find row ≥ h of maximum absolute value in this column
            for i in (h+1)..M {
                let ṽ = ab[i][k].abs();
                if ṽ > v  {
                    î = i;
                    v = ṽ;
                }
            }
            if v > F::ZERO {
                ab.swap(h, î);
                for i in (h+1)..M {
                    let f = ab[i][k] / ab[h][k];
                    ab[i][k] = F::ZERO;
                    for j in (k+1)..N {
                        ab[i][j] -= ab[h][j]*f;
                    }
                }
                k += 1;
                break 'find_pivot;
            }
            k += 1
        }
    }

    // Solve UAX=UB for X where UA with U presenting the transformations above an
    // upper triangular matrix.
    // This use of MaybeUninit assumes F : Copy. Otherwise undefined behaviour may occur.
    let mut x : [[MaybeUninit<F>; K]; M] = core::array::from_fn(|_| MaybeUninit::uninit_array::<K>() );
    //unsafe { std::mem::MaybeUninit::uninit().assume_init() };
    for i in (0..M).rev() {
        for 𝓁 in 0..K {
            let mut tmp  = ab[i][M+𝓁];
            for j in (i+1)..M {
                tmp -= ab[i][j] * unsafe { *(x[j][𝓁].assume_init_ref()) };
            }
            tmp /= ab[i][i];
            x[i][𝓁].write(tmp);
        }
    }
    //unsafe { MaybeUninit::array_assume_init(x) };
    let xinit = unsafe {
        //core::intrinsics::assert_inhabited::<[[F; K]; M]>();
        (&x as *const _ as *const [[F; K]; M]).read()
    };

    std::mem::forget(x);
    xinit
}

/// Gaussian elimination for $Ax=b$, where $A$ and $b$ are both stored in `ab`,
/// $A \in \mathbb{R}^{M \times M}$ and $x, b \in \mathbb{R}^M$.
#[inline]
pub fn linsolve<F : Float, const M : usize, const N : usize>(ab : [[F; N]; M]) -> [F; M] {
    let x : [[F; 1]; M] = linsolve0(ab);
    unsafe { *((&x as *const [F; 1]) as *const [F; M] ) }
}


#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn linsolve_test() {
        let ab1 = [[1.0, 2.0, 3.0], [2.0, 1.0, 6.0]];
        assert_eq!(linsolve(ab1), [3.0, 0.0]);
        let ab2 = [[1.0, 2.0, 0.0, 1.0], [4.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]];
        assert_eq!(linsolve(ab2), [0.0, 0.5, 0.0]);
    }
}

mercurial