src/experiments.rs

changeset 1
a4137aedcb3a
child 3
c3a4f4bb87f7
equal deleted inserted replaced
0:7ec1cfe19a24 1:a4137aedcb3a
1 /*!
2 Access to experiments provided in Python code.
3 */
4
5 use crate::dolfinx_access::DolfinxPyFunction_f64;
6 use crate::python_access::{
7 process_error, Differentiable, HasProx, PythonMapping, PythonPlotFactory,
8 };
9 use alg_tools::direct_product::Pair;
10 use alg_tools::error::{DynError, DynResult};
11 use alg_tools::loc::Loc;
12 use anyhow::bail;
13 use clap::Parser;
14 use log::debug;
15 use pointsource_algs::measures::DiscreteMeasure;
16 use pointsource_algs::prox_penalty::RadonSquared;
17 use pointsource_algs::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation};
18 use pointsource_algs::run::{
19 run_fb, run_fb_pair, AlgorithmConfig, DefaultAlgorithm, Named, RunError::NotImplemented,
20 RunnableExperiment, RunnableExperimentExtras,
21 };
22 use pointsource_algs::{AlgorithmOverrides, CommandLineArgs, ExperimentSetup};
23 use pyo3::ffi::c_str;
24 use pyo3::prelude::*;
25 use pyo3::types::{PyDict, PyFunction};
26 use serde::{Deserialize, Serialize};
27 use serde_json;
28 use serde_with::skip_serializing_none;
29 use std::collections::HashMap;
30 use std::ffi::CString;
31 use std::path::Path;
32
33 /// Command line experiment setup overrides
34 #[skip_serializing_none]
35 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)]
36 #[pyclass(module = "pointsource_pde")]
37 pub struct Experiments {
38 /// List of Python setup files of experiments to perform
39 #[arg(value_name = "EXPERIMENT.PY")]
40 #[pyo3(get, set)]
41 experiments: Vec<String>,
42
43 #[arg(long)]
44 #[pyo3(get, set)]
45 /// Regularisation parameter override.
46 ///
47 /// Only use if running just a single experiment, as different experiments have different
48 /// regularisation parameters.
49 alpha: Option<f64>,
50
51 #[arg(long)]
52 #[pyo3(get, set)]
53 /// Noise variance override
54 variance: Option<f64>,
55 }
56
57 /// An experiment implemented in Python code
58 #[derive(Serialize)]
59 pub struct PythonExperiment {
60 /// Name of the experiment
61 name: String,
62 /// Source file of experiment
63 filename: String,
64 /// The setup function
65 #[serde(skip_serializing)]
66 setup: Py<PyFunction>,
67 /// The Python module
68 #[serde(skip_serializing)]
69 #[allow(unused)]
70 module: Py<PyModule>,
71 /// Algorithm overrides
72 algorithm_overrides: HashMap<DefaultAlgorithm, AlgorithmOverrides<f64>>,
73 }
74
75 impl ExperimentSetup for Experiments {
76 type FloatType = f64;
77
78 fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>> {
79 debug!("Loading runnable experiments. Attaching Python.");
80
81 Python::attach(|py| {
82 debug!("Loading json Python module.");
83
84 // Load the "json" Python module, and "dumps" from it.
85 let dumps = PyModule::import(py, "json")?.getattr("dumps")?;
86
87 // Load the Python file describing each experiment, and extract
88 // fields to file `PythonExperiment`.
89 self.experiments
90 .iter()
91 .cloned()
92 .map(|filename| -> DynResult<Box<dyn RunnableExperiment<_>>> {
93 debug!("Trying to load experiment from {filename}");
94 let code = std::fs::read_to_string(&filename)?;
95 let modname = filename.as_bytes();
96 // Add path of this file to module search path, to allow local dependencies
97 let parent = Path::new(&filename).parent().unwrap().to_str().unwrap();
98 let locals = PyDict::new(py);
99 locals.set_item("this_path", parent)?;
100 py.run(
101 c_str!("import sys; sys.path.insert(0, this_path)"),
102 None,
103 Some(&locals),
104 )?;
105 // Load module
106 let module = PyModule::from_code(
107 py,
108 CString::new(code)?.as_ref(),
109 CString::new(filename.as_str())?.as_ref(),
110 CString::new(modname)?.as_ref(),
111 )?;
112 let name = module.getattr("name")?.extract()?;
113 let setup = module
114 .getattr("setup")?
115 .cast_into() // Check that a Pyfunction
116 .map_err(PyErr::from)? // Can't return a DowncastIntoError
117 .unbind();
118 let algorithm_overrides = match module.getattr("algorithm_overrides") {
119 Err(_) => Default::default(),
120 Ok(o) => {
121 // This passes through JSON and Serde as pyo3 does not allow
122 // generics in AlgorithmOverides<_> - not even instantiating
123 // them to specific values. It would be interesting to write
124 // a proper serde Deserializer to not have to pass through JSON.
125 // One day.
126 serde_json::from_str(dumps.call1((o,))?.extract()?)?
127 }
128 };
129 debug!("… Found {name} with algorithm overrides {algorithm_overrides:?}");
130 Ok(Box::new(PythonExperiment {
131 name,
132 filename,
133 setup,
134 algorithm_overrides,
135 module: module.unbind(),
136 }))
137 })
138 .try_collect()
139 })
140 }
141 }
142
143 /// Types for experiments in two dimensions with `f64` floats.
144 #[allow(non_camel_case_types)]
145 #[allow(unused)]
146 mod exp_f64_2 {
147 use super::*;
148 use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker};
149
150 /// Type of unknowns
151 pub type Domain = DiscreteMeasure<Loc<2, f64>, f64>;
152 /// Type of derivatives of objective function
153 pub type DerivativeCodomain<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>;
154 //pub type AuxTODOSpecifyFlexiblise<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>;
155 pub type DerivativeMarker = DolfinxPyFunctionMarker<2, 2, 1>;
156 /// Type of data terms
157 pub type DataTerm<'py> = PythonMapping<'py, Domain, f64, Differentiable<DerivativeMarker>>;
158 /// Plotter
159 pub type PlotFactory<'py> =
160 PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>;
161 }
162
163 /// Types for experiments in two dimensions with `f64` floats and aux
164 #[allow(non_camel_case_types, non_snake_case)]
165 #[allow(unused)]
166 mod exp_f64_2_auxvar {
167 use super::*;
168 use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker};
169 use numpy::Ix2;
170
171 /// Type of unknowns
172 pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>;
173 /// Type of derivatives of objective function
174 pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>;
175 /// Type of derivatives wrt. auxiliary variables.
176 pub type AuxVar<'py> = Pair<
177 DolfinxPyFunction_f64<'py, 2, 2, 1>,
178 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>,
179 >;
180 pub type DerivativeMarker = Pair<
181 DolfinxPyFunctionMarker<2, 2, 1>,
182 Pair<
183 DolfinxPyFunctionMarker<2, 2, 1>,
184 Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>,
185 >,
186 >;
187 /// Type of data terms
188 pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>;
189 /// Auxiliary objective
190 pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>;
191 /// Plotter
192 pub type PlotFactory<'py> =
193 PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>;
194 }
195
196 /// Types for experiments in two dimensions with `f64` floats and scalar aux
197 #[allow(non_camel_case_types, non_snake_case)]
198 #[allow(unused)]
199 mod exp_f64_2_auxvar_scalar {
200 use super::*;
201 use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker};
202 use numpy::Ix2;
203
204 /// Type of unknowns
205 pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>;
206 /// Type of derivatives of objective function
207 pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>;
208 /// Type of derivatives wrt. auxiliary variables.
209 pub type AuxVar<'py> = Pair<f64, Pair<f64, f64>>;
210 pub type DerivativeMarker = Pair<DolfinxPyFunctionMarker<2, 2, 1>, Pair<f64, Pair<f64, f64>>>;
211 /// Type of data terms
212 pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>;
213 /// Auxiliary objective
214 pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>;
215 /// Plotter
216 pub type PlotFactory<'py> =
217 PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>;
218 }
219
220 #[pyclass(module = "pointsource_pde", name = "RegTerm")]
221 #[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq)]
222 pub enum RegTermPy {
223 // Radon norm with weight `α`.
224 Radon(f64),
225 // Radon norm ith weight `α` and a positivity constraint.
226 NonnegRadon(f64),
227 }
228
229 impl From<RegTermPy> for Regularisation {
230 fn from(reg: RegTermPy) -> Regularisation {
231 match reg {
232 RegTermPy::Radon(α) => Regularisation::Radon(α),
233 RegTermPy::NonnegRadon(α) => Regularisation::NonnegRadon(α),
234 }
235 }
236 }
237
238 #[pyclass(module = "pointsource_pde")]
239 #[derive(Debug)]
240 pub struct Problem {
241 /// Data term. On the python side, this should be a `class` that implements
242 /// `apply` from [`DiscreteMeasure_2_f64`] (TODO: extended parameters) to floats, and `diff`
243 /// from the space space to [`DolfinxPyFunction_f64<2, 1, 2>`].
244 #[pyo3(set)]
245 dataterm: Py<PyAny>, //exp_f64_2::DataTerm,
246
247 // Regularisation
248 #[pyo3(set, get)]
249 regterm: RegTermPy,
250
251 // Auxiliary variable regularisation term or similar
252 #[pyo3(set, get)]
253 auxterm: Option<Py<PyAny>>,
254
255 // Initial auxiliary variable
256 #[pyo3(set, get)]
257 auxinit: Option<Py<PyAny>>,
258
259 // Initial measure
260 #[pyo3(set, get)]
261 μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>,
262
263 // Plotter
264 #[pyo3(set, get)]
265 plot_factory: Option<Py<PyAny>>,
266 }
267
268 #[pymethods]
269 impl Problem {
270 #[new]
271 #[pyo3(signature = (dataterm, regterm, auxterm = None, auxinit = None, μinit = None, plot_factory=None))]
272 fn new(
273 dataterm: Py<PyAny>,
274 regterm: RegTermPy,
275 auxterm: Option<Py<PyAny>>,
276 auxinit: Option<Py<PyAny>>,
277 μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>,
278 plot_factory: Option<Py<PyAny>>,
279 ) -> Self {
280 Self { dataterm, regterm, auxterm, auxinit, μinit, plot_factory }
281 }
282
283 #[getter]
284 fn get_dataterm<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> {
285 self.dataterm.bind(py).clone()
286 }
287 }
288
289 impl RunnableExperiment<f64> for PythonExperiment {
290 fn name(&self) -> &str {
291 &self.name
292 }
293
294 fn runall(
295 &self,
296 cli: &CommandLineArgs,
297 algs: Option<Vec<Named<AlgorithmConfig<f64>>>>,
298 ) -> DynError {
299 // Set up algorithms
300 let algorithms = algs.unwrap_or_else(|| vec![DefaultAlgorithm::FB.get_named()]);
301
302 debug!(
303 "Entered PythonExperiment::runall for experimen {}. Attaching Python.",
304 self.name
305 );
306
307 Python::attach(|py| {
308 debug!("Calling Python-side experiment initialisation.");
309
310 let prefix = self.start(cli)?;
311
312 let problem: PyRef<Problem> = process_error(
313 "PythonExperiment::runall",
314 py,
315 self.setup
316 .call1(py, (&prefix,))
317 .and_then(|r| r.extract(py).map_err(PyErr::from)),
318 )?;
319
320 let save_extra = |_, ()| Ok(());
321 let μ0 = problem.μinit.clone();
322
323 match (problem.regterm, &problem.auxterm, &problem.auxinit) {
324 (regterm, None, None) => {
325 debug!("… Extracting data term.");
326 let dataterm: exp_f64_2::DataTerm<'_> = problem.dataterm.extract(py)?;
327 debug!("… Extracting plotter.");
328 let plot_factory: exp_f64_2::PlotFactory<'_> =
329 if let Some(ref p) = problem.plot_factory {
330 p.extract(py)?
331 } else {
332 PythonPlotFactory::dummy()
333 };
334 let make_plotter = |prefix| plot_factory.produce(prefix);
335
336 debug!("Running.");
337 self.do_runall(
338 &prefix,
339 cli,
340 algorithms,
341 make_plotter,
342 save_extra,
343 μ0,
344 |p| match regterm {
345 RegTermPy::Radon(α) => {
346 run_fb(&dataterm, &RadonRegTerm(α), &RadonSquared, p, |_| {
347 Err(NotImplemented.into())
348 })
349 .map(|μ| (μ, ()))
350 }
351 RegTermPy::NonnegRadon(α) => {
352 run_fb(&dataterm, &NonnegRadonRegTerm(α), &RadonSquared, p, |_| {
353 Err(NotImplemented.into())
354 })
355 .map(|μ| (μ, ()))
356 }
357 },
358 )
359 }
360 (regterm, Some(auxterm), Some(z_)) => {
361 debug!("… Extracting auxiliary variable.");
362 let z0: PyResult<exp_f64_2_auxvar::AuxVar<'_>> = z_.extract(py);
363 match z0 {
364 Ok(z) => {
365 debug!("… Extracting data term.");
366 let dataterm: exp_f64_2_auxvar::DataTerm<'_> =
367 problem.dataterm.extract(py)?;
368 debug!("… Extracting plotter.");
369 let plot_factory: exp_f64_2::PlotFactory<'_> =
370 if let Some(ref p) = problem.plot_factory {
371 p.extract(py)?
372 } else {
373 PythonPlotFactory::dummy()
374 };
375 let make_plotter = |prefix| plot_factory.produce(prefix);
376 debug!("… Extracting auxiliary term.");
377 let auxterm: exp_f64_2_auxvar::AuxTerm<'_> = auxterm.extract(py)?;
378
379 debug!("Running.");
380 self.do_runall(
381 &prefix,
382 cli,
383 algorithms,
384 make_plotter,
385 save_extra,
386 (μ0, z),
387 |p| match regterm {
388 RegTermPy::Radon(α) => run_fb_pair(
389 &dataterm,
390 &RadonRegTerm(α),
391 &RadonSquared,
392 &auxterm,
393 p,
394 |_| Err(NotImplemented.into()),
395 )
396 .map(|Pair(μ, _)| (μ, ())),
397 RegTermPy::NonnegRadon(α) => run_fb_pair(
398 &dataterm,
399 &NonnegRadonRegTerm(α),
400 &RadonSquared,
401 &auxterm,
402 p,
403 |_| Err(NotImplemented.into()),
404 )
405 .map(|Pair(μ, _)| (μ, ())),
406 },
407 )
408 }
409 Err(_) => {
410 let z: exp_f64_2_auxvar_scalar::AuxVar<'_> = z_.extract(py)?;
411 debug!("… Extracting data term.");
412 let dataterm: exp_f64_2_auxvar_scalar::DataTerm<'_> =
413 problem.dataterm.extract(py)?;
414 debug!("… Extracting plotter.");
415 let plot_factory: exp_f64_2::PlotFactory<'_> =
416 if let Some(ref p) = problem.plot_factory {
417 p.extract(py)?
418 } else {
419 PythonPlotFactory::dummy()
420 };
421 let make_plotter = |prefix| plot_factory.produce(prefix);
422 debug!("… Extracting auxiliary term.");
423 let auxterm: exp_f64_2_auxvar_scalar::AuxTerm<'_> =
424 auxterm.extract(py)?;
425
426 debug!("Running.");
427 self.do_runall(
428 &prefix,
429 cli,
430 algorithms,
431 make_plotter,
432 save_extra,
433 (μ0, z),
434 |p| match regterm {
435 RegTermPy::Radon(α) => run_fb_pair(
436 &dataterm,
437 &RadonRegTerm(α),
438 &RadonSquared,
439 &auxterm,
440 p,
441 |_| Err(NotImplemented.into()),
442 )
443 .map(|Pair(μ, _)| (μ, ())),
444 RegTermPy::NonnegRadon(α) => run_fb_pair(
445 &dataterm,
446 &NonnegRadonRegTerm(α),
447 &RadonSquared,
448 &auxterm,
449 p,
450 |_| Err(NotImplemented.into()),
451 )
452 .map(|Pair(μ, _)| (μ, ())),
453 },
454 )
455 }
456 }
457 }
458 (_, _, _) => {
459 bail!("Errors in problem {} setup: auxiliary term or auxiliary variable initialisation given without the other", self.name);
460 }
461 }
462 })?;
463
464 // Should run the experiment here, going through all algorithms.
465 Ok(())
466 }
467
468 fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<f64> {
469 self.algorithm_overrides
470 .get(&alg)
471 .cloned()
472 .unwrap_or(Default::default())
473 }
474 }

mercurial