src/python_access/numpy_array.rs

changeset 1
a4137aedcb3a
child 3
c3a4f4bb87f7
equal deleted inserted replaced
0:7ec1cfe19a24 1:a4137aedcb3a
1 /*!
2 Python and C++ wrapper for Dolfinx Function<f64>
3 */
4
5 use alg_tools::euclidean::wrap::{WrapGuard, WrapGuardMut, Wrapped};
6 use alg_tools::types::Float;
7 use nalgebra::{DMatrix, DMatrixView, DMatrixViewMut, Dyn};
8 use ndarray::{Dimension, Ix1, Ix2};
9 use numpy::{PyArray, PyArrayMethods, PyReadonlyArray, PyReadwriteArray, ToPyArray};
10 use pyo3::prelude::*;
11
12 /// A helper structure of dealing with dolfinx functions.
13 /// `N` is the domain dimension, `O` the order, and `D` is the codomain dimension.
14 #[allow(non_camel_case_types)]
15 #[derive(Debug, Clone)]
16 pub enum NumpyArray_f64<'py, D> {
17 PyWrapped {
18 /// Python object.
19 x: Bound<'py, PyArray<f64, D>>,
20 },
21 Nalgebra {
22 x: DMatrix<f64>,
23 },
24 }
25
26 #[allow(non_camel_case_types)]
27 #[allow(unused)]
28 pub type NumpyVector_f64<'py> = NumpyArray_f64<'py, Ix1>;
29
30 #[allow(non_camel_case_types)]
31 #[allow(unused)]
32 pub type NumpyMatrix_f64<'py> = NumpyArray_f64<'py, Ix2>;
33
34 impl<'a, 'py, D: Dimension> FromPyObject<'a, 'py> for NumpyArray_f64<'py, D> {
35 type Error = PyErr;
36
37 fn extract(x: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
38 Ok(NumpyArray_f64::PyWrapped { x: x.to_owned().cast_into()? })
39 }
40 }
41
42 impl<'py> IntoPyObject<'py> for NumpyArray_f64<'py, Ix1> {
43 type Target = PyArray<f64, Ix1>;
44 type Error = PyErr;
45 type Output = pyo3::Bound<'py, Self::Target>;
46
47 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
48 match self {
49 NumpyArray_f64::PyWrapped { x } => x.into_pyobject(py).map_err(From::from),
50 NumpyArray_f64::Nalgebra { x } => x.to_pyarray(py).reshape((x.len(),)),
51 }
52 }
53 }
54
55 impl<'py> IntoPyObject<'py> for NumpyArray_f64<'py, Ix2> {
56 type Target = PyArray<f64, Ix2>;
57 type Error = PyErr;
58 type Output = pyo3::Bound<'py, Self::Target>;
59
60 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
61 match self {
62 NumpyArray_f64::PyWrapped { x } => x.into_pyobject(py).map_err(From::from),
63 NumpyArray_f64::Nalgebra { x } => Ok(x.to_pyarray(py)),
64 }
65 }
66 }
67
68 #[allow(non_camel_case_types)]
69 #[derive(Debug)]
70 pub enum ArrayGuard<'a, 'py, D: Dimension, F: numpy::Element + Float = f64> {
71 PyWrapped {
72 /// Python object.
73 x: PyReadonlyArray<'py, F, D>,
74 },
75 Nalgebra {
76 x: &'a DMatrix<F>,
77 },
78 }
79
80 #[allow(non_camel_case_types)]
81 #[derive(Debug)]
82 pub enum ArrayGuardMut<'a, 'py, D: Dimension, F: numpy::Element + Float = f64> {
83 PyWrapped {
84 /// Python object.
85 x: PyReadwriteArray<'py, F, D>,
86 },
87 Nalgebra {
88 x: &'a mut DMatrix<F>,
89 },
90 }
91
92 macro_rules! impl_euclidean {
93 ($dim:ty) => {
94 impl<'a, 'py> WrapGuard<'a, f64> for ArrayGuard<'a, 'py, $dim, f64> where 'py : 'a {
95 type View<'b> = DMatrixView<'b, f64, Dyn, Dyn> where Self : 'b;
96
97 #[inline]
98 fn get_view(&self) -> Self::View<'_> {
99 match self {
100 ArrayGuard::PyWrapped{ref x} => x.as_matrix(),
101 ArrayGuard::Nalgebra{x} => x.as_view(),
102 }
103 }
104 }
105
106 impl<'a, 'py> WrapGuardMut<'a, f64> for ArrayGuardMut<'a, 'py, $dim, f64> where 'py : 'a {
107 type ViewMut<'b> = DMatrixViewMut<'b, f64, Dyn, Dyn> where Self : 'b;
108
109 #[inline]
110 fn get_view_mut(&mut self) -> Self::ViewMut<'_> {
111 match self {
112 ArrayGuardMut::PyWrapped{ref x} => x.as_matrix_mut(),
113 ArrayGuardMut::Nalgebra{x} => x.as_view_mut(),
114 }
115 }
116 }
117
118 impl<'py> Wrapped for NumpyArray_f64<'py, $dim> where Self : 'py {
119 type WrappedField = f64;
120 type Guard<'a> = ArrayGuard<'a, 'py, $dim, f64> where Self : 'a;
121 type GuardMut<'a> = ArrayGuardMut<'a, 'py, $dim, f64> where Self : 'a;
122 type UnwrappedOutput = DMatrix<f64>;
123 type WrappedOutput = Self;
124
125 #[inline]
126 fn get_guard(&self) -> Self::Guard<'_> {
127 match self {
128 NumpyArray_f64::PyWrapped{ ref x } => ArrayGuard::PyWrapped{ x : x.readonly() },
129 NumpyArray_f64::Nalgebra{ ref x } => ArrayGuard::Nalgebra{ x },
130 }
131 }
132
133 #[inline]
134 fn get_guard_mut(&mut self) -> Self::GuardMut<'_> {
135 match self {
136 NumpyArray_f64::PyWrapped{ ref mut x } => ArrayGuardMut::PyWrapped{ x : x.readwrite() },
137 NumpyArray_f64::Nalgebra{ ref mut x } => ArrayGuardMut::Nalgebra{ x },
138 }
139 }
140
141
142 fn wrap(x: Self::UnwrappedOutput) -> Self {
143 NumpyArray_f64::Nalgebra{ x }
144 }
145 }
146 alg_tools::wrap!(f64; NumpyArray_f64<'py, $dim> where 'py);
147 };
148 }
149
150 impl_euclidean!(Ix1);
151 impl_euclidean!(Ix2);

mercurial