src/python_access/prox_mapping.rs

changeset 1
a4137aedcb3a
equal deleted inserted replaced
0:7ec1cfe19a24 1:a4137aedcb3a
1 /*!
2 Implementation of [`Prox`] for mappings implemented in Python.
3 */
4
5 use crate::python_access::process_error;
6
7 use super::diff_mapping::intern_many;
8 use super::diff_mapping::PythonMapping;
9 use alg_tools::convex::Prox;
10 use alg_tools::mapping::{ClosedSpace, Instance, Mapping, Space};
11 use pyo3::conversion::FromPyObject;
12 use pyo3::intern;
13 use pyo3::prelude::*;
14 use std::marker::PhantomData;
15
16 #[derive(Copy, Debug, Clone)]
17 /// Marker for PythonMappings that implement Prox
18 pub struct HasProx;
19
20 impl<'py, Domain> Prox<Domain> for PythonMapping<'py, Domain, f64, HasProx>
21 where
22 Domain: ClosedSpace + Clone + for<'a> FromPyObject<'a, 'py, Error = PyErr> + IntoPyObject<'py>,
23 {
24 type Prox<'a>
25 = PythonProx<'py, Domain>
26 where
27 Self: 'a;
28
29 fn prox_mapping(&self, τ: Self::Codomain) -> Self::Prox<'_> {
30 PythonProx { τ, obj: self.get_obj(), _phantoms: PhantomData }
31 }
32 }
33
34 #[derive(Debug)]
35 pub struct PythonProx<'py, Domain>
36 where
37 Domain: Space + for<'a> FromPyObject<'a, 'py>,
38 {
39 τ: f64,
40 obj: Bound<'py, PyAny>,
41 _phantoms: PhantomData<Domain>,
42 }
43
44 impl<'py, Domain> Mapping<Domain> for PythonProx<'py, Domain>
45 where
46 Domain: ClosedSpace + for<'a> FromPyObject<'a, 'py, Error = PyErr> + IntoPyObject<'py>,
47 {
48 type Codomain = Domain;
49
50 /// Compute the value of `self` at `x`.
51 fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain {
52 // TODO: use references and internal mutability?
53 let apply = intern!(self.obj.py(), "prox");
54 process_error(
55 "PythonProx::apply",
56 self.obj.py(),
57 self.obj
58 .call_method1(apply, (self.τ, x.own()))
59 .and_then(|r| r.extract()),
60 )
61 .unwrap()
62 }
63 }
64
65 impl<'a, 'py, Domain, Codomain> FromPyObject<'a, 'py>
66 for PythonMapping<'py, Domain, Codomain, HasProx>
67 where
68 Domain: Space,
69 Codomain: ClosedSpace + FromPyObject<'a, 'py>,
70 {
71 type Error = PyErr;
72
73 fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
74 let obj = obj_.to_owned();
75 // Verify that the necessary methods exist
76 for method in intern_many!(obj.py(), "apply", "prox") {
77 obj.getattr(method)?; //.downcast::<PyFunction>()?;
78 }
79 Ok(PythonMapping { obj, _phantoms: PhantomData })
80 }
81 }

mercurial