--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/experiments.rs Thu Feb 26 09:32:12 2026 -0500 @@ -0,0 +1,474 @@ +/*! +Access to experiments provided in Python code. +*/ + +use crate::dolfinx_access::DolfinxPyFunction_f64; +use crate::python_access::{ + process_error, Differentiable, HasProx, PythonMapping, PythonPlotFactory, +}; +use alg_tools::direct_product::Pair; +use alg_tools::error::{DynError, DynResult}; +use alg_tools::loc::Loc; +use anyhow::bail; +use clap::Parser; +use log::debug; +use pointsource_algs::measures::DiscreteMeasure; +use pointsource_algs::prox_penalty::RadonSquared; +use pointsource_algs::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation}; +use pointsource_algs::run::{ + run_fb, run_fb_pair, AlgorithmConfig, DefaultAlgorithm, Named, RunError::NotImplemented, + RunnableExperiment, RunnableExperimentExtras, +}; +use pointsource_algs::{AlgorithmOverrides, CommandLineArgs, ExperimentSetup}; +use pyo3::ffi::c_str; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyFunction}; +use serde::{Deserialize, Serialize}; +use serde_json; +use serde_with::skip_serializing_none; +use std::collections::HashMap; +use std::ffi::CString; +use std::path::Path; + +/// Command line experiment setup overrides +#[skip_serializing_none] +#[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] +#[pyclass(module = "pointsource_pde")] +pub struct Experiments { + /// List of Python setup files of experiments to perform + #[arg(value_name = "EXPERIMENT.PY")] + #[pyo3(get, set)] + experiments: Vec<String>, + + #[arg(long)] + #[pyo3(get, set)] + /// Regularisation parameter override. + /// + /// Only use if running just a single experiment, as different experiments have different + /// regularisation parameters. + alpha: Option<f64>, + + #[arg(long)] + #[pyo3(get, set)] + /// Noise variance override + variance: Option<f64>, +} + +/// An experiment implemented in Python code +#[derive(Serialize)] +pub struct PythonExperiment { + /// Name of the experiment + name: String, + /// Source file of experiment + filename: String, + /// The setup function + #[serde(skip_serializing)] + setup: Py<PyFunction>, + /// The Python module + #[serde(skip_serializing)] + #[allow(unused)] + module: Py<PyModule>, + /// Algorithm overrides + algorithm_overrides: HashMap<DefaultAlgorithm, AlgorithmOverrides<f64>>, +} + +impl ExperimentSetup for Experiments { + type FloatType = f64; + + fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>> { + debug!("Loading runnable experiments. Attaching Python."); + + Python::attach(|py| { + debug!("Loading json Python module."); + + // Load the "json" Python module, and "dumps" from it. + let dumps = PyModule::import(py, "json")?.getattr("dumps")?; + + // Load the Python file describing each experiment, and extract + // fields to file `PythonExperiment`. + self.experiments + .iter() + .cloned() + .map(|filename| -> DynResult<Box<dyn RunnableExperiment<_>>> { + debug!("Trying to load experiment from {filename}"); + let code = std::fs::read_to_string(&filename)?; + let modname = filename.as_bytes(); + // Add path of this file to module search path, to allow local dependencies + let parent = Path::new(&filename).parent().unwrap().to_str().unwrap(); + let locals = PyDict::new(py); + locals.set_item("this_path", parent)?; + py.run( + c_str!("import sys; sys.path.insert(0, this_path)"), + None, + Some(&locals), + )?; + // Load module + let module = PyModule::from_code( + py, + CString::new(code)?.as_ref(), + CString::new(filename.as_str())?.as_ref(), + CString::new(modname)?.as_ref(), + )?; + let name = module.getattr("name")?.extract()?; + let setup = module + .getattr("setup")? + .cast_into() // Check that a Pyfunction + .map_err(PyErr::from)? // Can't return a DowncastIntoError + .unbind(); + let algorithm_overrides = match module.getattr("algorithm_overrides") { + Err(_) => Default::default(), + Ok(o) => { + // This passes through JSON and Serde as pyo3 does not allow + // generics in AlgorithmOverides<_> - not even instantiating + // them to specific values. It would be interesting to write + // a proper serde Deserializer to not have to pass through JSON. + // One day. + serde_json::from_str(dumps.call1((o,))?.extract()?)? + } + }; + debug!("… Found {name} with algorithm overrides {algorithm_overrides:?}"); + Ok(Box::new(PythonExperiment { + name, + filename, + setup, + algorithm_overrides, + module: module.unbind(), + })) + }) + .try_collect() + }) + } +} + +/// Types for experiments in two dimensions with `f64` floats. +#[allow(non_camel_case_types)] +#[allow(unused)] +mod exp_f64_2 { + use super::*; + use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker}; + + /// Type of unknowns + pub type Domain = DiscreteMeasure<Loc<2, f64>, f64>; + /// Type of derivatives of objective function + pub type DerivativeCodomain<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>; + //pub type AuxTODOSpecifyFlexiblise<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>; + pub type DerivativeMarker = DolfinxPyFunctionMarker<2, 2, 1>; + /// Type of data terms + pub type DataTerm<'py> = PythonMapping<'py, Domain, f64, Differentiable<DerivativeMarker>>; + /// Plotter + pub type PlotFactory<'py> = + PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>; +} + +/// Types for experiments in two dimensions with `f64` floats and aux +#[allow(non_camel_case_types, non_snake_case)] +#[allow(unused)] +mod exp_f64_2_auxvar { + use super::*; + use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker}; + use numpy::Ix2; + + /// Type of unknowns + pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>; + /// Type of derivatives of objective function + pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>; + /// Type of derivatives wrt. auxiliary variables. + pub type AuxVar<'py> = Pair< + DolfinxPyFunction_f64<'py, 2, 2, 1>, + Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>, + >; + pub type DerivativeMarker = Pair< + DolfinxPyFunctionMarker<2, 2, 1>, + Pair< + DolfinxPyFunctionMarker<2, 2, 1>, + Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>, + >, + >; + /// Type of data terms + pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>; + /// Auxiliary objective + pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>; + /// Plotter + pub type PlotFactory<'py> = + PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>; +} + +/// Types for experiments in two dimensions with `f64` floats and scalar aux +#[allow(non_camel_case_types, non_snake_case)] +#[allow(unused)] +mod exp_f64_2_auxvar_scalar { + use super::*; + use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker}; + use numpy::Ix2; + + /// Type of unknowns + pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>; + /// Type of derivatives of objective function + pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>; + /// Type of derivatives wrt. auxiliary variables. + pub type AuxVar<'py> = Pair<f64, Pair<f64, f64>>; + pub type DerivativeMarker = Pair<DolfinxPyFunctionMarker<2, 2, 1>, Pair<f64, Pair<f64, f64>>>; + /// Type of data terms + pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>; + /// Auxiliary objective + pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>; + /// Plotter + pub type PlotFactory<'py> = + PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>; +} + +#[pyclass(module = "pointsource_pde", name = "RegTerm")] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum RegTermPy { + // Radon norm with weight `α`. + Radon(f64), + // Radon norm ith weight `α` and a positivity constraint. + NonnegRadon(f64), +} + +impl From<RegTermPy> for Regularisation { + fn from(reg: RegTermPy) -> Regularisation { + match reg { + RegTermPy::Radon(α) => Regularisation::Radon(α), + RegTermPy::NonnegRadon(α) => Regularisation::NonnegRadon(α), + } + } +} + +#[pyclass(module = "pointsource_pde")] +#[derive(Debug)] +pub struct Problem { + /// Data term. On the python side, this should be a `class` that implements + /// `apply` from [`DiscreteMeasure_2_f64`] (TODO: extended parameters) to floats, and `diff` + /// from the space space to [`DolfinxPyFunction_f64<2, 1, 2>`]. + #[pyo3(set)] + dataterm: Py<PyAny>, //exp_f64_2::DataTerm, + + // Regularisation + #[pyo3(set, get)] + regterm: RegTermPy, + + // Auxiliary variable regularisation term or similar + #[pyo3(set, get)] + auxterm: Option<Py<PyAny>>, + + // Initial auxiliary variable + #[pyo3(set, get)] + auxinit: Option<Py<PyAny>>, + + // Initial measure + #[pyo3(set, get)] + μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>, + + // Plotter + #[pyo3(set, get)] + plot_factory: Option<Py<PyAny>>, +} + +#[pymethods] +impl Problem { + #[new] + #[pyo3(signature = (dataterm, regterm, auxterm = None, auxinit = None, μinit = None, plot_factory=None))] + fn new( + dataterm: Py<PyAny>, + regterm: RegTermPy, + auxterm: Option<Py<PyAny>>, + auxinit: Option<Py<PyAny>>, + μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>, + plot_factory: Option<Py<PyAny>>, + ) -> Self { + Self { dataterm, regterm, auxterm, auxinit, μinit, plot_factory } + } + + #[getter] + fn get_dataterm<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> { + self.dataterm.bind(py).clone() + } +} + +impl RunnableExperiment<f64> for PythonExperiment { + fn name(&self) -> &str { + &self.name + } + + fn runall( + &self, + cli: &CommandLineArgs, + algs: Option<Vec<Named<AlgorithmConfig<f64>>>>, + ) -> DynError { + // Set up algorithms + let algorithms = algs.unwrap_or_else(|| vec![DefaultAlgorithm::FB.get_named()]); + + debug!( + "Entered PythonExperiment::runall for experimen {}. Attaching Python.", + self.name + ); + + Python::attach(|py| { + debug!("Calling Python-side experiment initialisation."); + + let prefix = self.start(cli)?; + + let problem: PyRef<Problem> = process_error( + "PythonExperiment::runall", + py, + self.setup + .call1(py, (&prefix,)) + .and_then(|r| r.extract(py).map_err(PyErr::from)), + )?; + + let save_extra = |_, ()| Ok(()); + let μ0 = problem.μinit.clone(); + + match (problem.regterm, &problem.auxterm, &problem.auxinit) { + (regterm, None, None) => { + debug!("… Extracting data term."); + let dataterm: exp_f64_2::DataTerm<'_> = problem.dataterm.extract(py)?; + debug!("… Extracting plotter."); + let plot_factory: exp_f64_2::PlotFactory<'_> = + if let Some(ref p) = problem.plot_factory { + p.extract(py)? + } else { + PythonPlotFactory::dummy() + }; + let make_plotter = |prefix| plot_factory.produce(prefix); + + debug!("Running."); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| match regterm { + RegTermPy::Radon(α) => { + run_fb(&dataterm, &RadonRegTerm(α), &RadonSquared, p, |_| { + Err(NotImplemented.into()) + }) + .map(|μ| (μ, ())) + } + RegTermPy::NonnegRadon(α) => { + run_fb(&dataterm, &NonnegRadonRegTerm(α), &RadonSquared, p, |_| { + Err(NotImplemented.into()) + }) + .map(|μ| (μ, ())) + } + }, + ) + } + (regterm, Some(auxterm), Some(z_)) => { + debug!("… Extracting auxiliary variable."); + let z0: PyResult<exp_f64_2_auxvar::AuxVar<'_>> = z_.extract(py); + match z0 { + Ok(z) => { + debug!("… Extracting data term."); + let dataterm: exp_f64_2_auxvar::DataTerm<'_> = + problem.dataterm.extract(py)?; + debug!("… Extracting plotter."); + let plot_factory: exp_f64_2::PlotFactory<'_> = + if let Some(ref p) = problem.plot_factory { + p.extract(py)? + } else { + PythonPlotFactory::dummy() + }; + let make_plotter = |prefix| plot_factory.produce(prefix); + debug!("… Extracting auxiliary term."); + let auxterm: exp_f64_2_auxvar::AuxTerm<'_> = auxterm.extract(py)?; + + debug!("Running."); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + (μ0, z), + |p| match regterm { + RegTermPy::Radon(α) => run_fb_pair( + &dataterm, + &RadonRegTerm(α), + &RadonSquared, + &auxterm, + p, + |_| Err(NotImplemented.into()), + ) + .map(|Pair(μ, _)| (μ, ())), + RegTermPy::NonnegRadon(α) => run_fb_pair( + &dataterm, + &NonnegRadonRegTerm(α), + &RadonSquared, + &auxterm, + p, + |_| Err(NotImplemented.into()), + ) + .map(|Pair(μ, _)| (μ, ())), + }, + ) + } + Err(_) => { + let z: exp_f64_2_auxvar_scalar::AuxVar<'_> = z_.extract(py)?; + debug!("… Extracting data term."); + let dataterm: exp_f64_2_auxvar_scalar::DataTerm<'_> = + problem.dataterm.extract(py)?; + debug!("… Extracting plotter."); + let plot_factory: exp_f64_2::PlotFactory<'_> = + if let Some(ref p) = problem.plot_factory { + p.extract(py)? + } else { + PythonPlotFactory::dummy() + }; + let make_plotter = |prefix| plot_factory.produce(prefix); + debug!("… Extracting auxiliary term."); + let auxterm: exp_f64_2_auxvar_scalar::AuxTerm<'_> = + auxterm.extract(py)?; + + debug!("Running."); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + (μ0, z), + |p| match regterm { + RegTermPy::Radon(α) => run_fb_pair( + &dataterm, + &RadonRegTerm(α), + &RadonSquared, + &auxterm, + p, + |_| Err(NotImplemented.into()), + ) + .map(|Pair(μ, _)| (μ, ())), + RegTermPy::NonnegRadon(α) => run_fb_pair( + &dataterm, + &NonnegRadonRegTerm(α), + &RadonSquared, + &auxterm, + p, + |_| Err(NotImplemented.into()), + ) + .map(|Pair(μ, _)| (μ, ())), + }, + ) + } + } + } + (_, _, _) => { + bail!("Errors in problem {} setup: auxiliary term or auxiliary variable initialisation given without the other", self.name); + } + } + })?; + + // Should run the experiment here, going through all algorithms. + Ok(()) + } + + fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<f64> { + self.algorithm_overrides + .get(&alg) + .cloned() + .unwrap_or(Default::default()) + } +}