Sat, 24 Jan 2026 19:24:29 -0500
Add truncate
/*! Python wrapper to measures. Only enabled with crate feature `pyo3`. These should really be in the `pointsource_pde` crate, but Rust doesn't allow that. */ use super::{DeltaMeasure, DiscreteMeasure}; use alg_tools::loc::Loc; use numpy::{Ix1, PyArray}; use pyo3::prelude::*; use pyo3::types::{PyList, PyModule, PyModuleMethods}; use pyo3::{pymodule, Borrowed, Bound, PyResult}; macro_rules! create_interface { ($name: ident, $iter:ident, $N: literal, $F:ident) => { #[allow(non_camel_case_types)] #[pyclass(module = "pointsource_algs")] /// Wrapper to [`DiscreteMeasure<Loc<$N, $F>, $F>,`] /// /// This is mainly needed because pyo3 does not support generics, not even instantiating /// them to specific values. pub struct $name { inner: DiscreteMeasure<Loc<$N, $F>, $F>, } impl $name { pub fn wrap(inner: DiscreteMeasure<Loc<$N, $F>, $F>) -> Self { Self { inner } } } #[pymethods] impl $name { #[new] pub fn new(contents: &Bound<'_, PyList>) -> PyResult<Self> { // let vec: Vec<(PyArray<$F, Ix1>, $F)> = contents.extract()?; // Ok(Self { inner: vec.into() }) let mut res = DiscreteMeasure::new(); for v in contents.iter() { let (x, α): ([$F; $N], $F) = v.extract()?; res.push(DeltaMeasure { x: x.into(), α }); } Ok(Self { inner: res }) } /*pub fn from(vec: Vec<([$F; $N], $F)>) -> Self { Self { inner: DiscreteMeasure::from_iter(vec.into_iter()), } }*/ fn __iter__(slf: Bound<'_, Self>) -> PyResult<Py<$iter>> { Py::new(slf.py(), $iter { measure_ref: slf.unbind(), next: 0, pad: false, }) } /// Same as iterating the object, but pads (or cuts) location vectors to 3 dimensions. fn iter_padded(slf: Bound<'_, Self>) -> PyResult<Py<$iter>> { Py::new(slf.py(), $iter { measure_ref: slf.unbind(), next: 0, pad: true, }) } } #[allow(non_camel_case_types)] #[pyclass(module = "pointsource_algs")] /// Python-side iterator for [`DiscreteMeasure<Loc<$N, $F>, $F>,`] /// Returns tuples (weight, coords) pub struct $iter { measure_ref: Py<$name>, next: usize, pad: bool, } #[pymethods] impl $iter { fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } // We cannot use lifetimes in types exported to Python, so have to use a // very primitive iterator. fn __next__( mut slf: PyRefMut<'_, Self>, ) -> PyResult<Option<(Bound<'_, PyArray<$F, Ix1>>, $F)>> { let py = slf.py(); let meas_: PyRef<'_, $name> = slf.measure_ref.extract(py)?; let meas = &(meas_.inner); let pad = slf.pad; let next = &mut slf.next; Ok((*next < meas.len()).then(|| { let δ = meas[*next]; *next += 1; if pad { ( PyArray::from_iter( py, δ.x.iter().copied().chain(std::iter::repeat(0.0)).take(3), ), δ.α, ) } else { (PyArray::from_slice(py, &δ.x.0), δ.α) } })) } } // Direct access without passing through the wrappers impl<'a, 'py> FromPyObject<'a, 'py> for DiscreteMeasure<Loc<$N, $F>, $F> { type Error = PyErr; fn extract(obj: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> { let wrapper: PyRef<'_, $name> = obj.extract()?; Ok(wrapper.inner.clone()) } } impl<'a, 'py> IntoPyObject<'py> for &'a mut DiscreteMeasure<Loc<$N, $F>, $F> { type Target = $name; type Error = PyErr; type Output = Bound<'py, $name>; fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> { Bound::new(py, $name { inner: self.clone() }) } } impl<'py> IntoPyObject<'py> for DiscreteMeasure<Loc<$N, $F>, $F> { type Target = $name; type Error = PyErr; type Output = Bound<'py, $name>; fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> { Bound::new(py, $name { inner: self }) } } }; } create_interface!(DiscreteMeasure_1_f64, DiscreteMeasureIter_1_f64, 1, f64); create_interface!(DiscreteMeasure_2_f64, DiscreteMeasureIter_2_f64, 2, f64); create_interface!(DiscreteMeasure_3_f64, DiscreteMeasureIter_3_f64, 3, f64); /// Populates the Python module. /// /// Needs to be called with [`pyo3::append_to_inittab`]: /// ``` /// append_to_inittab!(pymod_pointsource_algs); /// ``` /// before initialising the intepreter. #[pymodule] #[pyo3(name = "measures")] pub fn pymod(m: &Bound<'_, PyModule>) -> PyResult<()> { //m.add_class::<crate::run::DefaultAlgorithm>()?; m.add_class::<crate::python::DiscreteMeasure_1_f64>()?; m.add_class::<crate::python::DiscreteMeasure_2_f64>()?; m.add_class::<crate::python::DiscreteMeasure_3_f64>()?; m.add_class::<crate::python::DiscreteMeasureIter_1_f64>()?; m.add_class::<crate::python::DiscreteMeasureIter_2_f64>()?; m.add_class::<crate::python::DiscreteMeasureIter_3_f64>()?; Ok(()) }