src/python.rs

Fri, 28 Nov 2025 12:48:17 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Fri, 28 Nov 2025 12:48:17 -0500
changeset 0
e8f3b6c55ce7
child 3
fbdee8e4a78d
permissions
-rw-r--r--

Initialise repository, separating measure from pointsource_algs

/*!
Python wrapper to measures. Only enabled with crate feature `pyo3`.

These should really be in the `pointsource_pde` crate, but Rust doesn't allow that.
*/

use super::{DeltaMeasure, DiscreteMeasure};
use alg_tools::loc::Loc;
use numpy::{Ix1, PyArray};
use pyo3::prelude::*;
use pyo3::types::{PyList, PyModule, PyModuleMethods};
use pyo3::{pymodule, Borrowed, Bound, PyResult};

macro_rules! create_interface {
    ($name: ident, $iter:ident, $N: literal, $F:ident) => {
        #[allow(non_camel_case_types)]
        #[pyclass(module = "pointsource_algs")]
        /// Wrapper to [`DiscreteMeasure<Loc<$N, $F>, $F>,`]
        ///
        /// This is mainly needed because pyo3 does not support generics, not even instantiating
        /// them to specific values.
        pub struct $name {
            inner: DiscreteMeasure<Loc<$N, $F>, $F>,
        }

        impl $name {
            pub fn wrap(inner: DiscreteMeasure<Loc<$N, $F>, $F>) -> Self {
                Self { inner }
            }
        }

        #[pymethods]
        impl $name {
            #[new]
            pub fn new(contents: &Bound<'_, PyList>) -> PyResult<Self> {
                // let vec: Vec<(PyArray<$F, Ix1>, $F)> = contents.extract()?;
                // Ok(Self { inner: vec.into() })
                let mut res = DiscreteMeasure::new();
                for v in contents.iter() {
                    let (x, α): ([$F; $N], $F) = v.extract()?;
                    res.push(DeltaMeasure { x: x.into(), α });
                }
                Ok(Self { inner: res })
            }

            /*pub fn from(vec: Vec<([$F; $N], $F)>) -> Self {
                Self {
                    inner: DiscreteMeasure::from_iter(vec.into_iter()),
                }
            }*/

            fn __iter__(slf: Bound<'_, Self>) -> PyResult<Py<$iter>> {
                Py::new(slf.py(), $iter { measure_ref: slf.unbind(), next: 0 })
            }
        }

        #[allow(non_camel_case_types)]
        #[pyclass(module = "pointsource_algs")]
        /// Python-side iterator for [`DiscreteMeasure<Loc<$N, $F>, $F>,`]
        /// Returns tuples (weight, coords)
        pub struct $iter {
            measure_ref: Py<$name>,
            next: usize,
        }

        #[pymethods]
        impl $iter {
            fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
                slf
            }

            // We cannot use lifetimes in types exported to Python, so have to use a
            // very primitive iterator.
            fn __next__(
                mut slf: PyRefMut<'_, Self>,
            ) -> PyResult<Option<(Bound<'_, PyArray<$F, Ix1>>, $F)>> {
                let py = slf.py();
                let meas_: PyRef<'_, $name> = slf.measure_ref.extract(py)?;
                let meas = &(meas_.inner);
                let next = &mut slf.next;
                Ok((*next < meas.len()).then(|| {
                    let δ = meas[*next];
                    *next += 1;
                    (PyArray::from_slice(py, &δ.x.0), δ.α)
                }))
            }
        }

        // Direct access without passing through the wrappers
        impl<'a, 'py> FromPyObject<'a, 'py> for DiscreteMeasure<Loc<$N, $F>, $F> {
            type Error = PyErr;
            fn extract(obj: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
                let wrapper: PyRef<'_, $name> = obj.extract()?;
                Ok(wrapper.inner.clone())
            }
        }

        impl<'a, 'py> IntoPyObject<'py> for &'a mut DiscreteMeasure<Loc<$N, $F>, $F> {
            type Target = $name;
            type Error = PyErr;
            type Output = Bound<'py, $name>;

            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
                Bound::new(py, $name { inner: self.clone() })
            }
        }

        impl<'py> IntoPyObject<'py> for DiscreteMeasure<Loc<$N, $F>, $F> {
            type Target = $name;
            type Error = PyErr;
            type Output = Bound<'py, $name>;

            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
                Bound::new(py, $name { inner: self })
            }
        }
    };
}

create_interface!(DiscreteMeasure_1_f64, DiscreteMeasureIter_1_f64, 1, f64);
create_interface!(DiscreteMeasure_2_f64, DiscreteMeasureIter_2_f64, 2, f64);
create_interface!(DiscreteMeasure_3_f64, DiscreteMeasureIter_3_f64, 3, f64);

/// Populates the Python module.
///
/// Needs to be called with [`pyo3::append_to_inittab`]:
/// ```
/// append_to_inittab!(pymod_pointsource_algs);
/// ```
/// before initialising the intepreter.
#[pymodule]
#[pyo3(name = "measures")]
pub fn pymod(m: &Bound<'_, PyModule>) -> PyResult<()> {
    //m.add_class::<crate::run::DefaultAlgorithm>()?;
    m.add_class::<crate::python::DiscreteMeasure_1_f64>()?;
    m.add_class::<crate::python::DiscreteMeasure_2_f64>()?;
    m.add_class::<crate::python::DiscreteMeasure_3_f64>()?;
    m.add_class::<crate::python::DiscreteMeasureIter_1_f64>()?;
    m.add_class::<crate::python::DiscreteMeasureIter_2_f64>()?;
    m.add_class::<crate::python::DiscreteMeasureIter_3_f64>()?;
    Ok(())
}

mercurial