src/python_access/prox_mapping.rs

Fri, 08 May 2026 17:28:21 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Fri, 08 May 2026 17:28:21 -0500
changeset 5
3e164c024a01
parent 1
a4137aedcb3a
permissions
-rw-r--r--

Change README title

/*!
Implementation of [`Prox`] for mappings implemented in Python.
*/

use crate::python_access::process_error;

use super::diff_mapping::intern_many;
use super::diff_mapping::PythonMapping;
use alg_tools::convex::Prox;
use alg_tools::mapping::{ClosedSpace, Instance, Mapping, Space};
use pyo3::conversion::FromPyObject;
use pyo3::intern;
use pyo3::prelude::*;
use std::marker::PhantomData;

#[derive(Copy, Debug, Clone)]
/// Marker for PythonMappings that implement Prox
pub struct HasProx;

impl<'py, Domain> Prox<Domain> for PythonMapping<'py, Domain, f64, HasProx>
where
    Domain: ClosedSpace + Clone + for<'a> FromPyObject<'a, 'py, Error = PyErr> + IntoPyObject<'py>,
{
    type Prox<'a>
        = PythonProx<'py, Domain>
    where
        Self: 'a;

    fn prox_mapping(&self, τ: Self::Codomain) -> Self::Prox<'_> {
        PythonProx { τ, obj: self.get_obj(), _phantoms: PhantomData }
    }
}

#[derive(Debug)]
pub struct PythonProx<'py, Domain>
where
    Domain: Space + for<'a> FromPyObject<'a, 'py>,
{
    τ: f64,
    obj: Bound<'py, PyAny>,
    _phantoms: PhantomData<Domain>,
}

impl<'py, Domain> Mapping<Domain> for PythonProx<'py, Domain>
where
    Domain: ClosedSpace + for<'a> FromPyObject<'a, 'py, Error = PyErr> + IntoPyObject<'py>,
{
    type Codomain = Domain;

    /// Compute the value of `self` at `x`.
    fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain {
        // TODO: use references and internal mutability?
        let apply = intern!(self.obj.py(), "prox");
        process_error(
            "PythonProx::apply",
            self.obj.py(),
            self.obj
                .call_method1(apply, (self.τ, x.own()))
                .and_then(|r| r.extract()),
        )
        .unwrap()
    }
}

impl<'a, 'py, Domain, Codomain> FromPyObject<'a, 'py>
    for PythonMapping<'py, Domain, Codomain, HasProx>
where
    Domain: Space,
    Codomain: ClosedSpace + FromPyObject<'a, 'py>,
{
    type Error = PyErr;

    fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
        let obj = obj_.to_owned();
        // Verify that the necessary methods exist
        for method in intern_many!(obj.py(), "apply", "prox") {
            obj.getattr(method)?; //.downcast::<PyFunction>()?;
        }
        Ok(PythonMapping { obj, _phantoms: PhantomData })
    }
}

mercurial