| |
1 /*! |
| |
2 [`DifferentiableMapping`]s implemented in Python. |
| |
3 */ |
| |
4 use super::{process_error, NumpyArray_f64}; |
| |
5 use crate::dolfinx_access::DolfinxPyFunction_f64; |
| |
6 use alg_tools::direct_product::Pair; |
| |
7 use alg_tools::error::DynResult; |
| |
8 use alg_tools::linops::IdOp; |
| |
9 use alg_tools::mapping::{ClosedSpace, DifferentiableImpl, Instance, Mapping, Space}; |
| |
10 use anyhow::anyhow; |
| |
11 use ndarray::Dimension; |
| |
12 use numpy::{Ix1, Ix2}; |
| |
13 use pointsource_algs::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; |
| |
14 use pointsource_algs::prox_penalty::{RadonSquared, StepLengthBound, StepLengthBoundPair}; |
| |
15 use pyo3::conversion::FromPyObject; |
| |
16 use pyo3::intern; |
| |
17 use pyo3::prelude::*; |
| |
18 use pyo3::PyClass; |
| |
19 use std::marker::PhantomData; |
| |
20 |
| |
21 #[derive(Copy, Debug, Clone)] |
| |
22 /// Marker for differentiable PythonMappings |
| |
23 pub struct Differentiable<DerivativeDomainMarker>(DerivativeDomainMarker); |
| |
24 |
| |
25 #[derive(Copy, Debug, Clone)] |
| |
26 /// Marker for PythonMappings without further properties. |
| |
27 pub struct Basic; |
| |
28 |
| |
29 #[derive(Debug)] |
| |
30 pub struct PythonMapping<'py, Domain, Codomain, Marker> |
| |
31 where |
| |
32 Domain: Space, |
| |
33 Codomain: Space, |
| |
34 { |
| |
35 pub(super) obj: Bound<'py, PyAny>, |
| |
36 pub(super) _phantoms: PhantomData<(Domain, Codomain, Marker)>, |
| |
37 } |
| |
38 |
| |
39 macro_rules! intern_many { |
| |
40 ($py:expr, $($method:literal),*) => {[ $( |
| |
41 intern!($py, $method) |
| |
42 ),*]} |
| |
43 } |
| |
44 |
| |
45 // Export for super. |
| |
46 pub(super) use intern_many; |
| |
47 |
| |
48 impl<'a, 'py, Domain, Codomain> FromPyObject<'a, 'py> |
| |
49 for PythonMapping<'py, Domain, Codomain, Basic> |
| |
50 where |
| |
51 Domain: Space, |
| |
52 Codomain: Space + PyClass + FromPyObject<'a, 'py>, |
| |
53 { |
| |
54 type Error = PyErr; |
| |
55 |
| |
56 fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> { |
| |
57 let obj = obj_.to_owned(); |
| |
58 // Verify that the necessary methods exist |
| |
59 for method in intern_many!(obj.py(), "apply", "diff_lipschitz_factor") { |
| |
60 obj.getattr(method)?; //.downcast::<PyFunction>()?; |
| |
61 } |
| |
62 Ok(PythonMapping { obj, _phantoms: PhantomData }) |
| |
63 } |
| |
64 } |
| |
65 |
| |
66 impl<'a, 'py, Domain, Codomain, DerivativeMarker> FromPyObject<'a, 'py> |
| |
67 for PythonMapping<'py, Domain, Codomain, Differentiable<DerivativeMarker>> |
| |
68 where |
| |
69 Domain: Space, |
| |
70 Codomain: Space + FromPyObject<'a, 'py>, |
| |
71 { |
| |
72 type Error = PyErr; |
| |
73 |
| |
74 fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> { |
| |
75 let obj = obj_.to_owned(); |
| |
76 // Verify that the necessary methods exist |
| |
77 for method in intern_many!(obj.py(), "apply", "diff", "diff_lipschitz_factor") { |
| |
78 obj.getattr(method)?; //.downcast::<PyFunction>()?; |
| |
79 } |
| |
80 Ok(PythonMapping { obj, _phantoms: PhantomData }) |
| |
81 } |
| |
82 } |
| |
83 |
| |
84 impl<'py, Domain, Codomain, AnyMarker> Mapping<Domain> |
| |
85 for PythonMapping<'py, Domain, Codomain, AnyMarker> |
| |
86 where |
| |
87 Domain: Space, |
| |
88 Domain::Principal: IntoPyObject<'py>, |
| |
89 Codomain: ClosedSpace + for<'a> FromPyObject<'a, 'py, Error = PyErr>, |
| |
90 { |
| |
91 type Codomain = Codomain; |
| |
92 |
| |
93 /// Compute the value of `self` at `x`. |
| |
94 fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain { |
| |
95 // TODO: use references and internal mutability? |
| |
96 //x_py = x.own().to_python(py).unwrap(); |
| |
97 let apply = intern!(self.obj.py(), "apply"); |
| |
98 process_error( |
| |
99 "PythonMapping::apply", |
| |
100 self.obj.py(), |
| |
101 self.obj |
| |
102 .call_method1(apply, (x.own(),)) |
| |
103 .and_then(|r| r.extract()), |
| |
104 ) |
| |
105 .unwrap() |
| |
106 } |
| |
107 } |
| |
108 |
| |
109 impl<'py, Domain, Codomain, AnyMarker> PythonMapping<'py, Domain, Codomain, AnyMarker> |
| |
110 where |
| |
111 Domain: Space, |
| |
112 Codomain: Space + for<'a> FromPyObject<'a, 'py>, |
| |
113 { |
| |
114 pub(crate) fn get_obj(&self) -> Bound<'py, PyAny> { |
| |
115 self.obj.clone() |
| |
116 } |
| |
117 } |
| |
118 |
| |
119 macro_rules! impl_differentiable { |
| |
120 ($derivative:ty, $marker:ty) => { |
| |
121 impl<'py, Domain> DifferentiableImpl<Domain> |
| |
122 for PythonMapping<'py, Domain, f64, Differentiable<$marker>> |
| |
123 where |
| |
124 Domain: Space, |
| |
125 Domain::Principal: IntoPyObject<'py>, |
| |
126 //$derivative: Space + for<'py> FromPyObject<'py>, |
| |
127 //f64 : for<'py> FromPyObject<'py>, |
| |
128 { |
| |
129 type Derivative = $derivative; |
| |
130 |
| |
131 /// Compute the value of `self` at `x`. |
| |
132 fn differential_impl<I: Instance<Domain>>(&self, x: I) -> $derivative { |
| |
133 // TODO: use references and internal mutability? |
| |
134 //x_py = x.own().to_python(py).unwrap(); |
| |
135 let diff = intern!(self.obj.py(), "diff"); |
| |
136 process_error( |
| |
137 "PythonMapping::differential_impl", |
| |
138 self.obj.py(), |
| |
139 self.obj |
| |
140 .call_method1(diff, (x.own(),)) |
| |
141 .and_then(|r| r.extract()), |
| |
142 ) |
| |
143 .unwrap() |
| |
144 } |
| |
145 |
| |
146 fn apply_and_differential_impl<I: Instance<Domain>>( |
| |
147 &self, |
| |
148 x: I, |
| |
149 ) -> (<Self as Mapping<Domain>>::Codomain, $derivative) { |
| |
150 // TODO: use references and internal mutability? |
| |
151 //x_py = x.own().to_python(py).unwrap(); |
| |
152 let apply_and_diff = intern!(self.obj.py(), "apply_and_diff"); |
| |
153 process_error( |
| |
154 "PythonMapping::apply_and_differential_impl", |
| |
155 self.obj.py(), |
| |
156 self.obj |
| |
157 .call_method1(apply_and_diff, (x.own(),)) |
| |
158 .and_then(|r| r.extract()), |
| |
159 ) |
| |
160 .unwrap() |
| |
161 } |
| |
162 } |
| |
163 }; |
| |
164 } |
| |
165 |
| |
166 impl<'py, Domain, Marker> BoundedCurvature<f64> |
| |
167 for PythonMapping<'py, Domain, f64, Differentiable<Marker>> |
| |
168 where |
| |
169 Domain: Space, |
| |
170 { |
| |
171 fn curvature_bound_components( |
| |
172 &self, |
| |
173 _guess: BoundedCurvatureGuess, |
| |
174 ) -> (DynResult<f64>, DynResult<f64>) { |
| |
175 let m = intern!(self.obj.py(), "curvature_bound_components"); |
| |
176 match process_error::<(Option<f64>, Option<f64>)>( |
| |
177 "curvature_bound_components", |
| |
178 self.obj.py(), |
| |
179 self.obj.call_method0(m).and_then(|r| r.extract()), |
| |
180 ) { |
| |
181 Ok((l, θ2)) => ( |
| |
182 l.ok_or_else(|| anyhow!("l is None")), |
| |
183 θ2.ok_or_else(|| anyhow!("θ2 is None")), |
| |
184 ), |
| |
185 Err(e) => ( |
| |
186 Err(anyhow!("(same error as second component)")), |
| |
187 Err(e.into()), |
| |
188 ), |
| |
189 } |
| |
190 } |
| |
191 } |
| |
192 |
| |
193 impl<'py, Domain, Marker> |
| |
194 StepLengthBound<f64, PythonMapping<'py, Domain, f64, Differentiable<Marker>>> for RadonSquared |
| |
195 where |
| |
196 Domain: Space, |
| |
197 { |
| |
198 fn step_length_bound( |
| |
199 &self, |
| |
200 f: &PythonMapping<'py, Domain, f64, Differentiable<Marker>>, |
| |
201 ) -> DynResult<f64> { |
| |
202 let m = intern!(f.obj.py(), "diff_lipschitz_factor"); |
| |
203 process_error( |
| |
204 "PythonMapping::diff_lipschitz_factor", |
| |
205 f.obj.py(), |
| |
206 f.obj.call_method0(m).and_then(|r| r.extract()), |
| |
207 ) |
| |
208 } |
| |
209 } |
| |
210 |
| |
211 impl<'py, 'a, Domain, Marker, Z> |
| |
212 StepLengthBoundPair<f64, PythonMapping<'py, Domain, f64, Differentiable<Marker>>> |
| |
213 for Pair<&'a RadonSquared, &'a IdOp<Z>> |
| |
214 where |
| |
215 Domain: Space, |
| |
216 { |
| |
217 fn step_length_bound_pair( |
| |
218 &self, |
| |
219 f: &PythonMapping<'py, Domain, f64, Differentiable<Marker>>, |
| |
220 ) -> DynResult<(f64, f64)> { |
| |
221 let m = intern!(f.obj.py(), "diff_lipschitz_factor_pair"); |
| |
222 process_error( |
| |
223 "PythonMapping::diff_lipschitz_factor_pair", |
| |
224 f.obj.py(), |
| |
225 f.obj.call_method0(m).and_then(|r| r.extract()), |
| |
226 ) |
| |
227 } |
| |
228 } |
| |
229 |
| |
230 #[derive(Copy, Debug, Clone)] |
| |
231 /// This is a marker type for identifying the derivative codomain of a differentiable |
| |
232 /// [`PythonMapping`]. Ideally we would just use, e.g., |
| |
233 /// ``` |
| |
234 /// PythonMapping<Domain, f64, Differentiable<DolfinxPyFunction_f64<2,2,1>>> |
| |
235 /// ``` |
| |
236 /// to encode the derivative codomain into the type. However `DolfinxPyFunction_f64<2,2,1>` is |
| |
237 /// not [´Send`] so [`pyo3`] does not allow such a mapping in a `pyclass`. We don't actually |
| |
238 /// ever pass a `DolfinxPyFunction_f64` to Python—it's a Rust wrapper for a PyObject—but pyo3, |
| |
239 /// not knowing better, disallows the type from even appearing in signatures. (It's ok to |
| |
240 /// `pyo3` as `DifferentiableImpl::Domain`, though. That's why we need this marker in place |
| |
241 /// of the real type: |
| |
242 /// ``` |
| |
243 /// PythonMapping<Domain, f64, Differentiable<DolfinxPyFunctionMarker<2,2,1>>> |
| |
244 /// ``` |
| |
245 pub struct DolfinxPyFunctionMarker<const N: usize, const O: usize, const D: usize>; |
| |
246 pub struct NumpyArrayMarker<Ix: Dimension>(Ix); |
| |
247 |
| |
248 //impl_differentiable!(DolfinxPyFunction_f64<1,2,1>, DolfinxPyFunctionMarker<1,2,1>); |
| |
249 impl_differentiable!(DolfinxPyFunction_f64<'py, 2,2,1>, DolfinxPyFunctionMarker<2,2,1>); |
| |
250 impl_differentiable!( |
| |
251 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, NumpyArray_f64<'py, Ix1>>, |
| |
252 Pair<DolfinxPyFunctionMarker<2, 2, 1>, NumpyArrayMarker<Ix1>> |
| |
253 ); |
| |
254 impl_differentiable!( |
| |
255 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, NumpyArray_f64<'py, Ix2>>, |
| |
256 Pair<DolfinxPyFunctionMarker<2, 2, 1>, NumpyArrayMarker<Ix2>> |
| |
257 ); |
| |
258 impl_differentiable!( |
| |
259 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, Pair<f64, Pair<f64, f64>>>, |
| |
260 Pair<DolfinxPyFunctionMarker<2, 2, 1>, Pair<f64, Pair<f64, f64>>> |
| |
261 ); |
| |
262 impl_differentiable!( |
| |
263 Pair< |
| |
264 DolfinxPyFunction_f64<'py, 2, 2, 1>, |
| |
265 Pair<f64, Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>>, |
| |
266 >, |
| |
267 Pair< |
| |
268 DolfinxPyFunctionMarker<2, 2, 1>, |
| |
269 Pair<f64, Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>>, |
| |
270 > |
| |
271 ); |
| |
272 impl_differentiable!( |
| |
273 Pair< |
| |
274 DolfinxPyFunction_f64<'py, 2, 2, 1>, |
| |
275 Pair< |
| |
276 NumpyArray_f64<'py, Ix2>, |
| |
277 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>, |
| |
278 >, |
| |
279 >, |
| |
280 Pair< |
| |
281 DolfinxPyFunctionMarker<2, 2, 1>, |
| |
282 Pair< |
| |
283 NumpyArrayMarker<Ix2>, |
| |
284 Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>, |
| |
285 >, |
| |
286 > |
| |
287 ); |
| |
288 impl_differentiable!( |
| |
289 Pair< |
| |
290 DolfinxPyFunction_f64<'py, 2, 2, 1>, |
| |
291 Pair< |
| |
292 DolfinxPyFunction_f64<'py, 2, 2, 1>, |
| |
293 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>, |
| |
294 >, |
| |
295 >, |
| |
296 Pair< |
| |
297 DolfinxPyFunctionMarker<2, 2, 1>, |
| |
298 Pair< |
| |
299 DolfinxPyFunctionMarker<2, 2, 1>, |
| |
300 Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>, |
| |
301 >, |
| |
302 > |
| |
303 ); |