| |
1 /*! |
| |
2 Access to experiments provided in Python code. |
| |
3 */ |
| |
4 |
| |
5 use crate::dolfinx_access::DolfinxPyFunction_f64; |
| |
6 use crate::python_access::{ |
| |
7 process_error, Differentiable, HasProx, PythonMapping, PythonPlotFactory, |
| |
8 }; |
| |
9 use alg_tools::direct_product::Pair; |
| |
10 use alg_tools::error::{DynError, DynResult}; |
| |
11 use alg_tools::loc::Loc; |
| |
12 use anyhow::bail; |
| |
13 use clap::Parser; |
| |
14 use log::debug; |
| |
15 use pointsource_algs::measures::DiscreteMeasure; |
| |
16 use pointsource_algs::prox_penalty::RadonSquared; |
| |
17 use pointsource_algs::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation}; |
| |
18 use pointsource_algs::run::{ |
| |
19 run_fb, run_fb_pair, AlgorithmConfig, DefaultAlgorithm, Named, RunError::NotImplemented, |
| |
20 RunnableExperiment, RunnableExperimentExtras, |
| |
21 }; |
| |
22 use pointsource_algs::{AlgorithmOverrides, CommandLineArgs, ExperimentSetup}; |
| |
23 use pyo3::ffi::c_str; |
| |
24 use pyo3::prelude::*; |
| |
25 use pyo3::types::{PyDict, PyFunction}; |
| |
26 use serde::{Deserialize, Serialize}; |
| |
27 use serde_json; |
| |
28 use serde_with::skip_serializing_none; |
| |
29 use std::collections::HashMap; |
| |
30 use std::ffi::CString; |
| |
31 use std::path::Path; |
| |
32 |
| |
33 /// Command line experiment setup overrides |
| |
34 #[skip_serializing_none] |
| |
35 #[derive(Parser, Debug, Serialize, Deserialize, Default, Clone)] |
| |
36 #[pyclass(module = "pointsource_pde")] |
| |
37 pub struct Experiments { |
| |
38 /// List of Python setup files of experiments to perform |
| |
39 #[arg(value_name = "EXPERIMENT.PY")] |
| |
40 #[pyo3(get, set)] |
| |
41 experiments: Vec<String>, |
| |
42 |
| |
43 #[arg(long)] |
| |
44 #[pyo3(get, set)] |
| |
45 /// Regularisation parameter override. |
| |
46 /// |
| |
47 /// Only use if running just a single experiment, as different experiments have different |
| |
48 /// regularisation parameters. |
| |
49 alpha: Option<f64>, |
| |
50 |
| |
51 #[arg(long)] |
| |
52 #[pyo3(get, set)] |
| |
53 /// Noise variance override |
| |
54 variance: Option<f64>, |
| |
55 } |
| |
56 |
| |
57 /// An experiment implemented in Python code |
| |
58 #[derive(Serialize)] |
| |
59 pub struct PythonExperiment { |
| |
60 /// Name of the experiment |
| |
61 name: String, |
| |
62 /// Source file of experiment |
| |
63 filename: String, |
| |
64 /// The setup function |
| |
65 #[serde(skip_serializing)] |
| |
66 setup: Py<PyFunction>, |
| |
67 /// The Python module |
| |
68 #[serde(skip_serializing)] |
| |
69 #[allow(unused)] |
| |
70 module: Py<PyModule>, |
| |
71 /// Algorithm overrides |
| |
72 algorithm_overrides: HashMap<DefaultAlgorithm, AlgorithmOverrides<f64>>, |
| |
73 } |
| |
74 |
| |
75 impl ExperimentSetup for Experiments { |
| |
76 type FloatType = f64; |
| |
77 |
| |
78 fn runnables(&self) -> DynResult<Vec<Box<dyn RunnableExperiment<Self::FloatType>>>> { |
| |
79 debug!("Loading runnable experiments. Attaching Python."); |
| |
80 |
| |
81 Python::attach(|py| { |
| |
82 debug!("Loading json Python module."); |
| |
83 |
| |
84 // Load the "json" Python module, and "dumps" from it. |
| |
85 let dumps = PyModule::import(py, "json")?.getattr("dumps")?; |
| |
86 |
| |
87 // Load the Python file describing each experiment, and extract |
| |
88 // fields to file `PythonExperiment`. |
| |
89 self.experiments |
| |
90 .iter() |
| |
91 .cloned() |
| |
92 .map(|filename| -> DynResult<Box<dyn RunnableExperiment<_>>> { |
| |
93 debug!("Trying to load experiment from {filename}"); |
| |
94 let code = std::fs::read_to_string(&filename)?; |
| |
95 let modname = filename.as_bytes(); |
| |
96 // Add path of this file to module search path, to allow local dependencies |
| |
97 let parent = Path::new(&filename).parent().unwrap().to_str().unwrap(); |
| |
98 let locals = PyDict::new(py); |
| |
99 locals.set_item("this_path", parent)?; |
| |
100 py.run( |
| |
101 c_str!("import sys; sys.path.insert(0, this_path)"), |
| |
102 None, |
| |
103 Some(&locals), |
| |
104 )?; |
| |
105 // Load module |
| |
106 let module = PyModule::from_code( |
| |
107 py, |
| |
108 CString::new(code)?.as_ref(), |
| |
109 CString::new(filename.as_str())?.as_ref(), |
| |
110 CString::new(modname)?.as_ref(), |
| |
111 )?; |
| |
112 let name = module.getattr("name")?.extract()?; |
| |
113 let setup = module |
| |
114 .getattr("setup")? |
| |
115 .cast_into() // Check that a Pyfunction |
| |
116 .map_err(PyErr::from)? // Can't return a DowncastIntoError |
| |
117 .unbind(); |
| |
118 let algorithm_overrides = match module.getattr("algorithm_overrides") { |
| |
119 Err(_) => Default::default(), |
| |
120 Ok(o) => { |
| |
121 // This passes through JSON and Serde as pyo3 does not allow |
| |
122 // generics in AlgorithmOverides<_> - not even instantiating |
| |
123 // them to specific values. It would be interesting to write |
| |
124 // a proper serde Deserializer to not have to pass through JSON. |
| |
125 // One day. |
| |
126 serde_json::from_str(dumps.call1((o,))?.extract()?)? |
| |
127 } |
| |
128 }; |
| |
129 debug!("… Found {name} with algorithm overrides {algorithm_overrides:?}"); |
| |
130 Ok(Box::new(PythonExperiment { |
| |
131 name, |
| |
132 filename, |
| |
133 setup, |
| |
134 algorithm_overrides, |
| |
135 module: module.unbind(), |
| |
136 })) |
| |
137 }) |
| |
138 .try_collect() |
| |
139 }) |
| |
140 } |
| |
141 } |
| |
142 |
| |
143 /// Types for experiments in two dimensions with `f64` floats. |
| |
144 #[allow(non_camel_case_types)] |
| |
145 #[allow(unused)] |
| |
146 mod exp_f64_2 { |
| |
147 use super::*; |
| |
148 use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker}; |
| |
149 |
| |
150 /// Type of unknowns |
| |
151 pub type Domain = DiscreteMeasure<Loc<2, f64>, f64>; |
| |
152 /// Type of derivatives of objective function |
| |
153 pub type DerivativeCodomain<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>; |
| |
154 //pub type AuxTODOSpecifyFlexiblise<'py> = DolfinxPyFunction_f64<'py, 2, 2, 1>; |
| |
155 pub type DerivativeMarker = DolfinxPyFunctionMarker<2, 2, 1>; |
| |
156 /// Type of data terms |
| |
157 pub type DataTerm<'py> = PythonMapping<'py, Domain, f64, Differentiable<DerivativeMarker>>; |
| |
158 /// Plotter |
| |
159 pub type PlotFactory<'py> = |
| |
160 PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>; |
| |
161 } |
| |
162 |
| |
163 /// Types for experiments in two dimensions with `f64` floats and aux |
| |
164 #[allow(non_camel_case_types, non_snake_case)] |
| |
165 #[allow(unused)] |
| |
166 mod exp_f64_2_auxvar { |
| |
167 use super::*; |
| |
168 use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker}; |
| |
169 use numpy::Ix2; |
| |
170 |
| |
171 /// Type of unknowns |
| |
172 pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>; |
| |
173 /// Type of derivatives of objective function |
| |
174 pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>; |
| |
175 /// Type of derivatives wrt. auxiliary variables. |
| |
176 pub type AuxVar<'py> = Pair< |
| |
177 DolfinxPyFunction_f64<'py, 2, 2, 1>, |
| |
178 Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, DolfinxPyFunction_f64<'py, 2, 2, 1>>, |
| |
179 >; |
| |
180 pub type DerivativeMarker = Pair< |
| |
181 DolfinxPyFunctionMarker<2, 2, 1>, |
| |
182 Pair< |
| |
183 DolfinxPyFunctionMarker<2, 2, 1>, |
| |
184 Pair<DolfinxPyFunctionMarker<2, 2, 1>, DolfinxPyFunctionMarker<2, 2, 1>>, |
| |
185 >, |
| |
186 >; |
| |
187 /// Type of data terms |
| |
188 pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>; |
| |
189 /// Auxiliary objective |
| |
190 pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>; |
| |
191 /// Plotter |
| |
192 pub type PlotFactory<'py> = |
| |
193 PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>; |
| |
194 } |
| |
195 |
| |
196 /// Types for experiments in two dimensions with `f64` floats and scalar aux |
| |
197 #[allow(non_camel_case_types, non_snake_case)] |
| |
198 #[allow(unused)] |
| |
199 mod exp_f64_2_auxvar_scalar { |
| |
200 use super::*; |
| |
201 use crate::python_access::{DolfinxPyFunctionMarker, NumpyArrayMarker}; |
| |
202 use numpy::Ix2; |
| |
203 |
| |
204 /// Type of unknowns |
| |
205 pub type Domain<'py> = Pair<DiscreteMeasure<Loc<2, f64>, f64>, AuxVar<'py>>; |
| |
206 /// Type of derivatives of objective function |
| |
207 pub type DerivativeCodomain<'py> = Pair<DolfinxPyFunction_f64<'py, 2, 2, 1>, AuxVar<'py>>; |
| |
208 /// Type of derivatives wrt. auxiliary variables. |
| |
209 pub type AuxVar<'py> = Pair<f64, Pair<f64, f64>>; |
| |
210 pub type DerivativeMarker = Pair<DolfinxPyFunctionMarker<2, 2, 1>, Pair<f64, Pair<f64, f64>>>; |
| |
211 /// Type of data terms |
| |
212 pub type DataTerm<'py> = PythonMapping<'py, Domain<'py>, f64, Differentiable<DerivativeMarker>>; |
| |
213 /// Auxiliary objective |
| |
214 pub type AuxTerm<'py> = PythonMapping<'py, AuxVar<'py>, f64, HasProx>; |
| |
215 /// Plotter |
| |
216 pub type PlotFactory<'py> = |
| |
217 PythonPlotFactory<'py, DerivativeCodomain<'py>, DerivativeCodomain<'py>>; |
| |
218 } |
| |
219 |
| |
220 #[pyclass(module = "pointsource_pde", name = "RegTerm")] |
| |
221 #[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq)] |
| |
222 pub enum RegTermPy { |
| |
223 // Radon norm with weight `α`. |
| |
224 Radon(f64), |
| |
225 // Radon norm ith weight `α` and a positivity constraint. |
| |
226 NonnegRadon(f64), |
| |
227 } |
| |
228 |
| |
229 impl From<RegTermPy> for Regularisation { |
| |
230 fn from(reg: RegTermPy) -> Regularisation { |
| |
231 match reg { |
| |
232 RegTermPy::Radon(α) => Regularisation::Radon(α), |
| |
233 RegTermPy::NonnegRadon(α) => Regularisation::NonnegRadon(α), |
| |
234 } |
| |
235 } |
| |
236 } |
| |
237 |
| |
238 #[pyclass(module = "pointsource_pde")] |
| |
239 #[derive(Debug)] |
| |
240 pub struct Problem { |
| |
241 /// Data term. On the python side, this should be a `class` that implements |
| |
242 /// `apply` from [`DiscreteMeasure_2_f64`] (TODO: extended parameters) to floats, and `diff` |
| |
243 /// from the space space to [`DolfinxPyFunction_f64<2, 1, 2>`]. |
| |
244 #[pyo3(set)] |
| |
245 dataterm: Py<PyAny>, //exp_f64_2::DataTerm, |
| |
246 |
| |
247 // Regularisation |
| |
248 #[pyo3(set, get)] |
| |
249 regterm: RegTermPy, |
| |
250 |
| |
251 // Auxiliary variable regularisation term or similar |
| |
252 #[pyo3(set, get)] |
| |
253 auxterm: Option<Py<PyAny>>, |
| |
254 |
| |
255 // Initial auxiliary variable |
| |
256 #[pyo3(set, get)] |
| |
257 auxinit: Option<Py<PyAny>>, |
| |
258 |
| |
259 // Initial measure |
| |
260 #[pyo3(set, get)] |
| |
261 μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>, |
| |
262 |
| |
263 // Plotter |
| |
264 #[pyo3(set, get)] |
| |
265 plot_factory: Option<Py<PyAny>>, |
| |
266 } |
| |
267 |
| |
268 #[pymethods] |
| |
269 impl Problem { |
| |
270 #[new] |
| |
271 #[pyo3(signature = (dataterm, regterm, auxterm = None, auxinit = None, μinit = None, plot_factory=None))] |
| |
272 fn new( |
| |
273 dataterm: Py<PyAny>, |
| |
274 regterm: RegTermPy, |
| |
275 auxterm: Option<Py<PyAny>>, |
| |
276 auxinit: Option<Py<PyAny>>, |
| |
277 μinit: Option<DiscreteMeasure<Loc<2, f64>, f64>>, |
| |
278 plot_factory: Option<Py<PyAny>>, |
| |
279 ) -> Self { |
| |
280 Self { dataterm, regterm, auxterm, auxinit, μinit, plot_factory } |
| |
281 } |
| |
282 |
| |
283 #[getter] |
| |
284 fn get_dataterm<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> { |
| |
285 self.dataterm.bind(py).clone() |
| |
286 } |
| |
287 } |
| |
288 |
| |
289 impl RunnableExperiment<f64> for PythonExperiment { |
| |
290 fn name(&self) -> &str { |
| |
291 &self.name |
| |
292 } |
| |
293 |
| |
294 fn runall( |
| |
295 &self, |
| |
296 cli: &CommandLineArgs, |
| |
297 algs: Option<Vec<Named<AlgorithmConfig<f64>>>>, |
| |
298 ) -> DynError { |
| |
299 // Set up algorithms |
| |
300 let algorithms = algs.unwrap_or_else(|| vec![DefaultAlgorithm::FB.get_named()]); |
| |
301 |
| |
302 debug!( |
| |
303 "Entered PythonExperiment::runall for experimen {}. Attaching Python.", |
| |
304 self.name |
| |
305 ); |
| |
306 |
| |
307 Python::attach(|py| { |
| |
308 debug!("Calling Python-side experiment initialisation."); |
| |
309 |
| |
310 let prefix = self.start(cli)?; |
| |
311 |
| |
312 let problem: PyRef<Problem> = process_error( |
| |
313 "PythonExperiment::runall", |
| |
314 py, |
| |
315 self.setup |
| |
316 .call1(py, (&prefix,)) |
| |
317 .and_then(|r| r.extract(py).map_err(PyErr::from)), |
| |
318 )?; |
| |
319 |
| |
320 let save_extra = |_, ()| Ok(()); |
| |
321 let μ0 = problem.μinit.clone(); |
| |
322 |
| |
323 match (problem.regterm, &problem.auxterm, &problem.auxinit) { |
| |
324 (regterm, None, None) => { |
| |
325 debug!("… Extracting data term."); |
| |
326 let dataterm: exp_f64_2::DataTerm<'_> = problem.dataterm.extract(py)?; |
| |
327 debug!("… Extracting plotter."); |
| |
328 let plot_factory: exp_f64_2::PlotFactory<'_> = |
| |
329 if let Some(ref p) = problem.plot_factory { |
| |
330 p.extract(py)? |
| |
331 } else { |
| |
332 PythonPlotFactory::dummy() |
| |
333 }; |
| |
334 let make_plotter = |prefix| plot_factory.produce(prefix); |
| |
335 |
| |
336 debug!("Running."); |
| |
337 self.do_runall( |
| |
338 &prefix, |
| |
339 cli, |
| |
340 algorithms, |
| |
341 make_plotter, |
| |
342 save_extra, |
| |
343 μ0, |
| |
344 |p| match regterm { |
| |
345 RegTermPy::Radon(α) => { |
| |
346 run_fb(&dataterm, &RadonRegTerm(α), &RadonSquared, p, |_| { |
| |
347 Err(NotImplemented.into()) |
| |
348 }) |
| |
349 .map(|μ| (μ, ())) |
| |
350 } |
| |
351 RegTermPy::NonnegRadon(α) => { |
| |
352 run_fb(&dataterm, &NonnegRadonRegTerm(α), &RadonSquared, p, |_| { |
| |
353 Err(NotImplemented.into()) |
| |
354 }) |
| |
355 .map(|μ| (μ, ())) |
| |
356 } |
| |
357 }, |
| |
358 ) |
| |
359 } |
| |
360 (regterm, Some(auxterm), Some(z_)) => { |
| |
361 debug!("… Extracting auxiliary variable."); |
| |
362 let z0: PyResult<exp_f64_2_auxvar::AuxVar<'_>> = z_.extract(py); |
| |
363 match z0 { |
| |
364 Ok(z) => { |
| |
365 debug!("… Extracting data term."); |
| |
366 let dataterm: exp_f64_2_auxvar::DataTerm<'_> = |
| |
367 problem.dataterm.extract(py)?; |
| |
368 debug!("… Extracting plotter."); |
| |
369 let plot_factory: exp_f64_2::PlotFactory<'_> = |
| |
370 if let Some(ref p) = problem.plot_factory { |
| |
371 p.extract(py)? |
| |
372 } else { |
| |
373 PythonPlotFactory::dummy() |
| |
374 }; |
| |
375 let make_plotter = |prefix| plot_factory.produce(prefix); |
| |
376 debug!("… Extracting auxiliary term."); |
| |
377 let auxterm: exp_f64_2_auxvar::AuxTerm<'_> = auxterm.extract(py)?; |
| |
378 |
| |
379 debug!("Running."); |
| |
380 self.do_runall( |
| |
381 &prefix, |
| |
382 cli, |
| |
383 algorithms, |
| |
384 make_plotter, |
| |
385 save_extra, |
| |
386 (μ0, z), |
| |
387 |p| match regterm { |
| |
388 RegTermPy::Radon(α) => run_fb_pair( |
| |
389 &dataterm, |
| |
390 &RadonRegTerm(α), |
| |
391 &RadonSquared, |
| |
392 &auxterm, |
| |
393 p, |
| |
394 |_| Err(NotImplemented.into()), |
| |
395 ) |
| |
396 .map(|Pair(μ, _)| (μ, ())), |
| |
397 RegTermPy::NonnegRadon(α) => run_fb_pair( |
| |
398 &dataterm, |
| |
399 &NonnegRadonRegTerm(α), |
| |
400 &RadonSquared, |
| |
401 &auxterm, |
| |
402 p, |
| |
403 |_| Err(NotImplemented.into()), |
| |
404 ) |
| |
405 .map(|Pair(μ, _)| (μ, ())), |
| |
406 }, |
| |
407 ) |
| |
408 } |
| |
409 Err(_) => { |
| |
410 let z: exp_f64_2_auxvar_scalar::AuxVar<'_> = z_.extract(py)?; |
| |
411 debug!("… Extracting data term."); |
| |
412 let dataterm: exp_f64_2_auxvar_scalar::DataTerm<'_> = |
| |
413 problem.dataterm.extract(py)?; |
| |
414 debug!("… Extracting plotter."); |
| |
415 let plot_factory: exp_f64_2::PlotFactory<'_> = |
| |
416 if let Some(ref p) = problem.plot_factory { |
| |
417 p.extract(py)? |
| |
418 } else { |
| |
419 PythonPlotFactory::dummy() |
| |
420 }; |
| |
421 let make_plotter = |prefix| plot_factory.produce(prefix); |
| |
422 debug!("… Extracting auxiliary term."); |
| |
423 let auxterm: exp_f64_2_auxvar_scalar::AuxTerm<'_> = |
| |
424 auxterm.extract(py)?; |
| |
425 |
| |
426 debug!("Running."); |
| |
427 self.do_runall( |
| |
428 &prefix, |
| |
429 cli, |
| |
430 algorithms, |
| |
431 make_plotter, |
| |
432 save_extra, |
| |
433 (μ0, z), |
| |
434 |p| match regterm { |
| |
435 RegTermPy::Radon(α) => run_fb_pair( |
| |
436 &dataterm, |
| |
437 &RadonRegTerm(α), |
| |
438 &RadonSquared, |
| |
439 &auxterm, |
| |
440 p, |
| |
441 |_| Err(NotImplemented.into()), |
| |
442 ) |
| |
443 .map(|Pair(μ, _)| (μ, ())), |
| |
444 RegTermPy::NonnegRadon(α) => run_fb_pair( |
| |
445 &dataterm, |
| |
446 &NonnegRadonRegTerm(α), |
| |
447 &RadonSquared, |
| |
448 &auxterm, |
| |
449 p, |
| |
450 |_| Err(NotImplemented.into()), |
| |
451 ) |
| |
452 .map(|Pair(μ, _)| (μ, ())), |
| |
453 }, |
| |
454 ) |
| |
455 } |
| |
456 } |
| |
457 } |
| |
458 (_, _, _) => { |
| |
459 bail!("Errors in problem {} setup: auxiliary term or auxiliary variable initialisation given without the other", self.name); |
| |
460 } |
| |
461 } |
| |
462 })?; |
| |
463 |
| |
464 // Should run the experiment here, going through all algorithms. |
| |
465 Ok(()) |
| |
466 } |
| |
467 |
| |
468 fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<f64> { |
| |
469 self.algorithm_overrides |
| |
470 .get(&alg) |
| |
471 .cloned() |
| |
472 .unwrap_or(Default::default()) |
| |
473 } |
| |
474 } |