src/python_access/plot.rs

changeset 1
a4137aedcb3a
child 3
c3a4f4bb87f7
equal deleted inserted replaced
0:7ec1cfe19a24 1:a4137aedcb3a
1 use super::process_error;
2 use alg_tools::types::Float;
3 use pointsource_algs::measures::RNDM;
4 use pointsource_algs::plot::Plotter;
5 use pyo3::conversion::FromPyObject;
6 use pyo3::intern;
7 use pyo3::prelude::*;
8 use std::marker::PhantomData;
9
10 #[derive(Debug, Clone)]
11 pub struct PythonPlotFactory<'py, T1, T2> {
12 pub(super) obj: Option<Bound<'py, PyAny>>,
13 pub(super) _phantoms: PhantomData<(T1, T2)>,
14 }
15
16 impl<'a, 'py, T1, T2> FromPyObject<'a, 'py> for PythonPlotFactory<'py, T1, T2> {
17 type Error = PyErr;
18
19 fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
20 let obj = obj_.to_owned();
21 // Verify that the necessary methods exist
22 obj.getattr(intern!(obj.py(), "produce"))?;
23 Ok(PythonPlotFactory { obj: Some(obj), _phantoms: PhantomData })
24 }
25 }
26
27 #[derive(Debug, Clone)]
28 pub struct PythonPlotter<'py, T1, T2> {
29 pub(super) obj: Option<Bound<'py, PyAny>>,
30 pub(super) _phantoms: PhantomData<(T1, T2)>,
31 }
32
33 impl<'a, 'py, T1, T2> FromPyObject<'a, 'py> for PythonPlotter<'py, T1, T2> {
34 type Error = PyErr;
35
36 fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
37 let obj = obj_.to_owned();
38 // Verify that the necessary methods exist
39 obj.getattr(intern!(obj.py(), "plot"))?;
40 Ok(PythonPlotter { obj: Some(obj), _phantoms: PhantomData })
41 }
42 }
43
44 impl<'py, T1, T2, F, const N: usize> Plotter<T1, T2, RNDM<N, F>> for PythonPlotter<'py, T1, T2>
45 where
46 F: Float,
47 for<'a> &'a RNDM<N, F>: IntoPyObject<'py>,
48 {
49 fn plot_spikes(&mut self, iter: usize, _g: Option<&T1>, _ω: Option<&T2>, μ: &RNDM<N, F>) {
50 if let Some(ref obj) = self.obj {
51 let plot = intern!(obj.py(), "plot");
52 process_error(
53 "PythonPlotter::plot",
54 obj.py(),
55 obj.call_method1(plot, (iter, μ)).map(|_| ()),
56 )
57 .unwrap()
58 }
59 }
60 }
61
62 impl<'py, T1, T2> PythonPlotFactory<'py, T1, T2> {
63 pub fn dummy() -> Self {
64 PythonPlotFactory { obj: None, _phantoms: PhantomData }
65 }
66 }
67
68 impl<'py, T1: Clone, T2: Clone> PythonPlotFactory<'py, T1, T2> {
69 pub fn produce(&self, prefix: String) -> PythonPlotter<'py, T1, T2> {
70 if let Some(ref obj) = self.obj {
71 let produce = intern!(obj.py(), "produce");
72 process_error(
73 "PythonPlotFactory::produce",
74 obj.py(),
75 obj.call_method1(produce, (prefix,))
76 .and_then(|r| r.extract()),
77 )
78 .unwrap()
79 } else {
80 PythonPlotter { obj: None, _phantoms: PhantomData }
81 }
82 }
83 }

mercurial