src/experiments.rs

changeset 1
a4137aedcb3a
child 3
c3a4f4bb87f7
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/experiments.rs	Thu Feb 26 09:32:12 2026 -0500
@@ -0,0 +1,474 @@
+/*!
+Access to experiments provided in Python code.
+*/
+
+use crate::dolfinx_access::DolfinxPyFunction_f64;
+use crate::python_access::{
+    process_error, Differentiable, HasProx, PythonMapping, PythonPlotFactory,
+};
+use alg_tools::direct_product::Pair;
+use alg_tools::error::{DynError, DynResult};
+use alg_tools::loc::Loc;
+use anyhow::bail;
+use clap::Parser;
+use log::debug;
+use pointsource_algs::measures::DiscreteMeasure;
+use pointsource_algs::prox_penalty::RadonSquared;
+use pointsource_algs::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation};
+use pointsource_algs::run::{
+    run_fb, run_fb_pair, AlgorithmConfig, DefaultAlgorithm, Named, RunError::NotImplemented,
+    RunnableExperiment, RunnableExperimentExtras,
+};
+use pointsource_algs::{AlgorithmOverrides, CommandLineArgs, ExperimentSetup};
+use pyo3::ffi::c_str;
+use pyo3::prelude::*;
+use pyo3::types::{PyDict, PyFunction};
+use serde::{Deserialize, Serialize};
+use serde_json;
+use serde_with::skip_serializing_none;
+use std::collections::HashMap;
+use std::ffi::CString;
+use std::path::Path;
+
+/// Command line experiment setup overrides
+#[skip_serializing_none]
+#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
+#[pyclass(module = "pointsource_pde")]
+pub struct Experiments {
+    /// List of Python setup files of experiments to perform
+    #[arg(value_name = "EXPERIMENT.PY")]
+    #[pyo3(get, set)]
+    experiments: Vec<String>,
+
+    #[arg(long)]
+    #[pyo3(get, set)]
+    /// Regularisation parameter override.
+    ///
+    /// Only use if running just a single experiment, as different experiments have different
+    /// regularisation parameters.
+    alpha: Option<f64>,
+
+    #[arg(long)]
+    #[pyo3(get, set)]
+    /// Noise variance override
+    variance: Option<f64>,
+}
+
+/// An experiment implemented in Python code
+#[derive(Serialize)]
+pub struct PythonExperiment {
+    /// Name of the experiment
+    name: String,
+    /// Source file of experiment
+    filename: String,
+    /// The setup function
+    #[serde(skip_serializing)]
+    setup: Py<PyFunction>,
+    /// The Python module
+    #[serde(skip_serializing)]
+    #[allow(unused)]
+    module: Py<PyModule>,
+    /// Algorithm overrides
+    algorithm_overrides: HashMap<DefaultAlgorithm, AlgorithmOverrides<f64>>,
+}
+
+impl ExperimentSetup for Experiments {
+    type FloatType = f64;
+
+    fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>> {
+        debug!("Loading runnable experiments. Attaching Python.");
+
+        Python::attach(|py| {
+            debug!("Loading json Python module.");
+
+            // Load the "json" Python module, and "dumps" from it.
+            let dumps = PyModule::import(py, "json")?.getattr("dumps")?;
+
+            // Load the Python file describing each experiment, and extract
+            // fields to file `PythonExperiment`.
+            self.experiments
+                .iter()
+                .cloned()
+                .map(|filename| -> DynResult<Box<dyn RunnableExperiment<_>>> {
+                    debug!("Trying to load experiment from {filename}");
+                    let code = std::fs::read_to_string(&filename)?;
+                    let modname = filename.as_bytes();
+                    // Add path of this file to module search path, to allow local dependencies
+                    let parent = Path::new(&filename).parent().unwrap().to_str().unwrap();
+                    let locals = PyDict::new(py);
+                    locals.set_item("this_path", parent)?;
+                    py.run(
+                        c_str!("import sys; sys.path.insert(0, this_path)"),
+                        None,
+                        Some(&locals),
+                    )?;
+                    // Load module
+                    let module = PyModule::from_code(
+                        py,
+                        CString::new(code)?.as_ref(),
+                        CString::new(filename.as_str())?.as_ref(),
+                        CString::new(modname)?.as_ref(),
+                    )?;
+                    let name = module.getattr("name")?.extract()?;
+                    let setup = module
+                        .getattr("setup")?
+                        .cast_into() // Check that a Pyfunction
+                        .map_err(PyErr::from)? // Can't return a DowncastIntoError
+                        .unbind();
+                    let algorithm_overrides = match module.getattr("algorithm_overrides") {
+                        Err(_) => Default::default(),
+                        Ok(o) => {
+                            // This passes through JSON and Serde as pyo3 does not allow
+                            // generics in AlgorithmOverides<_> - not even instantiating
+                            // them to specific values. It would be interesting to write
+                            // a proper serde Deserializer to not have to pass through JSON.
+                            // One day.
+                            serde_json::from_str(dumps.call1((o,))?.extract()?)?
+                        }
+                    };
+                    debug!("… Found {name} with algorithm overrides {algorithm_overrides:?}");
+                    Ok(Box::new(PythonExperiment {
+                        name,
+                        filename,
+                        setup,
+                        algorithm_overrides,
+                        module: module.unbind(),
+                    }))
+                })
+                .try_collect()
+        })
+    }
+}
+
+/// Types for experiments in two dimensions with `f64` floats.
+#[allow(non_camel_case_types)]
+#[allow(unused)]
+mod exp_f64_2 {
+    use super::*;
+    use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker};
+
+    /// Type of unknowns
+    pub type Domain = DiscreteMeasure<Loc<2, f64>, f64>;
+    /// Type of derivatives of objective function
+    pub type DerivativeCodomain<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>;
+    //pub type AuxTODOSpecifyFlexiblise<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>;
+    pub type DerivativeMarker = DolfinxPyFunctionMarker<2, 2, 1>;
+    /// Type of data terms
+    pub type DataTerm<'py> = PythonMapping<'py, Domain, f64, Differentiable<DerivativeMarker>>;
+    /// Plotter
+    pub type PlotFactory<'py> =
+        PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>;
+}
+
+/// Types for experiments in two dimensions with `f64` floats and aux
+#[allow(non_camel_case_types, non_snake_case)]
+#[allow(unused)]
+mod exp_f64_2_auxvar {
+    use super::*;
+    use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker};
+    use numpy::Ix2;
+
+    /// Type of unknowns
+    pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>;
+    /// Type of derivatives of objective function
+    pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>;
+    /// Type of derivatives wrt. auxiliary variables.
+    pub type AuxVar<'py> = Pair<
+        DolfinxPyFunction_f64<'py, 2, 2, 1>,
+        Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>,
+    >;
+    pub type DerivativeMarker = Pair<
+        DolfinxPyFunctionMarker<2, 2, 1>,
+        Pair<
+            DolfinxPyFunctionMarker<2, 2, 1>,
+            Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>,
+        >,
+    >;
+    /// Type of data terms
+    pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>;
+    /// Auxiliary objective
+    pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>;
+    /// Plotter
+    pub type PlotFactory<'py> =
+        PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>;
+}
+
+/// Types for experiments in two dimensions with `f64` floats and scalar aux
+#[allow(non_camel_case_types, non_snake_case)]
+#[allow(unused)]
+mod exp_f64_2_auxvar_scalar {
+    use super::*;
+    use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker};
+    use numpy::Ix2;
+
+    /// Type of unknowns
+    pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>;
+    /// Type of derivatives of objective function
+    pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>;
+    /// Type of derivatives wrt. auxiliary variables.
+    pub type AuxVar<'py> = Pair<f64, Pair<f64, f64>>;
+    pub type DerivativeMarker = Pair<DolfinxPyFunctionMarker<2, 2, 1>, Pair<f64, Pair<f64, f64>>>;
+    /// Type of data terms
+    pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>;
+    /// Auxiliary objective
+    pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>;
+    /// Plotter
+    pub type PlotFactory<'py> =
+        PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>;
+}
+
+#[pyclass(module = "pointsource_pde", name = "RegTerm")]
+#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq)]
+pub enum RegTermPy {
+    // Radon norm with weight `α`.
+    Radon(f64),
+    // Radon norm ith weight `α` and a positivity constraint.
+    NonnegRadon(f64),
+}
+
+impl From<RegTermPy> for Regularisation {
+    fn from(reg: RegTermPy) -> Regularisation {
+        match reg {
+            RegTermPy::Radon(α) => Regularisation::Radon(α),
+            RegTermPy::NonnegRadon(α) => Regularisation::NonnegRadon(α),
+        }
+    }
+}
+
+#[pyclass(module = "pointsource_pde")]
+#[derive(Debug)]
+pub struct Problem {
+    /// Data term. On the python side, this should be a `class` that implements
+    /// `apply` from [`DiscreteMeasure_2_f64`] (TODO: extended parameters) to floats, and `diff`
+    /// from the space space to [`DolfinxPyFunction_f64<2, 1, 2>`].
+    #[pyo3(set)]
+    dataterm: Py<PyAny>, //exp_f64_2::DataTerm,
+
+    // Regularisation
+    #[pyo3(set, get)]
+    regterm: RegTermPy,
+
+    // Auxiliary variable regularisation term or similar
+    #[pyo3(set, get)]
+    auxterm: Option<Py<PyAny>>,
+
+    // Initial auxiliary variable
+    #[pyo3(set, get)]
+    auxinit: Option<Py<PyAny>>,
+
+    // Initial measure
+    #[pyo3(set, get)]
+    μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>,
+
+    // Plotter
+    #[pyo3(set, get)]
+    plot_factory: Option<Py<PyAny>>,
+}
+
+#[pymethods]
+impl Problem {
+    #[new]
+    #[pyo3(signature = (dataterm, regterm, auxterm = None, auxinit = None, μinit = None, plot_factory=None))]
+    fn new(
+        dataterm: Py<PyAny>,
+        regterm: RegTermPy,
+        auxterm: Option<Py<PyAny>>,
+        auxinit: Option<Py<PyAny>>,
+        μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>,
+        plot_factory: Option<Py<PyAny>>,
+    ) -> Self {
+        Self { dataterm, regterm, auxterm, auxinit, μinit, plot_factory }
+    }
+
+    #[getter]
+    fn get_dataterm<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> {
+        self.dataterm.bind(py).clone()
+    }
+}
+
+impl RunnableExperiment<f64> for PythonExperiment {
+    fn name(&self) -> &str {
+        &self.name
+    }
+
+    fn runall(
+        &self,
+        cli: &CommandLineArgs,
+        algs: Option<Vec<Named<AlgorithmConfig<f64>>>>,
+    ) -> DynError {
+        // Set up algorithms
+        let algorithms = algs.unwrap_or_else(|| vec![DefaultAlgorithm::FB.get_named()]);
+
+        debug!(
+            "Entered PythonExperiment::runall for experimen {}. Attaching Python.",
+            self.name
+        );
+
+        Python::attach(|py| {
+            debug!("Calling Python-side experiment initialisation.");
+
+            let prefix = self.start(cli)?;
+
+            let problem: PyRef<Problem> = process_error(
+                "PythonExperiment::runall",
+                py,
+                self.setup
+                    .call1(py, (&prefix,))
+                    .and_then(|r| r.extract(py).map_err(PyErr::from)),
+            )?;
+
+            let save_extra = |_, ()| Ok(());
+            let μ0 = problem.μinit.clone();
+
+            match (problem.regterm, &problem.auxterm, &problem.auxinit) {
+                (regterm, None, None) => {
+                    debug!("… Extracting data term.");
+                    let dataterm: exp_f64_2::DataTerm<'_> = problem.dataterm.extract(py)?;
+                    debug!("… Extracting plotter.");
+                    let plot_factory: exp_f64_2::PlotFactory<'_> =
+                        if let Some(ref p) = problem.plot_factory {
+                            p.extract(py)?
+                        } else {
+                            PythonPlotFactory::dummy()
+                        };
+                    let make_plotter = |prefix| plot_factory.produce(prefix);
+
+                    debug!("Running.");
+                    self.do_runall(
+                        &prefix,
+                        cli,
+                        algorithms,
+                        make_plotter,
+                        save_extra,
+                        μ0,
+                        |p| match regterm {
+                            RegTermPy::Radon(α) => {
+                                run_fb(&dataterm, &RadonRegTerm(α), &RadonSquared, p, |_| {
+                                    Err(NotImplemented.into())
+                                })
+                                .map(|μ| (μ, ()))
+                            }
+                            RegTermPy::NonnegRadon(α) => {
+                                run_fb(&dataterm, &NonnegRadonRegTerm(α), &RadonSquared, p, |_| {
+                                    Err(NotImplemented.into())
+                                })
+                                .map(|μ| (μ, ()))
+                            }
+                        },
+                    )
+                }
+                (regterm, Some(auxterm), Some(z_)) => {
+                    debug!("… Extracting auxiliary variable.");
+                    let z0: PyResult<exp_f64_2_auxvar::AuxVar<'_>> = z_.extract(py);
+                    match z0 {
+                        Ok(z) => {
+                            debug!("… Extracting data term.");
+                            let dataterm: exp_f64_2_auxvar::DataTerm<'_> =
+                                problem.dataterm.extract(py)?;
+                            debug!("… Extracting plotter.");
+                            let plot_factory: exp_f64_2::PlotFactory<'_> =
+                                if let Some(ref p) = problem.plot_factory {
+                                    p.extract(py)?
+                                } else {
+                                    PythonPlotFactory::dummy()
+                                };
+                            let make_plotter = |prefix| plot_factory.produce(prefix);
+                            debug!("… Extracting auxiliary term.");
+                            let auxterm: exp_f64_2_auxvar::AuxTerm<'_> = auxterm.extract(py)?;
+
+                            debug!("Running.");
+                            self.do_runall(
+                                &prefix,
+                                cli,
+                                algorithms,
+                                make_plotter,
+                                save_extra,
+                                (μ0, z),
+                                |p| match regterm {
+                                    RegTermPy::Radon(α) => run_fb_pair(
+                                        &dataterm,
+                                        &RadonRegTerm(α),
+                                        &RadonSquared,
+                                        &auxterm,
+                                        p,
+                                        |_| Err(NotImplemented.into()),
+                                    )
+                                    .map(|Pair(μ, _)| (μ, ())),
+                                    RegTermPy::NonnegRadon(α) => run_fb_pair(
+                                        &dataterm,
+                                        &NonnegRadonRegTerm(α),
+                                        &RadonSquared,
+                                        &auxterm,
+                                        p,
+                                        |_| Err(NotImplemented.into()),
+                                    )
+                                    .map(|Pair(μ, _)| (μ, ())),
+                                },
+                            )
+                        }
+                        Err(_) => {
+                            let z: exp_f64_2_auxvar_scalar::AuxVar<'_> = z_.extract(py)?;
+                            debug!("… Extracting data term.");
+                            let dataterm: exp_f64_2_auxvar_scalar::DataTerm<'_> =
+                                problem.dataterm.extract(py)?;
+                            debug!("… Extracting plotter.");
+                            let plot_factory: exp_f64_2::PlotFactory<'_> =
+                                if let Some(ref p) = problem.plot_factory {
+                                    p.extract(py)?
+                                } else {
+                                    PythonPlotFactory::dummy()
+                                };
+                            let make_plotter = |prefix| plot_factory.produce(prefix);
+                            debug!("… Extracting auxiliary term.");
+                            let auxterm: exp_f64_2_auxvar_scalar::AuxTerm<'_> =
+                                auxterm.extract(py)?;
+
+                            debug!("Running.");
+                            self.do_runall(
+                                &prefix,
+                                cli,
+                                algorithms,
+                                make_plotter,
+                                save_extra,
+                                (μ0, z),
+                                |p| match regterm {
+                                    RegTermPy::Radon(α) => run_fb_pair(
+                                        &dataterm,
+                                        &RadonRegTerm(α),
+                                        &RadonSquared,
+                                        &auxterm,
+                                        p,
+                                        |_| Err(NotImplemented.into()),
+                                    )
+                                    .map(|Pair(μ, _)| (μ, ())),
+                                    RegTermPy::NonnegRadon(α) => run_fb_pair(
+                                        &dataterm,
+                                        &NonnegRadonRegTerm(α),
+                                        &RadonSquared,
+                                        &auxterm,
+                                        p,
+                                        |_| Err(NotImplemented.into()),
+                                    )
+                                    .map(|Pair(μ, _)| (μ, ())),
+                                },
+                            )
+                        }
+                    }
+                }
+                (_, _, _) => {
+                    bail!("Errors in problem {} setup: auxiliary term or auxiliary variable initialisation given without the other", self.name);
+                }
+            }
+        })?;
+
+        // Should run the experiment here, going through all algorithms.
+        Ok(())
+    }
+
+    fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<f64> {
+        self.algorithm_overrides
+            .get(&alg)
+            .cloned()
+            .unwrap_or(Default::default())
+    }
+}

mercurial