src/python_access/prox_mapping.rs

changeset 1
a4137aedcb3a
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/python_access/prox_mapping.rs	Thu Feb 26 09:32:12 2026 -0500
@@ -0,0 +1,81 @@
+/*!
+Implementation of [`Prox`] for mappings implemented in Python.
+*/
+
+use crate::python_access::process_error;
+
+use super::diff_mapping::intern_many;
+use super::diff_mapping::PythonMapping;
+use alg_tools::convex::Prox;
+use alg_tools::mapping::{ClosedSpace, Instance, Mapping, Space};
+use pyo3::conversion::FromPyObject;
+use pyo3::intern;
+use pyo3::prelude::*;
+use std::marker::PhantomData;
+
+#[derive(Copy, Debug, Clone)]
+/// Marker for PythonMappings that implement Prox
+pub struct HasProx;
+
+impl<'py, Domain> Prox<Domain> for PythonMapping<'py, Domain, f64, HasProx>
+where
+    Domain: ClosedSpace + Clone + for<'a> FromPyObject<'a, 'py, Error = PyErr> + IntoPyObject<'py>,
+{
+    type Prox<'a>
+        = PythonProx<'py, Domain>
+    where
+        Self: 'a;
+
+    fn prox_mapping(&self, τ: Self::Codomain) -> Self::Prox<'_> {
+        PythonProx { τ, obj: self.get_obj(), _phantoms: PhantomData }
+    }
+}
+
+#[derive(Debug)]
+pub struct PythonProx<'py, Domain>
+where
+    Domain: Space + for<'a> FromPyObject<'a, 'py>,
+{
+    τ: f64,
+    obj: Bound<'py, PyAny>,
+    _phantoms: PhantomData<Domain>,
+}
+
+impl<'py, Domain> Mapping<Domain> for PythonProx<'py, Domain>
+where
+    Domain: ClosedSpace + for<'a> FromPyObject<'a, 'py, Error = PyErr> + IntoPyObject<'py>,
+{
+    type Codomain = Domain;
+
+    /// Compute the value of `self` at `x`.
+    fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain {
+        // TODO: use references and internal mutability?
+        let apply = intern!(self.obj.py(), "prox");
+        process_error(
+            "PythonProx::apply",
+            self.obj.py(),
+            self.obj
+                .call_method1(apply, (self.τ, x.own()))
+                .and_then(|r| r.extract()),
+        )
+        .unwrap()
+    }
+}
+
+impl<'a, 'py, Domain, Codomain> FromPyObject<'a, 'py>
+    for PythonMapping<'py, Domain, Codomain, HasProx>
+where
+    Domain: Space,
+    Codomain: ClosedSpace + FromPyObject<'a, 'py>,
+{
+    type Error = PyErr;
+
+    fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
+        let obj = obj_.to_owned();
+        // Verify that the necessary methods exist
+        for method in intern_many!(obj.py(), "apply", "prox") {
+            obj.getattr(method)?; //.downcast::<PyFunction>()?;
+        }
+        Ok(PythonMapping { obj, _phantoms: PhantomData })
+    }
+}

mercurial