src/python.rs

changeset 0
e8f3b6c55ce7
child 3
fbdee8e4a78d
equal deleted inserted replaced
-1:000000000000 0:e8f3b6c55ce7
1 /*!
2 Python wrapper to measures. Only enabled with crate feature `pyo3`.
3
4 These should really be in the `pointsource_pde` crate, but Rust doesn't allow that.
5 */
6
7 use super::{DeltaMeasure, DiscreteMeasure};
8 use alg_tools::loc::Loc;
9 use numpy::{Ix1, PyArray};
10 use pyo3::prelude::*;
11 use pyo3::types::{PyList, PyModule, PyModuleMethods};
12 use pyo3::{pymodule, Borrowed, Bound, PyResult};
13
14 macro_rules! create_interface {
15 ($name: ident, $iter:ident, $N: literal, $F:ident) => {
16 #[allow(non_camel_case_types)]
17 #[pyclass(module = "pointsource_algs")]
18 /// Wrapper to [`DiscreteMeasure<Loc<$N, $F>, $F>,`]
19 ///
20 /// This is mainly needed because pyo3 does not support generics, not even instantiating
21 /// them to specific values.
22 pub struct $name {
23 inner: DiscreteMeasure<Loc<$N, $F>, $F>,
24 }
25
26 impl $name {
27 pub fn wrap(inner: DiscreteMeasure<Loc<$N, $F>, $F>) -> Self {
28 Self { inner }
29 }
30 }
31
32 #[pymethods]
33 impl $name {
34 #[new]
35 pub fn new(contents: &Bound<'_, PyList>) -> PyResult<Self> {
36 // let vec: Vec<(PyArray<$F, Ix1>, $F)> = contents.extract()?;
37 // Ok(Self { inner: vec.into() })
38 let mut res = DiscreteMeasure::new();
39 for v in contents.iter() {
40 let (x, α): ([$F; $N], $F) = v.extract()?;
41 res.push(DeltaMeasure { x: x.into(), α });
42 }
43 Ok(Self { inner: res })
44 }
45
46 /*pub fn from(vec: Vec<([$F; $N], $F)>) -> Self {
47 Self {
48 inner: DiscreteMeasure::from_iter(vec.into_iter()),
49 }
50 }*/
51
52 fn __iter__(slf: Bound<'_, Self>) -> PyResult<Py<$iter>> {
53 Py::new(slf.py(), $iter { measure_ref: slf.unbind(), next: 0 })
54 }
55 }
56
57 #[allow(non_camel_case_types)]
58 #[pyclass(module = "pointsource_algs")]
59 /// Python-side iterator for [`DiscreteMeasure<Loc<$N, $F>, $F>,`]
60 /// Returns tuples (weight, coords)
61 pub struct $iter {
62 measure_ref: Py<$name>,
63 next: usize,
64 }
65
66 #[pymethods]
67 impl $iter {
68 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
69 slf
70 }
71
72 // We cannot use lifetimes in types exported to Python, so have to use a
73 // very primitive iterator.
74 fn __next__(
75 mut slf: PyRefMut<'_, Self>,
76 ) -> PyResult<Option<(Bound<'_, PyArray<$F, Ix1>>, $F)>> {
77 let py = slf.py();
78 let meas_: PyRef<'_, $name> = slf.measure_ref.extract(py)?;
79 let meas = &(meas_.inner);
80 let next = &mut slf.next;
81 Ok((*next < meas.len()).then(|| {
82 let δ = meas[*next];
83 *next += 1;
84 (PyArray::from_slice(py, &δ.x.0), δ.α)
85 }))
86 }
87 }
88
89 // Direct access without passing through the wrappers
90 impl<'a, 'py> FromPyObject<'a, 'py> for DiscreteMeasure<Loc<$N, $F>, $F> {
91 type Error = PyErr;
92 fn extract(obj: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
93 let wrapper: PyRef<'_, $name> = obj.extract()?;
94 Ok(wrapper.inner.clone())
95 }
96 }
97
98 impl<'a, 'py> IntoPyObject<'py> for &'a mut DiscreteMeasure<Loc<$N, $F>, $F> {
99 type Target = $name;
100 type Error = PyErr;
101 type Output = Bound<'py, $name>;
102
103 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
104 Bound::new(py, $name { inner: self.clone() })
105 }
106 }
107
108 impl<'py> IntoPyObject<'py> for DiscreteMeasure<Loc<$N, $F>, $F> {
109 type Target = $name;
110 type Error = PyErr;
111 type Output = Bound<'py, $name>;
112
113 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
114 Bound::new(py, $name { inner: self })
115 }
116 }
117 };
118 }
119
120 create_interface!(DiscreteMeasure_1_f64, DiscreteMeasureIter_1_f64, 1, f64);
121 create_interface!(DiscreteMeasure_2_f64, DiscreteMeasureIter_2_f64, 2, f64);
122 create_interface!(DiscreteMeasure_3_f64, DiscreteMeasureIter_3_f64, 3, f64);
123
124 /// Populates the Python module.
125 ///
126 /// Needs to be called with [`pyo3::append_to_inittab`]:
127 /// ```
128 /// append_to_inittab!(pymod_pointsource_algs);
129 /// ```
130 /// before initialising the intepreter.
131 #[pymodule]
132 #[pyo3(name = "measures")]
133 pub fn pymod(m: &Bound<'_, PyModule>) -> PyResult<()> {
134 //m.add_class::<crate::run::DefaultAlgorithm>()?;
135 m.add_class::<crate::python::DiscreteMeasure_1_f64>()?;
136 m.add_class::<crate::python::DiscreteMeasure_2_f64>()?;
137 m.add_class::<crate::python::DiscreteMeasure_3_f64>()?;
138 m.add_class::<crate::python::DiscreteMeasureIter_1_f64>()?;
139 m.add_class::<crate::python::DiscreteMeasureIter_2_f64>()?;
140 m.add_class::<crate::python::DiscreteMeasureIter_3_f64>()?;
141 Ok(())
142 }

mercurial