Fri, 08 May 2026 17:28:21 -0500
Change README title
/*! Access to experiments provided in Python code. */ use crate::dolfinx_access::DolfinxPyFunction_f64; use crate::python_access::{ process_error, Differentiable, HasProx, PythonMapping, PythonPlotFactory, }; use alg_tools::direct_product::Pair; use alg_tools::error::{DynError, DynResult}; use alg_tools::loc::Loc; use anyhow::bail; use clap::Parser; use log::debug; use pointsource_algs::measures::DiscreteMeasure; use pointsource_algs::prox_penalty::RadonSquared; use pointsource_algs::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation}; use pointsource_algs::run::{ run_fb, run_fb_pair, AlgorithmConfig, DefaultAlgorithm, Named, RunError::NotImplemented, RunnableExperiment, RunnableExperimentExtras, }; use pointsource_algs::{AlgorithmOverrides, CommandLineArgs, ExperimentSetup}; use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::{PyDict, PyFunction}; use serde::{Deserialize, Serialize}; use serde_json; use serde_with::skip_serializing_none; use std::collections::HashMap; use std::ffi::CString; use std::path::Path; /// Command line experiment setup overrides #[skip_serializing_none] #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] #[pyclass(module = "pointsource_pde")] pub struct Experiments { /// List of Python setup files of experiments to perform #[arg(value_name = "EXPERIMENT.PY")] #[pyo3(get, set)] experiments: Vec<String>, #[arg(long)] #[pyo3(get, set)] /// Regularisation parameter override. /// /// Only use if running just a single experiment, as different experiments have different /// regularisation parameters. alpha: Option<f64>, #[arg(long)] #[pyo3(get, set)] /// Noise variance override variance: Option<f64>, } /// An experiment implemented in Python code #[derive(Serialize)] pub struct PythonExperiment { /// Name of the experiment name: String, /// Source file of experiment filename: String, /// The setup function #[serde(skip_serializing)] setup: Py<PyFunction>, /// The Python module #[serde(skip_serializing)] #[allow(unused)] module: Py<PyModule>, /// Algorithm overrides algorithm_overrides: HashMap<DefaultAlgorithm, AlgorithmOverrides<f64>>, } impl ExperimentSetup for Experiments { type FloatType = f64; fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>> { debug!("Loading runnable experiments. Attaching Python."); Python::attach(|py| { debug!("Loading json Python module."); // Load the "json" Python module, and "dumps" from it. let dumps = PyModule::import(py, "json")?.getattr("dumps")?; // Load the Python file describing each experiment, and extract // fields to file `PythonExperiment`. self.experiments .iter() .cloned() .map(|filename| -> DynResult<Box<dyn RunnableExperiment<_>>> { debug!("Trying to load experiment from {filename}"); let code = std::fs::read_to_string(&filename)?; let modname = filename.as_bytes(); // Add path of this file to module search path, to allow local dependencies let parent = Path::new(&filename).parent().unwrap().to_str().unwrap(); let locals = PyDict::new(py); locals.set_item("this_path", parent)?; py.run( c_str!("import sys; sys.path.insert(0, this_path)"), None, Some(&locals), )?; // Load module let module = PyModule::from_code( py, CString::new(code)?.as_ref(), CString::new(filename.as_str())?.as_ref(), CString::new(modname)?.as_ref(), )?; let name = module.getattr("name")?.extract()?; let setup = module .getattr("setup")? .cast_into() // Check that a Pyfunction .map_err(PyErr::from)? // Can't return a DowncastIntoError .unbind(); let algorithm_overrides = match module.getattr("algorithm_overrides") { Err(_) => Default::default(), Ok(o) => { // This passes through JSON and Serde as pyo3 does not allow // generics in AlgorithmOverides<_> - not even instantiating // them to specific values. It would be interesting to write // a proper serde Deserializer to not have to pass through JSON. // One day. serde_json::from_str(dumps.call1((o,))?.extract()?)? } }; debug!("… Found {name} with algorithm overrides {algorithm_overrides:?}"); Ok(Box::new(PythonExperiment { name, filename, setup, algorithm_overrides, module: module.unbind(), })) }) .try_collect() }) } } /// Types for experiments in two dimensions with `f64` floats. #[allow(non_camel_case_types)] #[allow(unused)] mod exp_f64_2 { use super::*; use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker}; /// Type of unknowns pub type Domain = DiscreteMeasure<Loc<2, f64>, f64>; /// Type of derivatives of objective function pub type DerivativeCodomain<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>; //pub type AuxTODOSpecifyFlexiblise<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>; pub type DerivativeMarker = DolfinxPyFunctionMarker<2, 2, 1>; /// Type of data terms pub type DataTerm<'py> = PythonMapping<'py, Domain, f64, Differentiable<DerivativeMarker>>; /// Plotter pub type PlotFactory<'py> = PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>; } /// Types for experiments in two dimensions with `f64` floats and aux #[allow(non_camel_case_types, non_snake_case)] #[allow(unused)] mod exp_f64_2_auxvar { use super::*; use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker}; use numpy::Ix2; /// Type of unknowns pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>; /// Type of derivatives of objective function pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>; /// Type of derivatives wrt. auxiliary variables. pub type AuxVar<'py> = Pair< DolfinxPyFunction_f64<'py, 2, 2, 1>, Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>, >; pub type DerivativeMarker = Pair< DolfinxPyFunctionMarker<2, 2, 1>, Pair< DolfinxPyFunctionMarker<2, 2, 1>, Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>, >, >; /// Type of data terms pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>; /// Auxiliary objective pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>; /// Plotter pub type PlotFactory<'py> = PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>; } /// Types for experiments in two dimensions with `f64` floats and scalar aux #[allow(non_camel_case_types, non_snake_case)] #[allow(unused)] mod exp_f64_2_auxvar_scalar { use super::*; use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker}; use numpy::Ix2; /// Type of unknowns pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>; /// Type of derivatives of objective function pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>; /// Type of derivatives wrt. auxiliary variables. pub type AuxVar<'py> = Pair<f64, Pair<f64, f64>>; pub type DerivativeMarker = Pair<DolfinxPyFunctionMarker<2, 2, 1>, Pair<f64, Pair<f64, f64>>>; /// Type of data terms pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>; /// Auxiliary objective pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>; /// Plotter pub type PlotFactory<'py> = PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>; } /// Available regularisation terms, exported to the Python side #[pyclass(module = "pointsource_pde", name = "RegTerm")] #[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq)] pub enum RegTermPy { /// Radon norm with weight `α`. Radon(f64), /// Radon norm ith weight `α` and a positivity constraint. NonnegRadon(f64), } impl From<RegTermPy> for Regularisation { fn from(reg: RegTermPy) -> Regularisation { match reg { RegTermPy::Radon(α) => Regularisation::Radon(α), RegTermPy::NonnegRadon(α) => Regularisation::NonnegRadon(α), } } } /// Problem description, exported to the Python side to be filled there. #[pyclass(module = "pointsource_pde")] #[derive(Debug)] pub struct Problem { /// Data term. On the python side, this should be a `class` that implements /// `apply` and `diff`. #[pyo3(set)] dataterm: Py<PyAny>, /// Regularisation #[pyo3(set, get)] regterm: RegTermPy, /// Auxiliary variable regularisation term or similar #[pyo3(set, get)] auxterm: Option<Py<PyAny>>, /// Initial auxiliary variable #[pyo3(set, get)] auxinit: Option<Py<PyAny>>, /// Initial measure #[pyo3(set, get)] μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>, /// Plotter #[pyo3(set, get)] plot_factory: Option<Py<PyAny>>, } #[pymethods] impl Problem { #[new] #[pyo3(signature = (dataterm, regterm, auxterm = None, auxinit = None, μinit = None, plot_factory=None))] fn new( dataterm: Py<PyAny>, regterm: RegTermPy, auxterm: Option<Py<PyAny>>, auxinit: Option<Py<PyAny>>, μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>, plot_factory: Option<Py<PyAny>>, ) -> Self { Self { dataterm, regterm, auxterm, auxinit, μinit, plot_factory } } #[getter] fn get_dataterm<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> { self.dataterm.bind(py).clone() } } impl RunnableExperiment<f64> for PythonExperiment { fn name(&self) -> &str { &self.name } fn runall( &self, cli: &CommandLineArgs, algs: Option<Vec<Named<AlgorithmConfig<f64>>>>, ) -> DynError { // Set up algorithms let algorithms = algs.unwrap_or_else(|| vec![DefaultAlgorithm::FB.get_named()]); debug!( "Entered PythonExperiment::runall for experimen {}. Attaching Python.", self.name ); Python::attach(|py| { debug!("Calling Python-side experiment initialisation."); let prefix = self.start(cli)?; let problem: PyRef<Problem> = process_error( "PythonExperiment::runall", py, self.setup .call1(py, (&prefix,)) .and_then(|r| r.extract(py).map_err(PyErr::from)), )?; let save_extra = |_, ()| Ok(()); let μ0 = problem.μinit.clone(); match (problem.regterm, &problem.auxterm, &problem.auxinit) { (regterm, None, None) => { debug!("… Extracting data term."); let dataterm: exp_f64_2::DataTerm<'_> = problem.dataterm.extract(py)?; debug!("… Extracting plotter."); let plot_factory: exp_f64_2::PlotFactory<'_> = if let Some(ref p) = problem.plot_factory { p.extract(py)? } else { PythonPlotFactory::dummy() }; let make_plotter = |prefix| plot_factory.produce(prefix); debug!("Running."); self.do_runall( &prefix, cli, algorithms, make_plotter, save_extra, μ0, |p| match regterm { RegTermPy::Radon(α) => { run_fb(&dataterm, &RadonRegTerm(α), &RadonSquared, p, |_| { Err(NotImplemented.into()) }) .map(|μ| (μ, ())) } RegTermPy::NonnegRadon(α) => { run_fb(&dataterm, &NonnegRadonRegTerm(α), &RadonSquared, p, |_| { Err(NotImplemented.into()) }) .map(|μ| (μ, ())) } }, ) } (regterm, Some(auxterm), Some(z_)) => { debug!("… Extracting auxiliary variable."); let z0: PyResult<exp_f64_2_auxvar::AuxVar<'_>> = z_.extract(py); match z0 { Ok(z) => { debug!("… Extracting data term."); let dataterm: exp_f64_2_auxvar::DataTerm<'_> = problem.dataterm.extract(py)?; debug!("… Extracting plotter."); let plot_factory: exp_f64_2::PlotFactory<'_> = if let Some(ref p) = problem.plot_factory { p.extract(py)? } else { PythonPlotFactory::dummy() }; let make_plotter = |prefix| plot_factory.produce(prefix); debug!("… Extracting auxiliary term."); let auxterm: exp_f64_2_auxvar::AuxTerm<'_> = auxterm.extract(py)?; debug!("Running."); self.do_runall( &prefix, cli, algorithms, make_plotter, save_extra, (μ0, z), |p| match regterm { RegTermPy::Radon(α) => run_fb_pair( &dataterm, &RadonRegTerm(α), &RadonSquared, &auxterm, p, |_| Err(NotImplemented.into()), ) .map(|Pair(μ, _)| (μ, ())), RegTermPy::NonnegRadon(α) => run_fb_pair( &dataterm, &NonnegRadonRegTerm(α), &RadonSquared, &auxterm, p, |_| Err(NotImplemented.into()), ) .map(|Pair(μ, _)| (μ, ())), }, ) } Err(_) => { let z: exp_f64_2_auxvar_scalar::AuxVar<'_> = z_.extract(py)?; debug!("… Extracting data term."); let dataterm: exp_f64_2_auxvar_scalar::DataTerm<'_> = problem.dataterm.extract(py)?; debug!("… Extracting plotter."); let plot_factory: exp_f64_2::PlotFactory<'_> = if let Some(ref p) = problem.plot_factory { p.extract(py)? } else { PythonPlotFactory::dummy() }; let make_plotter = |prefix| plot_factory.produce(prefix); debug!("… Extracting auxiliary term."); let auxterm: exp_f64_2_auxvar_scalar::AuxTerm<'_> = auxterm.extract(py)?; debug!("Running."); self.do_runall( &prefix, cli, algorithms, make_plotter, save_extra, (μ0, z), |p| match regterm { RegTermPy::Radon(α) => run_fb_pair( &dataterm, &RadonRegTerm(α), &RadonSquared, &auxterm, p, |_| Err(NotImplemented.into()), ) .map(|Pair(μ, _)| (μ, ())), RegTermPy::NonnegRadon(α) => run_fb_pair( &dataterm, &NonnegRadonRegTerm(α), &RadonSquared, &auxterm, p, |_| Err(NotImplemented.into()), ) .map(|Pair(μ, _)| (μ, ())), }, ) } } } (_, _, _) => { bail!("Errors in problem {} setup: auxiliary term or auxiliary variable initialisation given without the other", self.name); } } })?; // Should run the experiment here, going through all algorithms. Ok(()) } fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<f64> { self.algorithm_overrides .get(&alg) .cloned() .unwrap_or(Default::default()) } }