diff -r 7ec1cfe19a24 -r a4137aedcb3a src/python_access/diff_mapping.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/python_access/diff_mapping.rs Thu Feb 26 09:32:12 2026 -0500 @@ -0,0 +1,303 @@ +/*! +[`DifferentiableMapping`]s implemented in Python. +*/ +use super::{process_error, NumpyArray_f64}; +use crate::dolfinx_access::DolfinxPyFunction_f64; +use alg_tools::direct_product::Pair; +use alg_tools::error::DynResult; +use alg_tools::linops::IdOp; +use alg_tools::mapping::{ClosedSpace, DifferentiableImpl, Instance, Mapping, Space}; +use anyhow::anyhow; +use ndarray::Dimension; +use numpy::{Ix1, Ix2}; +use pointsource_algs::forward_model::{BoundedCurvature, BoundedCurvatureGuess}; +use pointsource_algs::prox_penalty::{RadonSquared, StepLengthBound, StepLengthBoundPair}; +use pyo3::conversion::FromPyObject; +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::PyClass; +use std::marker::PhantomData; + +#[derive(Copy, Debug, Clone)] +/// Marker for differentiable PythonMappings +pub struct Differentiable(DerivativeDomainMarker); + +#[derive(Copy, Debug, Clone)] +/// Marker for PythonMappings without further properties. +pub struct Basic; + +#[derive(Debug)] +pub struct PythonMapping<'py, Domain, Codomain, Marker> +where + Domain: Space, + Codomain: Space, +{ + pub(super) obj: Bound<'py, PyAny>, + pub(super) _phantoms: PhantomData<(Domain, Codomain, Marker)>, +} + +macro_rules! intern_many { + ($py:expr, $($method:literal),*) => {[ $( + intern!($py, $method) + ),*]} +} + +// Export for super. +pub(super) use intern_many; + +impl<'a, 'py, Domain, Codomain> FromPyObject<'a, 'py> + for PythonMapping<'py, Domain, Codomain, Basic> +where + Domain: Space, + Codomain: Space + PyClass + FromPyObject<'a, 'py>, +{ + type Error = PyErr; + + fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult { + let obj = obj_.to_owned(); + // Verify that the necessary methods exist + for method in intern_many!(obj.py(), "apply", "diff_lipschitz_factor") { + obj.getattr(method)?; //.downcast::()?; + } + Ok(PythonMapping { obj, _phantoms: PhantomData }) + } +} + +impl<'a, 'py, Domain, Codomain, DerivativeMarker> FromPyObject<'a, 'py> + for PythonMapping<'py, Domain, Codomain, Differentiable> +where + Domain: Space, + Codomain: Space + FromPyObject<'a, 'py>, +{ + type Error = PyErr; + + fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult { + let obj = obj_.to_owned(); + // Verify that the necessary methods exist + for method in intern_many!(obj.py(), "apply", "diff", "diff_lipschitz_factor") { + obj.getattr(method)?; //.downcast::()?; + } + Ok(PythonMapping { obj, _phantoms: PhantomData }) + } +} + +impl<'py, Domain, Codomain, AnyMarker> Mapping + for PythonMapping<'py, Domain, Codomain, AnyMarker> +where + Domain: Space, + Domain::Principal: IntoPyObject<'py>, + Codomain: ClosedSpace + for<'a> FromPyObject<'a, 'py, Error = PyErr>, +{ + type Codomain = Codomain; + + /// Compute the value of `self` at `x`. + fn apply>(&self, x: I) -> Self::Codomain { + // TODO: use references and internal mutability? + //x_py = x.own().to_python(py).unwrap(); + let apply = intern!(self.obj.py(), "apply"); + process_error( + "PythonMapping::apply", + self.obj.py(), + self.obj + .call_method1(apply, (x.own(),)) + .and_then(|r| r.extract()), + ) + .unwrap() + } +} + +impl<'py, Domain, Codomain, AnyMarker> PythonMapping<'py, Domain, Codomain, AnyMarker> +where + Domain: Space, + Codomain: Space + for<'a> FromPyObject<'a, 'py>, +{ + pub(crate) fn get_obj(&self) -> Bound<'py, PyAny> { + self.obj.clone() + } +} + +macro_rules! impl_differentiable { + ($derivative:ty, $marker:ty) => { + impl<'py, Domain> DifferentiableImpl + for PythonMapping<'py, Domain, f64, Differentiable<$marker>> + where + Domain: Space, + Domain::Principal: IntoPyObject<'py>, + //$derivative: Space + for<'py> FromPyObject<'py>, + //f64 : for<'py> FromPyObject<'py>, + { + type Derivative = $derivative; + + /// Compute the value of `self` at `x`. + fn differential_impl>(&self, x: I) -> $derivative { + // TODO: use references and internal mutability? + //x_py = x.own().to_python(py).unwrap(); + let diff = intern!(self.obj.py(), "diff"); + process_error( + "PythonMapping::differential_impl", + self.obj.py(), + self.obj + .call_method1(diff, (x.own(),)) + .and_then(|r| r.extract()), + ) + .unwrap() + } + + fn apply_and_differential_impl>( + &self, + x: I, + ) -> (>::Codomain, $derivative) { + // TODO: use references and internal mutability? + //x_py = x.own().to_python(py).unwrap(); + let apply_and_diff = intern!(self.obj.py(), "apply_and_diff"); + process_error( + "PythonMapping::apply_and_differential_impl", + self.obj.py(), + self.obj + .call_method1(apply_and_diff, (x.own(),)) + .and_then(|r| r.extract()), + ) + .unwrap() + } + } + }; +} + +impl<'py, Domain, Marker> BoundedCurvature + for PythonMapping<'py, Domain, f64, Differentiable> +where + Domain: Space, +{ + fn curvature_bound_components( + &self, + _guess: BoundedCurvatureGuess, + ) -> (DynResult, DynResult) { + let m = intern!(self.obj.py(), "curvature_bound_components"); + match process_error::<(Option, Option)>( + "curvature_bound_components", + self.obj.py(), + self.obj.call_method0(m).and_then(|r| r.extract()), + ) { + Ok((l, θ2)) => ( + l.ok_or_else(|| anyhow!("l is None")), + θ2.ok_or_else(|| anyhow!("θ2 is None")), + ), + Err(e) => ( + Err(anyhow!("(same error as second component)")), + Err(e.into()), + ), + } + } +} + +impl<'py, Domain, Marker> + StepLengthBound>> for RadonSquared +where + Domain: Space, +{ + fn step_length_bound( + &self, + f: &PythonMapping<'py, Domain, f64, Differentiable>, + ) -> DynResult { + let m = intern!(f.obj.py(), "diff_lipschitz_factor"); + process_error( + "PythonMapping::diff_lipschitz_factor", + f.obj.py(), + f.obj.call_method0(m).and_then(|r| r.extract()), + ) + } +} + +impl<'py, 'a, Domain, Marker, Z> + StepLengthBoundPair>> + for Pair<&'a RadonSquared, &'a IdOp> +where + Domain: Space, +{ + fn step_length_bound_pair( + &self, + f: &PythonMapping<'py, Domain, f64, Differentiable>, + ) -> DynResult<(f64, f64)> { + let m = intern!(f.obj.py(), "diff_lipschitz_factor_pair"); + process_error( + "PythonMapping::diff_lipschitz_factor_pair", + f.obj.py(), + f.obj.call_method0(m).and_then(|r| r.extract()), + ) + } +} + +#[derive(Copy, Debug, Clone)] +/// This is a marker type for identifying the derivative codomain of a differentiable +/// [`PythonMapping`]. Ideally we would just use, e.g., +/// ``` +/// PythonMapping>> +/// ``` +/// to encode the derivative codomain into the type. However `DolfinxPyFunction_f64<2,2,1>` is +/// not [´Send`] so [`pyo3`] does not allow such a mapping in a `pyclass`. We don't actually +/// ever pass a `DolfinxPyFunction_f64` to Python—it's a Rust wrapper for a PyObject—but pyo3, +/// not knowing better, disallows the type from even appearing in signatures. (It's ok to +/// `pyo3` as `DifferentiableImpl::Domain`, though. That's why we need this marker in place +/// of the real type: +/// ``` +/// PythonMapping>> +/// ``` +pub struct DolfinxPyFunctionMarker; +pub struct NumpyArrayMarker(Ix); + +//impl_differentiable!(DolfinxPyFunction_f64<1,2,1>, DolfinxPyFunctionMarker<1,2,1>); +impl_differentiable!(DolfinxPyFunction_f64<'py, 2,2,1>, DolfinxPyFunctionMarker<2,2,1>); +impl_differentiable!( + Pair, NumpyArray_f64<'py, Ix1>>, + Pair, NumpyArrayMarker> +); +impl_differentiable!( + Pair, NumpyArray_f64<'py, Ix2>>, + Pair, NumpyArrayMarker> +); +impl_differentiable!( + Pair, Pair>>, + Pair, Pair>> +); +impl_differentiable!( + Pair< + DolfinxPyFunction_f64<'py, 2, 2, 1>, + Pair, DolfinxPyFunction_f64<'py, 2, 2, 1>>>, + >, + Pair< + DolfinxPyFunctionMarker<2, 2, 1>, + Pair, DolfinxPyFunctionMarker<2, 2, 1>>>, + > +); +impl_differentiable!( + Pair< + DolfinxPyFunction_f64<'py, 2, 2, 1>, + Pair< + NumpyArray_f64<'py, Ix2>, + Pair, DolfinxPyFunction_f64<'py, 2, 2, 1>>, + >, + >, + Pair< + DolfinxPyFunctionMarker<2, 2, 1>, + Pair< + NumpyArrayMarker, + Pair, DolfinxPyFunctionMarker<2, 2, 1>>, + >, + > +); +impl_differentiable!( + Pair< + DolfinxPyFunction_f64<'py, 2, 2, 1>, + Pair< + DolfinxPyFunction_f64<'py, 2, 2, 1>, + Pair, DolfinxPyFunction_f64<'py, 2, 2, 1>>, + >, + >, + Pair< + DolfinxPyFunctionMarker<2, 2, 1>, + Pair< + DolfinxPyFunctionMarker<2, 2, 1>, + Pair, DolfinxPyFunctionMarker<2, 2, 1>>, + >, + > +);