src/python_access/diff_mapping.rs

changeset 1
a4137aedcb3a
child 2
69002abe5dcb
child 3
c3a4f4bb87f7
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/python_access/diff_mapping.rs	Thu Feb 26 09:32:12 2026 -0500
@@ -0,0 +1,303 @@
+/*!
+[`DifferentiableMapping`]s implemented in Python.
+*/
+use super::{process_error, NumpyArray_f64};
+use crate::dolfinx_access::DolfinxPyFunction_f64;
+use alg_tools::direct_product::Pair;
+use alg_tools::error::DynResult;
+use alg_tools::linops::IdOp;
+use alg_tools::mapping::{ClosedSpace, DifferentiableImpl, Instance, Mapping, Space};
+use anyhow::anyhow;
+use ndarray::Dimension;
+use numpy::{Ix1, Ix2};
+use pointsource_algs::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
+use pointsource_algs::prox_penalty::{RadonSquared, StepLengthBound, StepLengthBoundPair};
+use pyo3::conversion::FromPyObject;
+use pyo3::intern;
+use pyo3::prelude::*;
+use pyo3::PyClass;
+use std::marker::PhantomData;
+
+#[derive(Copy, Debug, Clone)]
+/// Marker for differentiable PythonMappings
+pub struct Differentiable<DerivativeDomainMarker>(DerivativeDomainMarker);
+
+#[derive(Copy, Debug, Clone)]
+/// Marker for PythonMappings without further properties.
+pub struct Basic;
+
+#[derive(Debug)]
+pub struct PythonMapping<'py, Domain, Codomain, Marker>
+where
+    Domain: Space,
+    Codomain: Space,
+{
+    pub(super) obj: Bound<'py, PyAny>,
+    pub(super) _phantoms: PhantomData<(Domain, Codomain, Marker)>,
+}
+
+macro_rules! intern_many {
+    ($py:expr, $($method:literal),*) => {[ $(
+        intern!($py, $method)
+    ),*]}
+}
+
+// Export for super.
+pub(super) use intern_many;
+
+impl<'a, 'py, Domain, Codomain> FromPyObject<'a, 'py>
+    for PythonMapping<'py, Domain, Codomain, Basic>
+where
+    Domain: Space,
+    Codomain: Space + PyClass + FromPyObject<'a, 'py>,
+{
+    type Error = PyErr;
+
+    fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
+        let obj = obj_.to_owned();
+        // Verify that the necessary methods exist
+        for method in intern_many!(obj.py(), "apply", "diff_lipschitz_factor") {
+            obj.getattr(method)?; //.downcast::<PyFunction>()?;
+        }
+        Ok(PythonMapping { obj, _phantoms: PhantomData })
+    }
+}
+
+impl<'a, 'py, Domain, Codomain, DerivativeMarker> FromPyObject<'a, 'py>
+    for PythonMapping<'py, Domain, Codomain, Differentiable<DerivativeMarker>>
+where
+    Domain: Space,
+    Codomain: Space + FromPyObject<'a, 'py>,
+{
+    type Error = PyErr;
+
+    fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
+        let obj = obj_.to_owned();
+        // Verify that the necessary methods exist
+        for method in intern_many!(obj.py(), "apply", "diff", "diff_lipschitz_factor") {
+            obj.getattr(method)?; //.downcast::<PyFunction>()?;
+        }
+        Ok(PythonMapping { obj, _phantoms: PhantomData })
+    }
+}
+
+impl<'py, Domain, Codomain, AnyMarker> Mapping<Domain>
+    for PythonMapping<'py, Domain, Codomain, AnyMarker>
+where
+    Domain: Space,
+    Domain::Principal: IntoPyObject<'py>,
+    Codomain: ClosedSpace + for<'a> FromPyObject<'a, 'py, Error = PyErr>,
+{
+    type Codomain = Codomain;
+
+    /// Compute the value of `self` at `x`.
+    fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain {
+        // TODO: use references and internal mutability?
+        //x_py = x.own().to_python(py).unwrap();
+        let apply = intern!(self.obj.py(), "apply");
+        process_error(
+            "PythonMapping::apply",
+            self.obj.py(),
+            self.obj
+                .call_method1(apply, (x.own(),))
+                .and_then(|r| r.extract()),
+        )
+        .unwrap()
+    }
+}
+
+impl<'py, Domain, Codomain, AnyMarker> PythonMapping<'py, Domain, Codomain, AnyMarker>
+where
+    Domain: Space,
+    Codomain: Space + for<'a> FromPyObject<'a, 'py>,
+{
+    pub(crate) fn get_obj(&self) -> Bound<'py, PyAny> {
+        self.obj.clone()
+    }
+}
+
+macro_rules! impl_differentiable {
+    ($derivative:ty, $marker:ty) => {
+        impl<'py, Domain> DifferentiableImpl<Domain>
+            for PythonMapping<'py, Domain, f64, Differentiable<$marker>>
+        where
+            Domain: Space,
+            Domain::Principal: IntoPyObject<'py>,
+            //$derivative: Space + for<'py> FromPyObject<'py>,
+            //f64 : for<'py> FromPyObject<'py>,
+        {
+            type Derivative = $derivative;
+
+            /// Compute the value of `self` at `x`.
+            fn differential_impl<I: Instance<Domain>>(&self, x: I) -> $derivative {
+                // TODO: use references and internal mutability?
+                //x_py = x.own().to_python(py).unwrap();
+                let diff = intern!(self.obj.py(), "diff");
+                process_error(
+                    "PythonMapping::differential_impl",
+                    self.obj.py(),
+                    self.obj
+                        .call_method1(diff, (x.own(),))
+                        .and_then(|r| r.extract()),
+                )
+                .unwrap()
+            }
+
+            fn apply_and_differential_impl<I: Instance<Domain>>(
+                &self,
+                x: I,
+            ) -> (<Self as Mapping<Domain>>::Codomain, $derivative) {
+                // TODO: use references and internal mutability?
+                //x_py = x.own().to_python(py).unwrap();
+                let apply_and_diff = intern!(self.obj.py(), "apply_and_diff");
+                process_error(
+                    "PythonMapping::apply_and_differential_impl",
+                    self.obj.py(),
+                    self.obj
+                        .call_method1(apply_and_diff, (x.own(),))
+                        .and_then(|r| r.extract()),
+                )
+                .unwrap()
+            }
+        }
+    };
+}
+
+impl<'py, Domain, Marker> BoundedCurvature<f64>
+    for PythonMapping<'py, Domain, f64, Differentiable<Marker>>
+where
+    Domain: Space,
+{
+    fn curvature_bound_components(
+        &self,
+        _guess: BoundedCurvatureGuess,
+    ) -> (DynResult<f64>, DynResult<f64>) {
+        let m = intern!(self.obj.py(), "curvature_bound_components");
+        match process_error::<(Option<f64>, Option<f64>)>(
+            "curvature_bound_components",
+            self.obj.py(),
+            self.obj.call_method0(m).and_then(|r| r.extract()),
+        ) {
+            Ok((l, θ2)) => (
+                l.ok_or_else(|| anyhow!("l is None")),
+                θ2.ok_or_else(|| anyhow!("θ2 is None")),
+            ),
+            Err(e) => (
+                Err(anyhow!("(same error as second component)")),
+                Err(e.into()),
+            ),
+        }
+    }
+}
+
+impl<'py, Domain, Marker>
+    StepLengthBound<f64, PythonMapping<'py, Domain, f64, Differentiable<Marker>>> for RadonSquared
+where
+    Domain: Space,
+{
+    fn step_length_bound(
+        &self,
+        f: &PythonMapping<'py, Domain, f64, Differentiable<Marker>>,
+    ) -> DynResult<f64> {
+        let m = intern!(f.obj.py(), "diff_lipschitz_factor");
+        process_error(
+            "PythonMapping::diff_lipschitz_factor",
+            f.obj.py(),
+            f.obj.call_method0(m).and_then(|r| r.extract()),
+        )
+    }
+}
+
+impl<'py, 'a, Domain, Marker, Z>
+    StepLengthBoundPair<f64, PythonMapping<'py, Domain, f64, Differentiable<Marker>>>
+    for Pair<&'a RadonSquared, &'a IdOp<Z>>
+where
+    Domain: Space,
+{
+    fn step_length_bound_pair(
+        &self,
+        f: &PythonMapping<'py, Domain, f64, Differentiable<Marker>>,
+    ) -> DynResult<(f64, f64)> {
+        let m = intern!(f.obj.py(), "diff_lipschitz_factor_pair");
+        process_error(
+            "PythonMapping::diff_lipschitz_factor_pair",
+            f.obj.py(),
+            f.obj.call_method0(m).and_then(|r| r.extract()),
+        )
+    }
+}
+
+#[derive(Copy, Debug, Clone)]
+/// This is a marker type for identifying the derivative codomain of a differentiable
+/// [`PythonMapping`]. Ideally we would just use, e.g.,
+/// ```
+/// PythonMapping<Domain, f64, Differentiable<DolfinxPyFunction_f64<2,2,1>>>
+/// ```
+/// to encode the derivative codomain into the type. However `DolfinxPyFunction_f64<2,2,1>` is
+/// not [´Send`] so [`pyo3`] does not allow such a mapping in a `pyclass`. We don't actually
+/// ever pass a `DolfinxPyFunction_f64` to Python—it's a Rust wrapper for a PyObject—but pyo3,
+/// not knowing better, disallows the type from even appearing in signatures. (It's ok to
+/// `pyo3` as `DifferentiableImpl::Domain`, though. That's why we  need this marker in place
+/// of the real type:
+/// ```
+/// PythonMapping<Domain, f64, Differentiable<DolfinxPyFunctionMarker<2,2,1>>>
+/// ```
+pub struct DolfinxPyFunctionMarker<const N: usize, const O: usize, const D: usize>;
+pub struct NumpyArrayMarker<Ix: Dimension>(Ix);
+
+//impl_differentiable!(DolfinxPyFunction_f64<1,2,1>, DolfinxPyFunctionMarker<1,2,1>);
+impl_differentiable!(DolfinxPyFunction_f64<'py, 2,2,1>, DolfinxPyFunctionMarker<2,2,1>);
+impl_differentiable!(
+    Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, NumpyArray_f64<'py, Ix1>>,
+    Pair<DolfinxPyFunctionMarker<2, 2, 1>, NumpyArrayMarker<Ix1>>
+);
+impl_differentiable!(
+    Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, NumpyArray_f64<'py, Ix2>>,
+    Pair<DolfinxPyFunctionMarker<2, 2, 1>, NumpyArrayMarker<Ix2>>
+);
+impl_differentiable!(
+    Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, Pair<f64, Pair<f64, f64>>>,
+    Pair<DolfinxPyFunctionMarker<2, 2, 1>, Pair<f64, Pair<f64, f64>>>
+);
+impl_differentiable!(
+    Pair<
+        DolfinxPyFunction_f64<'py, 2, 2, 1>,
+        Pair<f64, Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>>,
+    >,
+    Pair<
+        DolfinxPyFunctionMarker<2, 2, 1>,
+        Pair<f64, Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>>,
+    >
+);
+impl_differentiable!(
+    Pair<
+        DolfinxPyFunction_f64<'py, 2, 2, 1>,
+        Pair<
+            NumpyArray_f64<'py, Ix2>,
+            Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>,
+        >,
+    >,
+    Pair<
+        DolfinxPyFunctionMarker<2, 2, 1>,
+        Pair<
+            NumpyArrayMarker<Ix2>,
+            Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>,
+        >,
+    >
+);
+impl_differentiable!(
+    Pair<
+        DolfinxPyFunction_f64<'py, 2, 2, 1>,
+        Pair<
+            DolfinxPyFunction_f64<'py, 2, 2, 1>,
+            Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>,
+        >,
+    >,
+    Pair<
+        DolfinxPyFunctionMarker<2, 2, 1>,
+        Pair<
+            DolfinxPyFunctionMarker<2, 2, 1>,
+            Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>,
+        >,
+    >
+);

mercurial