src/python_access/plot.rs

Wed, 22 Apr 2026 22:32:00 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Wed, 22 Apr 2026 22:32:00 -0500
changeset 3
c3a4f4bb87f7
parent 1
a4137aedcb3a
permissions
-rw-r--r--

Dolfin update, fixes, additional experiment, build instructions.

/*!
 Access to Python-side plotter implementations.
*/
use super::process_error;
use alg_tools::types::Float;
use pointsource_algs::measures::RNDM;
use pointsource_algs::plot::Plotter;
use pyo3::conversion::FromPyObject;
use pyo3::intern;
use pyo3::prelude::*;
use std::marker::PhantomData;

/// A factor for generating a [`PythonPlotter`].
#[derive(Debug, Clone)]
pub struct PythonPlotFactory<'py, T1, T2> {
    /// Python side object implementing the plot factory
    pub(super) obj: Option<Bound<'py, PyAny>>,
    /// Phantoms
    pub(super) _phantoms: PhantomData<(T1, T2)>,
}

impl<'a, 'py, T1, T2> FromPyObject<'a, 'py> for PythonPlotFactory<'py, T1, T2> {
    type Error = PyErr;

    fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
        let obj = obj_.to_owned();
        // Verify that the necessary methods exist
        obj.getattr(intern!(obj.py(), "produce"))?;
        Ok(PythonPlotFactory { obj: Some(obj), _phantoms: PhantomData })
    }
}

/// Python-side implementation of [`Plotter`]
#[derive(Debug, Clone)]
pub struct PythonPlotter<'py, T1, T2> {
    /// Python side object implementing the [`Plotter`]
    pub(super) obj: Option<Bound<'py, PyAny>>,
    pub(super) _phantoms: PhantomData<(T1, T2)>,
}

impl<'a, 'py, T1, T2> FromPyObject<'a, 'py> for PythonPlotter<'py, T1, T2> {
    type Error = PyErr;

    fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
        let obj = obj_.to_owned();
        // Verify that the necessary methods exist
        obj.getattr(intern!(obj.py(), "plot"))?;
        Ok(PythonPlotter { obj: Some(obj), _phantoms: PhantomData })
    }
}

impl<'py, T1, T2, F, const N: usize> Plotter<T1, T2, RNDM<N, F>> for PythonPlotter<'py, T1, T2>
where
    F: Float,
    for<'a> &'a RNDM<N, F>: IntoPyObject<'py>,
{
    fn plot_spikes(&mut self, iter: usize, _g: Option<&T1>, _ω: Option<&T2>, μ: &RNDM<N, F>) {
        if let Some(ref obj) = self.obj {
            let plot = intern!(obj.py(), "plot");
            process_error(
                "PythonPlotter::plot",
                obj.py(),
                obj.call_method1(plot, (iter, μ)).map(|_| ()),
            )
            .unwrap()
        }
    }
}

impl<'py, T1, T2> PythonPlotFactory<'py, T1, T2> {
    /// Creates a dummy plotter factory that doe snothing
    pub fn dummy() -> Self {
        PythonPlotFactory { obj: None, _phantoms: PhantomData }
    }
}

impl<'py, T1: Clone, T2: Clone> PythonPlotFactory<'py, T1, T2> {
    /// Produces a [`PythonPlotter`] out of the factory.
    /// The `prefix` parameter indicates where to store the files.
    pub fn produce(&self, prefix: String) -> PythonPlotter<'py, T1, T2> {
        if let Some(ref obj) = self.obj {
            let produce = intern!(obj.py(), "produce");
            process_error(
                "PythonPlotFactory::produce",
                obj.py(),
                obj.call_method1(produce, (prefix,))
                    .and_then(|r| r.extract()),
            )
            .unwrap()
        } else {
            PythonPlotter { obj: None, _phantoms: PhantomData }
        }
    }
}

mercurial