src/experiments.rs

Fri, 08 May 2026 17:16:34 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Fri, 08 May 2026 17:16:34 -0500
changeset 4
49b062acace9
parent 3
c3a4f4bb87f7
permissions
-rw-r--r--

Do not directly depend on ndarray, but through numpy

/*!
Access to experiments provided in Python code.
*/

use crate::dolfinx_access::DolfinxPyFunction_f64;
use crate::python_access::{
    process_error, Differentiable, HasProx, PythonMapping, PythonPlotFactory,
};
use alg_tools::direct_product::Pair;
use alg_tools::error::{DynError, DynResult};
use alg_tools::loc::Loc;
use anyhow::bail;
use clap::Parser;
use log::debug;
use pointsource_algs::measures::DiscreteMeasure;
use pointsource_algs::prox_penalty::RadonSquared;
use pointsource_algs::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation};
use pointsource_algs::run::{
    run_fb, run_fb_pair, AlgorithmConfig, DefaultAlgorithm, Named, RunError::NotImplemented,
    RunnableExperiment, RunnableExperimentExtras,
};
use pointsource_algs::{AlgorithmOverrides, CommandLineArgs, ExperimentSetup};
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyFunction};
use serde::{Deserialize, Serialize};
use serde_json;
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use std::ffi::CString;
use std::path::Path;

/// Command line experiment setup overrides
#[skip_serializing_none]
#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
#[pyclass(module = "pointsource_pde")]
pub struct Experiments {
    /// List of Python setup files of experiments to perform
    #[arg(value_name = "EXPERIMENT.PY")]
    #[pyo3(get, set)]
    experiments: Vec<String>,

    #[arg(long)]
    #[pyo3(get, set)]
    /// Regularisation parameter override.
    ///
    /// Only use if running just a single experiment, as different experiments have different
    /// regularisation parameters.
    alpha: Option<f64>,

    #[arg(long)]
    #[pyo3(get, set)]
    /// Noise variance override
    variance: Option<f64>,
}

/// An experiment implemented in Python code
#[derive(Serialize)]
pub struct PythonExperiment {
    /// Name of the experiment
    name: String,
    /// Source file of experiment
    filename: String,
    /// The setup function
    #[serde(skip_serializing)]
    setup: Py<PyFunction>,
    /// The Python module
    #[serde(skip_serializing)]
    #[allow(unused)]
    module: Py<PyModule>,
    /// Algorithm overrides
    algorithm_overrides: HashMap<DefaultAlgorithm, AlgorithmOverrides<f64>>,
}

impl ExperimentSetup for Experiments {
    type FloatType = f64;

    fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>> {
        debug!("Loading runnable experiments. Attaching Python.");

        Python::attach(|py| {
            debug!("Loading json Python module.");

            // Load the "json" Python module, and "dumps" from it.
            let dumps = PyModule::import(py, "json")?.getattr("dumps")?;

            // Load the Python file describing each experiment, and extract
            // fields to file `PythonExperiment`.
            self.experiments
                .iter()
                .cloned()
                .map(|filename| -> DynResult<Box<dyn RunnableExperiment<_>>> {
                    debug!("Trying to load experiment from {filename}");
                    let code = std::fs::read_to_string(&filename)?;
                    let modname = filename.as_bytes();
                    // Add path of this file to module search path, to allow local dependencies
                    let parent = Path::new(&filename).parent().unwrap().to_str().unwrap();
                    let locals = PyDict::new(py);
                    locals.set_item("this_path", parent)?;
                    py.run(
                        c_str!("import sys; sys.path.insert(0, this_path)"),
                        None,
                        Some(&locals),
                    )?;
                    // Load module
                    let module = PyModule::from_code(
                        py,
                        CString::new(code)?.as_ref(),
                        CString::new(filename.as_str())?.as_ref(),
                        CString::new(modname)?.as_ref(),
                    )?;
                    let name = module.getattr("name")?.extract()?;
                    let setup = module
                        .getattr("setup")?
                        .cast_into() // Check that a Pyfunction
                        .map_err(PyErr::from)? // Can't return a DowncastIntoError
                        .unbind();
                    let algorithm_overrides = match module.getattr("algorithm_overrides") {
                        Err(_) => Default::default(),
                        Ok(o) => {
                            // This passes through JSON and Serde as pyo3 does not allow
                            // generics in AlgorithmOverides<_> - not even instantiating
                            // them to specific values. It would be interesting to write
                            // a proper serde Deserializer to not have to pass through JSON.
                            // One day.
                            serde_json::from_str(dumps.call1((o,))?.extract()?)?
                        }
                    };
                    debug!("… Found {name} with algorithm overrides {algorithm_overrides:?}");
                    Ok(Box::new(PythonExperiment {
                        name,
                        filename,
                        setup,
                        algorithm_overrides,
                        module: module.unbind(),
                    }))
                })
                .try_collect()
        })
    }
}

/// Types for experiments in two dimensions with `f64` floats.
#[allow(non_camel_case_types)]
#[allow(unused)]
mod exp_f64_2 {
    use super::*;
    use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker};

    /// Type of unknowns
    pub type Domain = DiscreteMeasure<Loc<2, f64>, f64>;
    /// Type of derivatives of objective function
    pub type DerivativeCodomain<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>;
    //pub type AuxTODOSpecifyFlexiblise<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>;
    pub type DerivativeMarker = DolfinxPyFunctionMarker<2, 2, 1>;
    /// Type of data terms
    pub type DataTerm<'py> = PythonMapping<'py, Domain, f64, Differentiable<DerivativeMarker>>;
    /// Plotter
    pub type PlotFactory<'py> =
        PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>;
}

/// Types for experiments in two dimensions with `f64` floats and aux
#[allow(non_camel_case_types, non_snake_case)]
#[allow(unused)]
mod exp_f64_2_auxvar {
    use super::*;
    use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker};
    use numpy::Ix2;

    /// Type of unknowns
    pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>;
    /// Type of derivatives of objective function
    pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>;
    /// Type of derivatives wrt. auxiliary variables.
    pub type AuxVar<'py> = Pair<
        DolfinxPyFunction_f64<'py, 2, 2, 1>,
        Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>,
    >;
    pub type DerivativeMarker = Pair<
        DolfinxPyFunctionMarker<2, 2, 1>,
        Pair<
            DolfinxPyFunctionMarker<2, 2, 1>,
            Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>,
        >,
    >;
    /// Type of data terms
    pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>;
    /// Auxiliary objective
    pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>;
    /// Plotter
    pub type PlotFactory<'py> =
        PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>;
}

/// Types for experiments in two dimensions with `f64` floats and scalar aux
#[allow(non_camel_case_types, non_snake_case)]
#[allow(unused)]
mod exp_f64_2_auxvar_scalar {
    use super::*;
    use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker};
    use numpy::Ix2;

    /// Type of unknowns
    pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>;
    /// Type of derivatives of objective function
    pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>;
    /// Type of derivatives wrt. auxiliary variables.
    pub type AuxVar<'py> = Pair<f64, Pair<f64, f64>>;
    pub type DerivativeMarker = Pair<DolfinxPyFunctionMarker<2, 2, 1>, Pair<f64, Pair<f64, f64>>>;
    /// Type of data terms
    pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>;
    /// Auxiliary objective
    pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>;
    /// Plotter
    pub type PlotFactory<'py> =
        PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>;
}

/// Available regularisation terms, exported to the Python side
#[pyclass(module = "pointsource_pde", name = "RegTerm")]
#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum RegTermPy {
    /// Radon norm with weight `α`.
    Radon(f64),
    /// Radon norm ith weight `α` and a positivity constraint.
    NonnegRadon(f64),
}

impl From<RegTermPy> for Regularisation {
    fn from(reg: RegTermPy) -> Regularisation {
        match reg {
            RegTermPy::Radon(α) => Regularisation::Radon(α),
            RegTermPy::NonnegRadon(α) => Regularisation::NonnegRadon(α),
        }
    }
}

/// Problem description, exported to the Python side to be filled there.
#[pyclass(module = "pointsource_pde")]
#[derive(Debug)]
pub struct Problem {
    /// Data term. On the python side, this should be a `class` that implements
    /// `apply` and `diff`.
    #[pyo3(set)]
    dataterm: Py<PyAny>,

    /// Regularisation
    #[pyo3(set, get)]
    regterm: RegTermPy,

    /// Auxiliary variable regularisation term or similar
    #[pyo3(set, get)]
    auxterm: Option<Py<PyAny>>,

    /// Initial auxiliary variable
    #[pyo3(set, get)]
    auxinit: Option<Py<PyAny>>,

    /// Initial measure
    #[pyo3(set, get)]
    μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>,

    /// Plotter
    #[pyo3(set, get)]
    plot_factory: Option<Py<PyAny>>,
}

#[pymethods]
impl Problem {
    #[new]
    #[pyo3(signature = (dataterm, regterm, auxterm = None, auxinit = None, μinit = None, plot_factory=None))]
    fn new(
        dataterm: Py<PyAny>,
        regterm: RegTermPy,
        auxterm: Option<Py<PyAny>>,
        auxinit: Option<Py<PyAny>>,
        μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>,
        plot_factory: Option<Py<PyAny>>,
    ) -> Self {
        Self { dataterm, regterm, auxterm, auxinit, μinit, plot_factory }
    }

    #[getter]
    fn get_dataterm<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> {
        self.dataterm.bind(py).clone()
    }
}

impl RunnableExperiment<f64> for PythonExperiment {
    fn name(&self) -> &str {
        &self.name
    }

    fn runall(
        &self,
        cli: &CommandLineArgs,
        algs: Option<Vec<Named<AlgorithmConfig<f64>>>>,
    ) -> DynError {
        // Set up algorithms
        let algorithms = algs.unwrap_or_else(|| vec![DefaultAlgorithm::FB.get_named()]);

        debug!(
            "Entered PythonExperiment::runall for experimen {}. Attaching Python.",
            self.name
        );

        Python::attach(|py| {
            debug!("Calling Python-side experiment initialisation.");

            let prefix = self.start(cli)?;

            let problem: PyRef<Problem> = process_error(
                "PythonExperiment::runall",
                py,
                self.setup
                    .call1(py, (&prefix,))
                    .and_then(|r| r.extract(py).map_err(PyErr::from)),
            )?;

            let save_extra = |_, ()| Ok(());
            let μ0 = problem.μinit.clone();

            match (problem.regterm, &problem.auxterm, &problem.auxinit) {
                (regterm, None, None) => {
                    debug!("… Extracting data term.");
                    let dataterm: exp_f64_2::DataTerm<'_> = problem.dataterm.extract(py)?;
                    debug!("… Extracting plotter.");
                    let plot_factory: exp_f64_2::PlotFactory<'_> =
                        if let Some(ref p) = problem.plot_factory {
                            p.extract(py)?
                        } else {
                            PythonPlotFactory::dummy()
                        };
                    let make_plotter = |prefix| plot_factory.produce(prefix);

                    debug!("Running.");
                    self.do_runall(
                        &prefix,
                        cli,
                        algorithms,
                        make_plotter,
                        save_extra,
                        μ0,
                        |p| match regterm {
                            RegTermPy::Radon(α) => {
                                run_fb(&dataterm, &RadonRegTerm(α), &RadonSquared, p, |_| {
                                    Err(NotImplemented.into())
                                })
                                .map(|μ| (μ, ()))
                            }
                            RegTermPy::NonnegRadon(α) => {
                                run_fb(&dataterm, &NonnegRadonRegTerm(α), &RadonSquared, p, |_| {
                                    Err(NotImplemented.into())
                                })
                                .map(|μ| (μ, ()))
                            }
                        },
                    )
                }
                (regterm, Some(auxterm), Some(z_)) => {
                    debug!("… Extracting auxiliary variable.");
                    let z0: PyResult<exp_f64_2_auxvar::AuxVar<'_>> = z_.extract(py);
                    match z0 {
                        Ok(z) => {
                            debug!("… Extracting data term.");
                            let dataterm: exp_f64_2_auxvar::DataTerm<'_> =
                                problem.dataterm.extract(py)?;
                            debug!("… Extracting plotter.");
                            let plot_factory: exp_f64_2::PlotFactory<'_> =
                                if let Some(ref p) = problem.plot_factory {
                                    p.extract(py)?
                                } else {
                                    PythonPlotFactory::dummy()
                                };
                            let make_plotter = |prefix| plot_factory.produce(prefix);
                            debug!("… Extracting auxiliary term.");
                            let auxterm: exp_f64_2_auxvar::AuxTerm<'_> = auxterm.extract(py)?;

                            debug!("Running.");
                            self.do_runall(
                                &prefix,
                                cli,
                                algorithms,
                                make_plotter,
                                save_extra,
                                (μ0, z),
                                |p| match regterm {
                                    RegTermPy::Radon(α) => run_fb_pair(
                                        &dataterm,
                                        &RadonRegTerm(α),
                                        &RadonSquared,
                                        &auxterm,
                                        p,
                                        |_| Err(NotImplemented.into()),
                                    )
                                    .map(|Pair(μ, _)| (μ, ())),
                                    RegTermPy::NonnegRadon(α) => run_fb_pair(
                                        &dataterm,
                                        &NonnegRadonRegTerm(α),
                                        &RadonSquared,
                                        &auxterm,
                                        p,
                                        |_| Err(NotImplemented.into()),
                                    )
                                    .map(|Pair(μ, _)| (μ, ())),
                                },
                            )
                        }
                        Err(_) => {
                            let z: exp_f64_2_auxvar_scalar::AuxVar<'_> = z_.extract(py)?;
                            debug!("… Extracting data term.");
                            let dataterm: exp_f64_2_auxvar_scalar::DataTerm<'_> =
                                problem.dataterm.extract(py)?;
                            debug!("… Extracting plotter.");
                            let plot_factory: exp_f64_2::PlotFactory<'_> =
                                if let Some(ref p) = problem.plot_factory {
                                    p.extract(py)?
                                } else {
                                    PythonPlotFactory::dummy()
                                };
                            let make_plotter = |prefix| plot_factory.produce(prefix);
                            debug!("… Extracting auxiliary term.");
                            let auxterm: exp_f64_2_auxvar_scalar::AuxTerm<'_> =
                                auxterm.extract(py)?;

                            debug!("Running.");
                            self.do_runall(
                                &prefix,
                                cli,
                                algorithms,
                                make_plotter,
                                save_extra,
                                (μ0, z),
                                |p| match regterm {
                                    RegTermPy::Radon(α) => run_fb_pair(
                                        &dataterm,
                                        &RadonRegTerm(α),
                                        &RadonSquared,
                                        &auxterm,
                                        p,
                                        |_| Err(NotImplemented.into()),
                                    )
                                    .map(|Pair(μ, _)| (μ, ())),
                                    RegTermPy::NonnegRadon(α) => run_fb_pair(
                                        &dataterm,
                                        &NonnegRadonRegTerm(α),
                                        &RadonSquared,
                                        &auxterm,
                                        p,
                                        |_| Err(NotImplemented.into()),
                                    )
                                    .map(|Pair(μ, _)| (μ, ())),
                                },
                            )
                        }
                    }
                }
                (_, _, _) => {
                    bail!("Errors in problem {} setup: auxiliary term or auxiliary variable initialisation given without the other", self.name);
                }
            }
        })?;

        // Should run the experiment here, going through all algorithms.
        Ok(())
    }

    fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<f64> {
        self.algorithm_overrides
            .get(&alg)
            .cloned()
            .unwrap_or(Default::default())
    }
}

mercurial