src/python_access/plot.rs

changeset 1
a4137aedcb3a
child 3
c3a4f4bb87f7
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/python_access/plot.rs	Thu Feb 26 09:32:12 2026 -0500
@@ -0,0 +1,83 @@
+use super::process_error;
+use alg_tools::types::Float;
+use pointsource_algs::measures::RNDM;
+use pointsource_algs::plot::Plotter;
+use pyo3::conversion::FromPyObject;
+use pyo3::intern;
+use pyo3::prelude::*;
+use std::marker::PhantomData;
+
+#[derive(Debug, Clone)]
+pub struct PythonPlotFactory<'py, T1, T2> {
+    pub(super) obj: Option<Bound<'py, PyAny>>,
+    pub(super) _phantoms: PhantomData<(T1, T2)>,
+}
+
+impl<'a, 'py, T1, T2> FromPyObject<'a, 'py> for PythonPlotFactory<'py, T1, T2> {
+    type Error = PyErr;
+
+    fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
+        let obj = obj_.to_owned();
+        // Verify that the necessary methods exist
+        obj.getattr(intern!(obj.py(), "produce"))?;
+        Ok(PythonPlotFactory { obj: Some(obj), _phantoms: PhantomData })
+    }
+}
+
+#[derive(Debug, Clone)]
+pub struct PythonPlotter<'py, T1, T2> {
+    pub(super) obj: Option<Bound<'py, PyAny>>,
+    pub(super) _phantoms: PhantomData<(T1, T2)>,
+}
+
+impl<'a, 'py, T1, T2> FromPyObject<'a, 'py> for PythonPlotter<'py, T1, T2> {
+    type Error = PyErr;
+
+    fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
+        let obj = obj_.to_owned();
+        // Verify that the necessary methods exist
+        obj.getattr(intern!(obj.py(), "plot"))?;
+        Ok(PythonPlotter { obj: Some(obj), _phantoms: PhantomData })
+    }
+}
+
+impl<'py, T1, T2, F, const N: usize> Plotter<T1, T2, RNDM<N, F>> for PythonPlotter<'py, T1, T2>
+where
+    F: Float,
+    for<'a> &'a RNDM<N, F>: IntoPyObject<'py>,
+{
+    fn plot_spikes(&mut self, iter: usize, _g: Option<&T1>, _ω: Option<&T2>, μ: &RNDM<N, F>) {
+        if let Some(ref obj) = self.obj {
+            let plot = intern!(obj.py(), "plot");
+            process_error(
+                "PythonPlotter::plot",
+                obj.py(),
+                obj.call_method1(plot, (iter, μ)).map(|_| ()),
+            )
+            .unwrap()
+        }
+    }
+}
+
+impl<'py, T1, T2> PythonPlotFactory<'py, T1, T2> {
+    pub fn dummy() -> Self {
+        PythonPlotFactory { obj: None, _phantoms: PhantomData }
+    }
+}
+
+impl<'py, T1: Clone, T2: Clone> PythonPlotFactory<'py, T1, T2> {
+    pub fn produce(&self, prefix: String) -> PythonPlotter<'py, T1, T2> {
+        if let Some(ref obj) = self.obj {
+            let produce = intern!(obj.py(), "produce");
+            process_error(
+                "PythonPlotFactory::produce",
+                obj.py(),
+                obj.call_method1(produce, (prefix,))
+                    .and_then(|r| r.extract()),
+            )
+            .unwrap()
+        } else {
+            PythonPlotter { obj: None, _phantoms: PhantomData }
+        }
+    }
+}

mercurial