| |
1 /*! |
| |
2 Implementation of [`Prox`] for mappings implemented in Python. |
| |
3 */ |
| |
4 |
| |
5 use crate::python_access::process_error; |
| |
6 |
| |
7 use super::diff_mapping::intern_many; |
| |
8 use super::diff_mapping::PythonMapping; |
| |
9 use alg_tools::convex::Prox; |
| |
10 use alg_tools::mapping::{ClosedSpace, Instance, Mapping, Space}; |
| |
11 use pyo3::conversion::FromPyObject; |
| |
12 use pyo3::intern; |
| |
13 use pyo3::prelude::*; |
| |
14 use std::marker::PhantomData; |
| |
15 |
| |
16 #[derive(Copy, Debug, Clone)] |
| |
17 /// Marker for PythonMappings that implement Prox |
| |
18 pub struct HasProx; |
| |
19 |
| |
20 impl<'py, Domain> Prox<Domain> for PythonMapping<'py, Domain, f64, HasProx> |
| |
21 where |
| |
22 Domain: ClosedSpace + Clone + for<'a> FromPyObject<'a, 'py, Error = PyErr> + IntoPyObject<'py>, |
| |
23 { |
| |
24 type Prox<'a> |
| |
25 = PythonProx<'py, Domain> |
| |
26 where |
| |
27 Self: 'a; |
| |
28 |
| |
29 fn prox_mapping(&self, τ: Self::Codomain) -> Self::Prox<'_> { |
| |
30 PythonProx { τ, obj: self.get_obj(), _phantoms: PhantomData } |
| |
31 } |
| |
32 } |
| |
33 |
| |
34 #[derive(Debug)] |
| |
35 pub struct PythonProx<'py, Domain> |
| |
36 where |
| |
37 Domain: Space + for<'a> FromPyObject<'a, 'py>, |
| |
38 { |
| |
39 τ: f64, |
| |
40 obj: Bound<'py, PyAny>, |
| |
41 _phantoms: PhantomData<Domain>, |
| |
42 } |
| |
43 |
| |
44 impl<'py, Domain> Mapping<Domain> for PythonProx<'py, Domain> |
| |
45 where |
| |
46 Domain: ClosedSpace + for<'a> FromPyObject<'a, 'py, Error = PyErr> + IntoPyObject<'py>, |
| |
47 { |
| |
48 type Codomain = Domain; |
| |
49 |
| |
50 /// Compute the value of `self` at `x`. |
| |
51 fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain { |
| |
52 // TODO: use references and internal mutability? |
| |
53 let apply = intern!(self.obj.py(), "prox"); |
| |
54 process_error( |
| |
55 "PythonProx::apply", |
| |
56 self.obj.py(), |
| |
57 self.obj |
| |
58 .call_method1(apply, (self.τ, x.own())) |
| |
59 .and_then(|r| r.extract()), |
| |
60 ) |
| |
61 .unwrap() |
| |
62 } |
| |
63 } |
| |
64 |
| |
65 impl<'a, 'py, Domain, Codomain> FromPyObject<'a, 'py> |
| |
66 for PythonMapping<'py, Domain, Codomain, HasProx> |
| |
67 where |
| |
68 Domain: Space, |
| |
69 Codomain: ClosedSpace + FromPyObject<'a, 'py>, |
| |
70 { |
| |
71 type Error = PyErr; |
| |
72 |
| |
73 fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> { |
| |
74 let obj = obj_.to_owned(); |
| |
75 // Verify that the necessary methods exist |
| |
76 for method in intern_many!(obj.py(), "apply", "prox") { |
| |
77 obj.getattr(method)?; //.downcast::<PyFunction>()?; |
| |
78 } |
| |
79 Ok(PythonMapping { obj, _phantoms: PhantomData }) |
| |
80 } |
| |
81 } |