| |
1 use super::process_error; |
| |
2 use alg_tools::types::Float; |
| |
3 use pointsource_algs::measures::RNDM; |
| |
4 use pointsource_algs::plot::Plotter; |
| |
5 use pyo3::conversion::FromPyObject; |
| |
6 use pyo3::intern; |
| |
7 use pyo3::prelude::*; |
| |
8 use std::marker::PhantomData; |
| |
9 |
| |
10 #[derive(Debug, Clone)] |
| |
11 pub struct PythonPlotFactory<'py, T1, T2> { |
| |
12 pub(super) obj: Option<Bound<'py, PyAny>>, |
| |
13 pub(super) _phantoms: PhantomData<(T1, T2)>, |
| |
14 } |
| |
15 |
| |
16 impl<'a, 'py, T1, T2> FromPyObject<'a, 'py> for PythonPlotFactory<'py, T1, T2> { |
| |
17 type Error = PyErr; |
| |
18 |
| |
19 fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> { |
| |
20 let obj = obj_.to_owned(); |
| |
21 // Verify that the necessary methods exist |
| |
22 obj.getattr(intern!(obj.py(), "produce"))?; |
| |
23 Ok(PythonPlotFactory { obj: Some(obj), _phantoms: PhantomData }) |
| |
24 } |
| |
25 } |
| |
26 |
| |
27 #[derive(Debug, Clone)] |
| |
28 pub struct PythonPlotter<'py, T1, T2> { |
| |
29 pub(super) obj: Option<Bound<'py, PyAny>>, |
| |
30 pub(super) _phantoms: PhantomData<(T1, T2)>, |
| |
31 } |
| |
32 |
| |
33 impl<'a, 'py, T1, T2> FromPyObject<'a, 'py> for PythonPlotter<'py, T1, T2> { |
| |
34 type Error = PyErr; |
| |
35 |
| |
36 fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> { |
| |
37 let obj = obj_.to_owned(); |
| |
38 // Verify that the necessary methods exist |
| |
39 obj.getattr(intern!(obj.py(), "plot"))?; |
| |
40 Ok(PythonPlotter { obj: Some(obj), _phantoms: PhantomData }) |
| |
41 } |
| |
42 } |
| |
43 |
| |
44 impl<'py, T1, T2, F, const N: usize> Plotter<T1, T2, RNDM<N, F>> for PythonPlotter<'py, T1, T2> |
| |
45 where |
| |
46 F: Float, |
| |
47 for<'a> &'a RNDM<N, F>: IntoPyObject<'py>, |
| |
48 { |
| |
49 fn plot_spikes(&mut self, iter: usize, _g: Option<&T1>, _ω: Option<&T2>, μ: &RNDM<N, F>) { |
| |
50 if let Some(ref obj) = self.obj { |
| |
51 let plot = intern!(obj.py(), "plot"); |
| |
52 process_error( |
| |
53 "PythonPlotter::plot", |
| |
54 obj.py(), |
| |
55 obj.call_method1(plot, (iter, μ)).map(|_| ()), |
| |
56 ) |
| |
57 .unwrap() |
| |
58 } |
| |
59 } |
| |
60 } |
| |
61 |
| |
62 impl<'py, T1, T2> PythonPlotFactory<'py, T1, T2> { |
| |
63 pub fn dummy() -> Self { |
| |
64 PythonPlotFactory { obj: None, _phantoms: PhantomData } |
| |
65 } |
| |
66 } |
| |
67 |
| |
68 impl<'py, T1: Clone, T2: Clone> PythonPlotFactory<'py, T1, T2> { |
| |
69 pub fn produce(&self, prefix: String) -> PythonPlotter<'py, T1, T2> { |
| |
70 if let Some(ref obj) = self.obj { |
| |
71 let produce = intern!(obj.py(), "produce"); |
| |
72 process_error( |
| |
73 "PythonPlotFactory::produce", |
| |
74 obj.py(), |
| |
75 obj.call_method1(produce, (prefix,)) |
| |
76 .and_then(|r| r.extract()), |
| |
77 ) |
| |
78 .unwrap() |
| |
79 } else { |
| |
80 PythonPlotter { obj: None, _phantoms: PhantomData } |
| |
81 } |
| |
82 } |
| |
83 } |