src/python.rs

changeset 0
e8f3b6c55ce7
child 3
fbdee8e4a78d
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/python.rs	Fri Nov 28 12:48:17 2025 -0500
@@ -0,0 +1,142 @@
+/*!
+Python wrapper to measures. Only enabled with crate feature `pyo3`.
+
+These should really be in the `pointsource_pde` crate, but Rust doesn't allow that.
+*/
+
+use super::{DeltaMeasure, DiscreteMeasure};
+use alg_tools::loc::Loc;
+use numpy::{Ix1, PyArray};
+use pyo3::prelude::*;
+use pyo3::types::{PyList, PyModule, PyModuleMethods};
+use pyo3::{pymodule, Borrowed, Bound, PyResult};
+
+macro_rules! create_interface {
+    ($name: ident, $iter:ident, $N: literal, $F:ident) => {
+        #[allow(non_camel_case_types)]
+        #[pyclass(module = "pointsource_algs")]
+        /// Wrapper to [`DiscreteMeasure<Loc<$N, $F>, $F>,`]
+        ///
+        /// This is mainly needed because pyo3 does not support generics, not even instantiating
+        /// them to specific values.
+        pub struct $name {
+            inner: DiscreteMeasure<Loc<$N, $F>, $F>,
+        }
+
+        impl $name {
+            pub fn wrap(inner: DiscreteMeasure<Loc<$N, $F>, $F>) -> Self {
+                Self { inner }
+            }
+        }
+
+        #[pymethods]
+        impl $name {
+            #[new]
+            pub fn new(contents: &Bound<'_, PyList>) -> PyResult<Self> {
+                // let vec: Vec<(PyArray<$F, Ix1>, $F)> = contents.extract()?;
+                // Ok(Self { inner: vec.into() })
+                let mut res = DiscreteMeasure::new();
+                for v in contents.iter() {
+                    let (x, α): ([$F; $N], $F) = v.extract()?;
+                    res.push(DeltaMeasure { x: x.into(), α });
+                }
+                Ok(Self { inner: res })
+            }
+
+            /*pub fn from(vec: Vec<([$F; $N], $F)>) -> Self {
+                Self {
+                    inner: DiscreteMeasure::from_iter(vec.into_iter()),
+                }
+            }*/
+
+            fn __iter__(slf: Bound<'_, Self>) -> PyResult<Py<$iter>> {
+                Py::new(slf.py(), $iter { measure_ref: slf.unbind(), next: 0 })
+            }
+        }
+
+        #[allow(non_camel_case_types)]
+        #[pyclass(module = "pointsource_algs")]
+        /// Python-side iterator for [`DiscreteMeasure<Loc<$N, $F>, $F>,`]
+        /// Returns tuples (weight, coords)
+        pub struct $iter {
+            measure_ref: Py<$name>,
+            next: usize,
+        }
+
+        #[pymethods]
+        impl $iter {
+            fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
+                slf
+            }
+
+            // We cannot use lifetimes in types exported to Python, so have to use a
+            // very primitive iterator.
+            fn __next__(
+                mut slf: PyRefMut<'_, Self>,
+            ) -> PyResult<Option<(Bound<'_, PyArray<$F, Ix1>>, $F)>> {
+                let py = slf.py();
+                let meas_: PyRef<'_, $name> = slf.measure_ref.extract(py)?;
+                let meas = &(meas_.inner);
+                let next = &mut slf.next;
+                Ok((*next < meas.len()).then(|| {
+                    let δ = meas[*next];
+                    *next += 1;
+                    (PyArray::from_slice(py, &δ.x.0), δ.α)
+                }))
+            }
+        }
+
+        // Direct access without passing through the wrappers
+        impl<'a, 'py> FromPyObject<'a, 'py> for DiscreteMeasure<Loc<$N, $F>, $F> {
+            type Error = PyErr;
+            fn extract(obj: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
+                let wrapper: PyRef<'_, $name> = obj.extract()?;
+                Ok(wrapper.inner.clone())
+            }
+        }
+
+        impl<'a, 'py> IntoPyObject<'py> for &'a mut DiscreteMeasure<Loc<$N, $F>, $F> {
+            type Target = $name;
+            type Error = PyErr;
+            type Output = Bound<'py, $name>;
+
+            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
+                Bound::new(py, $name { inner: self.clone() })
+            }
+        }
+
+        impl<'py> IntoPyObject<'py> for DiscreteMeasure<Loc<$N, $F>, $F> {
+            type Target = $name;
+            type Error = PyErr;
+            type Output = Bound<'py, $name>;
+
+            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
+                Bound::new(py, $name { inner: self })
+            }
+        }
+    };
+}
+
+create_interface!(DiscreteMeasure_1_f64, DiscreteMeasureIter_1_f64, 1, f64);
+create_interface!(DiscreteMeasure_2_f64, DiscreteMeasureIter_2_f64, 2, f64);
+create_interface!(DiscreteMeasure_3_f64, DiscreteMeasureIter_3_f64, 3, f64);
+
+/// Populates the Python module.
+///
+/// Needs to be called with [`pyo3::append_to_inittab`]:
+/// ```
+/// append_to_inittab!(pymod_pointsource_algs);
+/// ```
+/// before initialising the intepreter.
+#[pymodule]
+#[pyo3(name = "measures")]
+pub fn pymod(m: &Bound<'_, PyModule>) -> PyResult<()> {
+    //m.add_class::<crate::run::DefaultAlgorithm>()?;
+    m.add_class::<crate::python::DiscreteMeasure_1_f64>()?;
+    m.add_class::<crate::python::DiscreteMeasure_2_f64>()?;
+    m.add_class::<crate::python::DiscreteMeasure_3_f64>()?;
+    m.add_class::<crate::python::DiscreteMeasureIter_1_f64>()?;
+    m.add_class::<crate::python::DiscreteMeasureIter_2_f64>()?;
+    m.add_class::<crate::python::DiscreteMeasureIter_3_f64>()?;
+    Ok(())
+}

mercurial