| |
1 /*! |
| |
2 Python wrapper to measures. Only enabled with crate feature `pyo3`. |
| |
3 |
| |
4 These should really be in the `pointsource_pde` crate, but Rust doesn't allow that. |
| |
5 */ |
| |
6 |
| |
7 use super::{DeltaMeasure, DiscreteMeasure}; |
| |
8 use alg_tools::loc::Loc; |
| |
9 use numpy::{Ix1, PyArray}; |
| |
10 use pyo3::prelude::*; |
| |
11 use pyo3::types::{PyList, PyModule, PyModuleMethods}; |
| |
12 use pyo3::{pymodule, Borrowed, Bound, PyResult}; |
| |
13 |
| |
14 macro_rules! create_interface { |
| |
15 ($name: ident, $iter:ident, $N: literal, $F:ident) => { |
| |
16 #[allow(non_camel_case_types)] |
| |
17 #[pyclass(module = "pointsource_algs")] |
| |
18 /// Wrapper to [`DiscreteMeasure<Loc<$N, $F>, $F>,`] |
| |
19 /// |
| |
20 /// This is mainly needed because pyo3 does not support generics, not even instantiating |
| |
21 /// them to specific values. |
| |
22 pub struct $name { |
| |
23 inner: DiscreteMeasure<Loc<$N, $F>, $F>, |
| |
24 } |
| |
25 |
| |
26 impl $name { |
| |
27 pub fn wrap(inner: DiscreteMeasure<Loc<$N, $F>, $F>) -> Self { |
| |
28 Self { inner } |
| |
29 } |
| |
30 } |
| |
31 |
| |
32 #[pymethods] |
| |
33 impl $name { |
| |
34 #[new] |
| |
35 pub fn new(contents: &Bound<'_, PyList>) -> PyResult<Self> { |
| |
36 // let vec: Vec<(PyArray<$F, Ix1>, $F)> = contents.extract()?; |
| |
37 // Ok(Self { inner: vec.into() }) |
| |
38 let mut res = DiscreteMeasure::new(); |
| |
39 for v in contents.iter() { |
| |
40 let (x, α): ([$F; $N], $F) = v.extract()?; |
| |
41 res.push(DeltaMeasure { x: x.into(), α }); |
| |
42 } |
| |
43 Ok(Self { inner: res }) |
| |
44 } |
| |
45 |
| |
46 /*pub fn from(vec: Vec<([$F; $N], $F)>) -> Self { |
| |
47 Self { |
| |
48 inner: DiscreteMeasure::from_iter(vec.into_iter()), |
| |
49 } |
| |
50 }*/ |
| |
51 |
| |
52 fn __iter__(slf: Bound<'_, Self>) -> PyResult<Py<$iter>> { |
| |
53 Py::new(slf.py(), $iter { measure_ref: slf.unbind(), next: 0 }) |
| |
54 } |
| |
55 } |
| |
56 |
| |
57 #[allow(non_camel_case_types)] |
| |
58 #[pyclass(module = "pointsource_algs")] |
| |
59 /// Python-side iterator for [`DiscreteMeasure<Loc<$N, $F>, $F>,`] |
| |
60 /// Returns tuples (weight, coords) |
| |
61 pub struct $iter { |
| |
62 measure_ref: Py<$name>, |
| |
63 next: usize, |
| |
64 } |
| |
65 |
| |
66 #[pymethods] |
| |
67 impl $iter { |
| |
68 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { |
| |
69 slf |
| |
70 } |
| |
71 |
| |
72 // We cannot use lifetimes in types exported to Python, so have to use a |
| |
73 // very primitive iterator. |
| |
74 fn __next__( |
| |
75 mut slf: PyRefMut<'_, Self>, |
| |
76 ) -> PyResult<Option<(Bound<'_, PyArray<$F, Ix1>>, $F)>> { |
| |
77 let py = slf.py(); |
| |
78 let meas_: PyRef<'_, $name> = slf.measure_ref.extract(py)?; |
| |
79 let meas = &(meas_.inner); |
| |
80 let next = &mut slf.next; |
| |
81 Ok((*next < meas.len()).then(|| { |
| |
82 let δ = meas[*next]; |
| |
83 *next += 1; |
| |
84 (PyArray::from_slice(py, &δ.x.0), δ.α) |
| |
85 })) |
| |
86 } |
| |
87 } |
| |
88 |
| |
89 // Direct access without passing through the wrappers |
| |
90 impl<'a, 'py> FromPyObject<'a, 'py> for DiscreteMeasure<Loc<$N, $F>, $F> { |
| |
91 type Error = PyErr; |
| |
92 fn extract(obj: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> { |
| |
93 let wrapper: PyRef<'_, $name> = obj.extract()?; |
| |
94 Ok(wrapper.inner.clone()) |
| |
95 } |
| |
96 } |
| |
97 |
| |
98 impl<'a, 'py> IntoPyObject<'py> for &'a mut DiscreteMeasure<Loc<$N, $F>, $F> { |
| |
99 type Target = $name; |
| |
100 type Error = PyErr; |
| |
101 type Output = Bound<'py, $name>; |
| |
102 |
| |
103 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> { |
| |
104 Bound::new(py, $name { inner: self.clone() }) |
| |
105 } |
| |
106 } |
| |
107 |
| |
108 impl<'py> IntoPyObject<'py> for DiscreteMeasure<Loc<$N, $F>, $F> { |
| |
109 type Target = $name; |
| |
110 type Error = PyErr; |
| |
111 type Output = Bound<'py, $name>; |
| |
112 |
| |
113 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> { |
| |
114 Bound::new(py, $name { inner: self }) |
| |
115 } |
| |
116 } |
| |
117 }; |
| |
118 } |
| |
119 |
| |
120 create_interface!(DiscreteMeasure_1_f64, DiscreteMeasureIter_1_f64, 1, f64); |
| |
121 create_interface!(DiscreteMeasure_2_f64, DiscreteMeasureIter_2_f64, 2, f64); |
| |
122 create_interface!(DiscreteMeasure_3_f64, DiscreteMeasureIter_3_f64, 3, f64); |
| |
123 |
| |
124 /// Populates the Python module. |
| |
125 /// |
| |
126 /// Needs to be called with [`pyo3::append_to_inittab`]: |
| |
127 /// ``` |
| |
128 /// append_to_inittab!(pymod_pointsource_algs); |
| |
129 /// ``` |
| |
130 /// before initialising the intepreter. |
| |
131 #[pymodule] |
| |
132 #[pyo3(name = "measures")] |
| |
133 pub fn pymod(m: &Bound<'_, PyModule>) -> PyResult<()> { |
| |
134 //m.add_class::<crate::run::DefaultAlgorithm>()?; |
| |
135 m.add_class::<crate::python::DiscreteMeasure_1_f64>()?; |
| |
136 m.add_class::<crate::python::DiscreteMeasure_2_f64>()?; |
| |
137 m.add_class::<crate::python::DiscreteMeasure_3_f64>()?; |
| |
138 m.add_class::<crate::python::DiscreteMeasureIter_1_f64>()?; |
| |
139 m.add_class::<crate::python::DiscreteMeasureIter_2_f64>()?; |
| |
140 m.add_class::<crate::python::DiscreteMeasureIter_3_f64>()?; |
| |
141 Ok(()) |
| |
142 } |