src/python_access/diff_mapping.rs

Thu, 26 Feb 2026 09:32:12 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 26 Feb 2026 09:32:12 -0500
changeset 1
a4137aedcb3a
child 2
69002abe5dcb
child 3
c3a4f4bb87f7
permissions
-rw-r--r--

Initial working version for known convectivity and diffusivity

/*!
[`DifferentiableMapping`]s implemented in Python.
*/
use super::{process_error, NumpyArray_f64};
use crate::dolfinx_access::DolfinxPyFunction_f64;
use alg_tools::direct_product::Pair;
use alg_tools::error::DynResult;
use alg_tools::linops::IdOp;
use alg_tools::mapping::{ClosedSpace, DifferentiableImpl, Instance, Mapping, Space};
use anyhow::anyhow;
use ndarray::Dimension;
use numpy::{Ix1, Ix2};
use pointsource_algs::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
use pointsource_algs::prox_penalty::{RadonSquared, StepLengthBound, StepLengthBoundPair};
use pyo3::conversion::FromPyObject;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::PyClass;
use std::marker::PhantomData;

#[derive(Copy, Debug, Clone)]
/// Marker for differentiable PythonMappings
pub struct Differentiable<DerivativeDomainMarker>(DerivativeDomainMarker);

#[derive(Copy, Debug, Clone)]
/// Marker for PythonMappings without further properties.
pub struct Basic;

#[derive(Debug)]
pub struct PythonMapping<'py, Domain, Codomain, Marker>
where
    Domain: Space,
    Codomain: Space,
{
    pub(super) obj: Bound<'py, PyAny>,
    pub(super) _phantoms: PhantomData<(Domain, Codomain, Marker)>,
}

macro_rules! intern_many {
    ($py:expr, $($method:literal),*) => {[ $(
        intern!($py, $method)
    ),*]}
}

// Export for super.
pub(super) use intern_many;

impl<'a, 'py, Domain, Codomain> FromPyObject<'a, 'py>
    for PythonMapping<'py, Domain, Codomain, Basic>
where
    Domain: Space,
    Codomain: Space + PyClass + FromPyObject<'a, 'py>,
{
    type Error = PyErr;

    fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
        let obj = obj_.to_owned();
        // Verify that the necessary methods exist
        for method in intern_many!(obj.py(), "apply", "diff_lipschitz_factor") {
            obj.getattr(method)?; //.downcast::<PyFunction>()?;
        }
        Ok(PythonMapping { obj, _phantoms: PhantomData })
    }
}

impl<'a, 'py, Domain, Codomain, DerivativeMarker> FromPyObject<'a, 'py>
    for PythonMapping<'py, Domain, Codomain, Differentiable<DerivativeMarker>>
where
    Domain: Space,
    Codomain: Space + FromPyObject<'a, 'py>,
{
    type Error = PyErr;

    fn extract(obj_: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
        let obj = obj_.to_owned();
        // Verify that the necessary methods exist
        for method in intern_many!(obj.py(), "apply", "diff", "diff_lipschitz_factor") {
            obj.getattr(method)?; //.downcast::<PyFunction>()?;
        }
        Ok(PythonMapping { obj, _phantoms: PhantomData })
    }
}

impl<'py, Domain, Codomain, AnyMarker> Mapping<Domain>
    for PythonMapping<'py, Domain, Codomain, AnyMarker>
where
    Domain: Space,
    Domain::Principal: IntoPyObject<'py>,
    Codomain: ClosedSpace + for<'a> FromPyObject<'a, 'py, Error = PyErr>,
{
    type Codomain = Codomain;

    /// Compute the value of `self` at `x`.
    fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain {
        // TODO: use references and internal mutability?
        //x_py = x.own().to_python(py).unwrap();
        let apply = intern!(self.obj.py(), "apply");
        process_error(
            "PythonMapping::apply",
            self.obj.py(),
            self.obj
                .call_method1(apply, (x.own(),))
                .and_then(|r| r.extract()),
        )
        .unwrap()
    }
}

impl<'py, Domain, Codomain, AnyMarker> PythonMapping<'py, Domain, Codomain, AnyMarker>
where
    Domain: Space,
    Codomain: Space + for<'a> FromPyObject<'a, 'py>,
{
    pub(crate) fn get_obj(&self) -> Bound<'py, PyAny> {
        self.obj.clone()
    }
}

macro_rules! impl_differentiable {
    ($derivative:ty, $marker:ty) => {
        impl<'py, Domain> DifferentiableImpl<Domain>
            for PythonMapping<'py, Domain, f64, Differentiable<$marker>>
        where
            Domain: Space,
            Domain::Principal: IntoPyObject<'py>,
            //$derivative: Space + for<'py> FromPyObject<'py>,
            //f64 : for<'py> FromPyObject<'py>,
        {
            type Derivative = $derivative;

            /// Compute the value of `self` at `x`.
            fn differential_impl<I: Instance<Domain>>(&self, x: I) -> $derivative {
                // TODO: use references and internal mutability?
                //x_py = x.own().to_python(py).unwrap();
                let diff = intern!(self.obj.py(), "diff");
                process_error(
                    "PythonMapping::differential_impl",
                    self.obj.py(),
                    self.obj
                        .call_method1(diff, (x.own(),))
                        .and_then(|r| r.extract()),
                )
                .unwrap()
            }

            fn apply_and_differential_impl<I: Instance<Domain>>(
                &self,
                x: I,
            ) -> (<Self as Mapping<Domain>>::Codomain, $derivative) {
                // TODO: use references and internal mutability?
                //x_py = x.own().to_python(py).unwrap();
                let apply_and_diff = intern!(self.obj.py(), "apply_and_diff");
                process_error(
                    "PythonMapping::apply_and_differential_impl",
                    self.obj.py(),
                    self.obj
                        .call_method1(apply_and_diff, (x.own(),))
                        .and_then(|r| r.extract()),
                )
                .unwrap()
            }
        }
    };
}

impl<'py, Domain, Marker> BoundedCurvature<f64>
    for PythonMapping<'py, Domain, f64, Differentiable<Marker>>
where
    Domain: Space,
{
    fn curvature_bound_components(
        &self,
        _guess: BoundedCurvatureGuess,
    ) -> (DynResult<f64>, DynResult<f64>) {
        let m = intern!(self.obj.py(), "curvature_bound_components");
        match process_error::<(Option<f64>, Option<f64>)>(
            "curvature_bound_components",
            self.obj.py(),
            self.obj.call_method0(m).and_then(|r| r.extract()),
        ) {
            Ok((l, θ2)) => (
                l.ok_or_else(|| anyhow!("l is None")),
                θ2.ok_or_else(|| anyhow!("θ2 is None")),
            ),
            Err(e) => (
                Err(anyhow!("(same error as second component)")),
                Err(e.into()),
            ),
        }
    }
}

impl<'py, Domain, Marker>
    StepLengthBound<f64, PythonMapping<'py, Domain, f64, Differentiable<Marker>>> for RadonSquared
where
    Domain: Space,
{
    fn step_length_bound(
        &self,
        f: &PythonMapping<'py, Domain, f64, Differentiable<Marker>>,
    ) -> DynResult<f64> {
        let m = intern!(f.obj.py(), "diff_lipschitz_factor");
        process_error(
            "PythonMapping::diff_lipschitz_factor",
            f.obj.py(),
            f.obj.call_method0(m).and_then(|r| r.extract()),
        )
    }
}

impl<'py, 'a, Domain, Marker, Z>
    StepLengthBoundPair<f64, PythonMapping<'py, Domain, f64, Differentiable<Marker>>>
    for Pair<&'a RadonSquared, &'a IdOp<Z>>
where
    Domain: Space,
{
    fn step_length_bound_pair(
        &self,
        f: &PythonMapping<'py, Domain, f64, Differentiable<Marker>>,
    ) -> DynResult<(f64, f64)> {
        let m = intern!(f.obj.py(), "diff_lipschitz_factor_pair");
        process_error(
            "PythonMapping::diff_lipschitz_factor_pair",
            f.obj.py(),
            f.obj.call_method0(m).and_then(|r| r.extract()),
        )
    }
}

#[derive(Copy, Debug, Clone)]
/// This is a marker type for identifying the derivative codomain of a differentiable
/// [`PythonMapping`]. Ideally we would just use, e.g.,
/// ```
/// PythonMapping<Domain, f64, Differentiable<DolfinxPyFunction_f64<2,2,1>>>
/// ```
/// to encode the derivative codomain into the type. However `DolfinxPyFunction_f64<2,2,1>` is
/// not [´Send`] so [`pyo3`] does not allow such a mapping in a `pyclass`. We don't actually
/// ever pass a `DolfinxPyFunction_f64` to Python—it's a Rust wrapper for a PyObject—but pyo3,
/// not knowing better, disallows the type from even appearing in signatures. (It's ok to
/// `pyo3` as `DifferentiableImpl::Domain`, though. That's why we  need this marker in place
/// of the real type:
/// ```
/// PythonMapping<Domain, f64, Differentiable<DolfinxPyFunctionMarker<2,2,1>>>
/// ```
pub struct DolfinxPyFunctionMarker<const N: usize, const O: usize, const D: usize>;
pub struct NumpyArrayMarker<Ix: Dimension>(Ix);

//impl_differentiable!(DolfinxPyFunction_f64<1,2,1>, DolfinxPyFunctionMarker<1,2,1>);
impl_differentiable!(DolfinxPyFunction_f64<'py, 2,2,1>, DolfinxPyFunctionMarker<2,2,1>);
impl_differentiable!(
    Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, NumpyArray_f64<'py, Ix1>>,
    Pair<DolfinxPyFunctionMarker<2, 2, 1>, NumpyArrayMarker<Ix1>>
);
impl_differentiable!(
    Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, NumpyArray_f64<'py, Ix2>>,
    Pair<DolfinxPyFunctionMarker<2, 2, 1>, NumpyArrayMarker<Ix2>>
);
impl_differentiable!(
    Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, Pair<f64, Pair<f64, f64>>>,
    Pair<DolfinxPyFunctionMarker<2, 2, 1>, Pair<f64, Pair<f64, f64>>>
);
impl_differentiable!(
    Pair<
        DolfinxPyFunction_f64<'py, 2, 2, 1>,
        Pair<f64, Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>>,
    >,
    Pair<
        DolfinxPyFunctionMarker<2, 2, 1>,
        Pair<f64, Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>>,
    >
);
impl_differentiable!(
    Pair<
        DolfinxPyFunction_f64<'py, 2, 2, 1>,
        Pair<
            NumpyArray_f64<'py, Ix2>,
            Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>,
        >,
    >,
    Pair<
        DolfinxPyFunctionMarker<2, 2, 1>,
        Pair<
            NumpyArrayMarker<Ix2>,
            Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>,
        >,
    >
);
impl_differentiable!(
    Pair<
        DolfinxPyFunction_f64<'py, 2, 2, 1>,
        Pair<
            DolfinxPyFunction_f64<'py, 2, 2, 1>,
            Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>,
        >,
    >,
    Pair<
        DolfinxPyFunctionMarker<2, 2, 1>,
        Pair<
            DolfinxPyFunctionMarker<2, 2, 1>,
            Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>,
        >,
    >
);

mercurial