src/python_access/numpy_array.rs

Fri, 08 May 2026 17:16:34 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Fri, 08 May 2026 17:16:34 -0500
changeset 4
49b062acace9
parent 3
c3a4f4bb87f7
permissions
-rw-r--r--

Do not directly depend on ndarray, but through numpy

/*!
Wrapper for numpy arrays.
*/

use alg_tools::euclidean::wrap::{WrapGuard, WrapGuardMut, Wrapped};
use alg_tools::types::Float;
use nalgebra::{DMatrix, DMatrixView, DMatrixViewMut, Dyn};
use numpy::ndarray::{Dimension, Ix1, Ix2};
use numpy::{PyArray, PyArrayMethods, PyReadonlyArray, PyReadwriteArray, ToPyArray};
use pyo3::prelude::*;

/// A helper structure of dealing with dolfinx functions.
/// `N` is the domain dimension, `O` the order, and `D` is the codomain dimension.
#[allow(non_camel_case_types)]
#[derive(Debug, Clone)]
pub enum NumpyArray_f64<'py, D> {
    PyWrapped {
        /// Python object.
        x: Bound<'py, PyArray<f64, D>>,
    },
    Nalgebra {
        x: DMatrix<f64>,
    },
}

#[allow(non_camel_case_types)]
#[allow(unused)]
pub type NumpyVector_f64<'py> = NumpyArray_f64<'py, Ix1>;

#[allow(non_camel_case_types)]
#[allow(unused)]
pub type NumpyMatrix_f64<'py> = NumpyArray_f64<'py, Ix2>;

impl<'a, 'py, D: Dimension> FromPyObject<'a, 'py> for NumpyArray_f64<'py, D> {
    type Error = PyErr;

    fn extract(x: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
        Ok(NumpyArray_f64::PyWrapped { x: x.to_owned().cast_into()? })
    }
}

impl<'py> IntoPyObject<'py> for NumpyArray_f64<'py, Ix1> {
    type Target = PyArray<f64, Ix1>;
    type Error = PyErr;
    type Output = pyo3::Bound<'py, Self::Target>;

    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
        match self {
            NumpyArray_f64::PyWrapped { x } => x.into_pyobject(py).map_err(From::from),
            NumpyArray_f64::Nalgebra { x } => x.to_pyarray(py).reshape((x.len(),)),
        }
    }
}

impl<'py> IntoPyObject<'py> for NumpyArray_f64<'py, Ix2> {
    type Target = PyArray<f64, Ix2>;
    type Error = PyErr;
    type Output = pyo3::Bound<'py, Self::Target>;

    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
        match self {
            NumpyArray_f64::PyWrapped { x } => x.into_pyobject(py).map_err(From::from),
            NumpyArray_f64::Nalgebra { x } => Ok(x.to_pyarray(py)),
        }
    }
}

#[allow(non_camel_case_types)]
#[derive(Debug)]
pub enum ArrayGuard<'a, 'py, D: Dimension, F: numpy::Element + Float = f64> {
    PyWrapped {
        /// Python object.
        x: PyReadonlyArray<'py, F, D>,
    },
    Nalgebra {
        x: &'a DMatrix<F>,
    },
}

#[allow(non_camel_case_types)]
#[derive(Debug)]
pub enum ArrayGuardMut<'a, 'py, D: Dimension, F: numpy::Element + Float = f64> {
    PyWrapped {
        /// Python object.
        x: PyReadwriteArray<'py, F, D>,
    },
    Nalgebra {
        x: &'a mut DMatrix<F>,
    },
}

macro_rules! impl_euclidean {
    ($dim:ty) => {
        impl<'a, 'py> WrapGuard<'a, f64> for ArrayGuard<'a, 'py, $dim, f64> where 'py : 'a {
            type View<'b> = DMatrixView<'b, f64, Dyn, Dyn> where Self : 'b;

            #[inline]
            fn get_view(&self) -> Self::View<'_> {
                match self {
                    ArrayGuard::PyWrapped{ref x} => x.as_matrix(),
                    ArrayGuard::Nalgebra{x} => x.as_view(),
                }
            }
        }

        impl<'a, 'py> WrapGuardMut<'a, f64> for ArrayGuardMut<'a, 'py, $dim, f64> where 'py : 'a {
            type ViewMut<'b> = DMatrixViewMut<'b, f64, Dyn, Dyn> where Self : 'b;

            #[inline]
            fn get_view_mut(&mut self) -> Self::ViewMut<'_> {
                match self {
                    ArrayGuardMut::PyWrapped{ref x} => x.as_matrix_mut(),
                    ArrayGuardMut::Nalgebra{x} => x.as_view_mut(),
                }
            }
        }

        impl<'py> Wrapped for NumpyArray_f64<'py, $dim> where Self : 'py {
            type WrappedField = f64;
            type Guard<'a> =  ArrayGuard<'a, 'py, $dim, f64> where Self : 'a;
            type GuardMut<'a> =  ArrayGuardMut<'a, 'py, $dim, f64> where Self : 'a;
            type UnwrappedOutput = DMatrix<f64>;
            type WrappedOutput = Self;

            #[inline]
            fn get_guard(&self) -> Self::Guard<'_> {
                match self {
                    NumpyArray_f64::PyWrapped{ ref x } => ArrayGuard::PyWrapped{ x : x.readonly() },
                    NumpyArray_f64::Nalgebra{ ref x } => ArrayGuard::Nalgebra{ x },
                }
            }

            #[inline]
            fn get_guard_mut(&mut self) -> Self::GuardMut<'_> {
                match self {
                    NumpyArray_f64::PyWrapped{ ref mut x } => ArrayGuardMut::PyWrapped{ x : x.readwrite() },
                    NumpyArray_f64::Nalgebra{ ref mut x } => ArrayGuardMut::Nalgebra{ x },
                }
            }


            fn wrap(x: Self::UnwrappedOutput) -> Self {
                NumpyArray_f64::Nalgebra{ x }
            }
        }
        alg_tools::wrap!(f64; NumpyArray_f64<'py, $dim> where 'py);
    };
}

impl_euclidean!(Ix1);
impl_euclidean!(Ix2);

mercurial