src/python_access/diff_mapping.rs

changeset 1
a4137aedcb3a
child 2
69002abe5dcb
child 3
c3a4f4bb87f7
equal deleted inserted replaced
0:7ec1cfe19a24 1:a4137aedcb3a
1 /*!
2 [`DifferentiableMapping`]s implemented in Python.
3 */
4 use super::{process_error, NumpyArray_f64};
5 use crate::dolfinx_access::DolfinxPyFunction_f64;
6 use alg_tools::direct_product::Pair;
7 use alg_tools::error::DynResult;
8 use alg_tools::linops::IdOp;
9 use alg_tools::mapping::{ClosedSpace, DifferentiableImpl, Instance, Mapping, Space};
10 use anyhow::anyhow;
11 use ndarray::Dimension;
12 use numpy::{Ix1, Ix2};
13 use pointsource_algs::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
14 use pointsource_algs::prox_penalty::{RadonSquared, StepLengthBound, StepLengthBoundPair};
15 use pyo3::conversion::FromPyObject;
16 use pyo3::intern;
17 use pyo3::prelude::*;
18 use pyo3::PyClass;
19 use std::marker::PhantomData;
20
21 #[derive(Copy, Debug, Clone)]
22 /// Marker for differentiable PythonMappings
23 pub struct Differentiable<DerivativeDomainMarker>(DerivativeDomainMarker);
24
25 #[derive(Copy, Debug, Clone)]
26 /// Marker for PythonMappings without further properties.
27 pub struct Basic;
28
29 #[derive(Debug)]
30 pub struct PythonMapping<'py, Domain, Codomain, Marker>
31 where
32 Domain: Space,
33 Codomain: Space,
34 {
35 pub(super) obj: Bound<'py, PyAny>,
36 pub(super) _phantoms: PhantomData<(Domain, Codomain, Marker)>,
37 }
38
39 macro_rules! intern_many {
40 ($py:expr, $($method:literal),*) => {[ $(
41 intern!($py, $method)
42 ),*]}
43 }
44
45 // Export for super.
46 pub(super) use intern_many;
47
48 impl<'a, 'py, Domain, Codomain> FromPyObject<'a, 'py>
49 for PythonMapping<'py, Domain, Codomain, Basic>
50 where
51 Domain: Space,
52 Codomain: Space + PyClass + FromPyObject<'a, 'py>,
53 {
54 type Error = PyErr;
55
56 fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
57 let obj = obj_.to_owned();
58 // Verify that the necessary methods exist
59 for method in intern_many!(obj.py(), "apply", "diff_lipschitz_factor") {
60 obj.getattr(method)?; //.downcast::<PyFunction>()?;
61 }
62 Ok(PythonMapping { obj, _phantoms: PhantomData })
63 }
64 }
65
66 impl<'a, 'py, Domain, Codomain, DerivativeMarker> FromPyObject<'a, 'py>
67 for PythonMapping<'py, Domain, Codomain, Differentiable<DerivativeMarker>>
68 where
69 Domain: Space,
70 Codomain: Space + FromPyObject<'a, 'py>,
71 {
72 type Error = PyErr;
73
74 fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
75 let obj = obj_.to_owned();
76 // Verify that the necessary methods exist
77 for method in intern_many!(obj.py(), "apply", "diff", "diff_lipschitz_factor") {
78 obj.getattr(method)?; //.downcast::<PyFunction>()?;
79 }
80 Ok(PythonMapping { obj, _phantoms: PhantomData })
81 }
82 }
83
84 impl<'py, Domain, Codomain, AnyMarker> Mapping<Domain>
85 for PythonMapping<'py, Domain, Codomain, AnyMarker>
86 where
87 Domain: Space,
88 Domain::Principal: IntoPyObject<'py>,
89 Codomain: ClosedSpace + for<'a> FromPyObject<'a, 'py, Error = PyErr>,
90 {
91 type Codomain = Codomain;
92
93 /// Compute the value of `self` at `x`.
94 fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain {
95 // TODO: use references and internal mutability?
96 //x_py = x.own().to_python(py).unwrap();
97 let apply = intern!(self.obj.py(), "apply");
98 process_error(
99 "PythonMapping::apply",
100 self.obj.py(),
101 self.obj
102 .call_method1(apply, (x.own(),))
103 .and_then(|r| r.extract()),
104 )
105 .unwrap()
106 }
107 }
108
109 impl<'py, Domain, Codomain, AnyMarker> PythonMapping<'py, Domain, Codomain, AnyMarker>
110 where
111 Domain: Space,
112 Codomain: Space + for<'a> FromPyObject<'a, 'py>,
113 {
114 pub(crate) fn get_obj(&self) -> Bound<'py, PyAny> {
115 self.obj.clone()
116 }
117 }
118
119 macro_rules! impl_differentiable {
120 ($derivative:ty, $marker:ty) => {
121 impl<'py, Domain> DifferentiableImpl<Domain>
122 for PythonMapping<'py, Domain, f64, Differentiable<$marker>>
123 where
124 Domain: Space,
125 Domain::Principal: IntoPyObject<'py>,
126 //$derivative: Space + for<'py> FromPyObject<'py>,
127 //f64 : for<'py> FromPyObject<'py>,
128 {
129 type Derivative = $derivative;
130
131 /// Compute the value of `self` at `x`.
132 fn differential_impl<I: Instance<Domain>>(&self, x: I) -> $derivative {
133 // TODO: use references and internal mutability?
134 //x_py = x.own().to_python(py).unwrap();
135 let diff = intern!(self.obj.py(), "diff");
136 process_error(
137 "PythonMapping::differential_impl",
138 self.obj.py(),
139 self.obj
140 .call_method1(diff, (x.own(),))
141 .and_then(|r| r.extract()),
142 )
143 .unwrap()
144 }
145
146 fn apply_and_differential_impl<I: Instance<Domain>>(
147 &self,
148 x: I,
149 ) -> (<Self as Mapping<Domain>>::Codomain, $derivative) {
150 // TODO: use references and internal mutability?
151 //x_py = x.own().to_python(py).unwrap();
152 let apply_and_diff = intern!(self.obj.py(), "apply_and_diff");
153 process_error(
154 "PythonMapping::apply_and_differential_impl",
155 self.obj.py(),
156 self.obj
157 .call_method1(apply_and_diff, (x.own(),))
158 .and_then(|r| r.extract()),
159 )
160 .unwrap()
161 }
162 }
163 };
164 }
165
166 impl<'py, Domain, Marker> BoundedCurvature<f64>
167 for PythonMapping<'py, Domain, f64, Differentiable<Marker>>
168 where
169 Domain: Space,
170 {
171 fn curvature_bound_components(
172 &self,
173 _guess: BoundedCurvatureGuess,
174 ) -> (DynResult<f64>, DynResult<f64>) {
175 let m = intern!(self.obj.py(), "curvature_bound_components");
176 match process_error::<(Option<f64>, Option<f64>)>(
177 "curvature_bound_components",
178 self.obj.py(),
179 self.obj.call_method0(m).and_then(|r| r.extract()),
180 ) {
181 Ok((l, θ2)) => (
182 l.ok_or_else(|| anyhow!("l is None")),
183 θ2.ok_or_else(|| anyhow!("θ2 is None")),
184 ),
185 Err(e) => (
186 Err(anyhow!("(same error as second component)")),
187 Err(e.into()),
188 ),
189 }
190 }
191 }
192
193 impl<'py, Domain, Marker>
194 StepLengthBound<f64, PythonMapping<'py, Domain, f64, Differentiable<Marker>>> for RadonSquared
195 where
196 Domain: Space,
197 {
198 fn step_length_bound(
199 &self,
200 f: &PythonMapping<'py, Domain, f64, Differentiable<Marker>>,
201 ) -> DynResult<f64> {
202 let m = intern!(f.obj.py(), "diff_lipschitz_factor");
203 process_error(
204 "PythonMapping::diff_lipschitz_factor",
205 f.obj.py(),
206 f.obj.call_method0(m).and_then(|r| r.extract()),
207 )
208 }
209 }
210
211 impl<'py, 'a, Domain, Marker, Z>
212 StepLengthBoundPair<f64, PythonMapping<'py, Domain, f64, Differentiable<Marker>>>
213 for Pair<&'a RadonSquared, &'a IdOp<Z>>
214 where
215 Domain: Space,
216 {
217 fn step_length_bound_pair(
218 &self,
219 f: &PythonMapping<'py, Domain, f64, Differentiable<Marker>>,
220 ) -> DynResult<(f64, f64)> {
221 let m = intern!(f.obj.py(), "diff_lipschitz_factor_pair");
222 process_error(
223 "PythonMapping::diff_lipschitz_factor_pair",
224 f.obj.py(),
225 f.obj.call_method0(m).and_then(|r| r.extract()),
226 )
227 }
228 }
229
230 #[derive(Copy, Debug, Clone)]
231 /// This is a marker type for identifying the derivative codomain of a differentiable
232 /// [`PythonMapping`]. Ideally we would just use, e.g.,
233 /// ```
234 /// PythonMapping<Domain, f64, Differentiable<DolfinxPyFunction_f64<2,2,1>>>
235 /// ```
236 /// to encode the derivative codomain into the type. However `DolfinxPyFunction_f64<2,2,1>` is
237 /// not [´Send`] so [`pyo3`] does not allow such a mapping in a `pyclass`. We don't actually
238 /// ever pass a `DolfinxPyFunction_f64` to Python—it's a Rust wrapper for a PyObject—but pyo3,
239 /// not knowing better, disallows the type from even appearing in signatures. (It's ok to
240 /// `pyo3` as `DifferentiableImpl::Domain`, though. That's why we need this marker in place
241 /// of the real type:
242 /// ```
243 /// PythonMapping<Domain, f64, Differentiable<DolfinxPyFunctionMarker<2,2,1>>>
244 /// ```
245 pub struct DolfinxPyFunctionMarker<const N: usize, const O: usize, const D: usize>;
246 pub struct NumpyArrayMarker<Ix: Dimension>(Ix);
247
248 //impl_differentiable!(DolfinxPyFunction_f64<1,2,1>, DolfinxPyFunctionMarker<1,2,1>);
249 impl_differentiable!(DolfinxPyFunction_f64<'py, 2,2,1>, DolfinxPyFunctionMarker<2,2,1>);
250 impl_differentiable!(
251 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, NumpyArray_f64<'py, Ix1>>,
252 Pair<DolfinxPyFunctionMarker<2, 2, 1>, NumpyArrayMarker<Ix1>>
253 );
254 impl_differentiable!(
255 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, NumpyArray_f64<'py, Ix2>>,
256 Pair<DolfinxPyFunctionMarker<2, 2, 1>, NumpyArrayMarker<Ix2>>
257 );
258 impl_differentiable!(
259 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, Pair<f64, Pair<f64, f64>>>,
260 Pair<DolfinxPyFunctionMarker<2, 2, 1>, Pair<f64, Pair<f64, f64>>>
261 );
262 impl_differentiable!(
263 Pair<
264 DolfinxPyFunction_f64<'py, 2, 2, 1>,
265 Pair<f64, Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>>,
266 >,
267 Pair<
268 DolfinxPyFunctionMarker<2, 2, 1>,
269 Pair<f64, Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>>,
270 >
271 );
272 impl_differentiable!(
273 Pair<
274 DolfinxPyFunction_f64<'py, 2, 2, 1>,
275 Pair<
276 NumpyArray_f64<'py, Ix2>,
277 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>,
278 >,
279 >,
280 Pair<
281 DolfinxPyFunctionMarker<2, 2, 1>,
282 Pair<
283 NumpyArrayMarker<Ix2>,
284 Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>,
285 >,
286 >
287 );
288 impl_differentiable!(
289 Pair<
290 DolfinxPyFunction_f64<'py, 2, 2, 1>,
291 Pair<
292 DolfinxPyFunction_f64<'py, 2, 2, 1>,
293 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>,
294 >,
295 >,
296 Pair<
297 DolfinxPyFunctionMarker<2, 2, 1>,
298 Pair<
299 DolfinxPyFunctionMarker<2, 2, 1>,
300 Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>,
301 >,
302 >
303 );

mercurial