--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/python_access/prox_mapping.rs Thu Feb 26 09:32:12 2026 -0500 @@ -0,0 +1,81 @@ +/*! +Implementation of [`Prox`] for mappings implemented in Python. +*/ + +use crate::python_access::process_error; + +use super::diff_mapping::intern_many; +use super::diff_mapping::PythonMapping; +use alg_tools::convex::Prox; +use alg_tools::mapping::{ClosedSpace, Instance, Mapping, Space}; +use pyo3::conversion::FromPyObject; +use pyo3::intern; +use pyo3::prelude::*; +use std::marker::PhantomData; + +#[derive(Copy, Debug, Clone)] +/// Marker for PythonMappings that implement Prox +pub struct HasProx; + +impl<'py, Domain> Prox<Domain> for PythonMapping<'py, Domain, f64, HasProx> +where + Domain: ClosedSpace + Clone + for<'a> FromPyObject<'a, 'py, Error = PyErr> + IntoPyObject<'py>, +{ + type Prox<'a> + = PythonProx<'py, Domain> + where + Self: 'a; + + fn prox_mapping(&self, τ: Self::Codomain) -> Self::Prox<'_> { + PythonProx { τ, obj: self.get_obj(), _phantoms: PhantomData } + } +} + +#[derive(Debug)] +pub struct PythonProx<'py, Domain> +where + Domain: Space + for<'a> FromPyObject<'a, 'py>, +{ + τ: f64, + obj: Bound<'py, PyAny>, + _phantoms: PhantomData<Domain>, +} + +impl<'py, Domain> Mapping<Domain> for PythonProx<'py, Domain> +where + Domain: ClosedSpace + for<'a> FromPyObject<'a, 'py, Error = PyErr> + IntoPyObject<'py>, +{ + type Codomain = Domain; + + /// Compute the value of `self` at `x`. + fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain { + // TODO: use references and internal mutability? + let apply = intern!(self.obj.py(), "prox"); + process_error( + "PythonProx::apply", + self.obj.py(), + self.obj + .call_method1(apply, (self.τ, x.own())) + .and_then(|r| r.extract()), + ) + .unwrap() + } +} + +impl<'a, 'py, Domain, Codomain> FromPyObject<'a, 'py> + for PythonMapping<'py, Domain, Codomain, HasProx> +where + Domain: Space, + Codomain: ClosedSpace + FromPyObject<'a, 'py>, +{ + type Error = PyErr; + + fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> { + let obj = obj_.to_owned(); + // Verify that the necessary methods exist + for method in intern_many!(obj.py(), "apply", "prox") { + obj.getattr(method)?; //.downcast::<PyFunction>()?; + } + Ok(PythonMapping { obj, _phantoms: PhantomData }) + } +}