src/run.rs

changeset 24
d29d1fcf5423
parent 23
9869fa1e0ccd
child 25
79943be70720
--- a/src/run.rs	Sun Dec 11 23:19:17 2022 +0200
+++ b/src/run.rs	Sun Dec 11 23:25:53 2022 +0200
@@ -34,7 +34,7 @@
 use alg_tools::mapping::RealMapping;
 use alg_tools::nalgebra_support::ToNalgebraRealField;
 use alg_tools::euclidean::Euclidean;
-use alg_tools::norms::{Norm, L1};
+use alg_tools::norms::L1;
 use alg_tools::lingrid::lingrid;
 use alg_tools::sets::SetOrd;
 
@@ -45,13 +45,14 @@
 use crate::forward_model::*;
 use crate::fb::{
     FBConfig,
-    pointsource_fb,
-    FBMetaAlgorithm, FBGenericConfig,
+    pointsource_fb_reg,
+    FBMetaAlgorithm,
+    FBGenericConfig,
 };
 use crate::pdps::{
     PDPSConfig,
     L2Squared,
-    pointsource_pdps,
+    pointsource_pdps_reg,
 };
 use crate::frank_wolfe::{
     FWConfig,
@@ -65,6 +66,7 @@
 use crate::plot::*;
 use crate::{AlgorithmOverrides, CommandLineArgs};
 use crate::tolerance::Tolerance;
+use crate::regularisation::{Regularisation, RadonRegTerm, NonnegRadonRegTerm};
 
 /// Available algorithms and their configurations
 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
@@ -276,7 +278,7 @@
 
 /// Struct for experiment configurations
 #[derive(Debug, Clone, Serialize)]
-pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize>
+pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize>
 where F : Float,
       [usize; N] : Serialize,
       NoiseDistr : Distribution<F>,
@@ -300,8 +302,8 @@
     pub kernel : K,
     /// True point sources
     pub μ_hat : DiscreteMeasure<Loc<F, N>, F>,
-    /// Regularisation parameter
-    pub α : F,
+    /// Regularisation term and parameter
+    pub regularisation : Regularisation<F>,
     /// For plotting : how wide should the kernels be plotted
     pub kernel_plot_width : F,
     /// Data term
@@ -322,8 +324,12 @@
     -> Named<AlgorithmConfig<F>>;
 }
 
+// *** macro boilerplate ***
+macro_rules! impl_experiment {
+($type:ident, $reg_field:ident, $reg_convert:path) => {
+// *** macro ***
 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for
-Named<Experiment<F, NoiseDistr, S, K, P, N>>
+Named<$type<F, NoiseDistr, S, K, P, N>>
 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
       [usize; N] : Serialize,
       S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
@@ -356,12 +362,14 @@
         // Get experiment configuration
         let &Named {
             name : ref experiment_name,
-            data : Experiment {
+            data : $type {
                 domain, sensor_count, ref noise_distr, sensor, spread, kernel,
-                ref μ_hat, α, kernel_plot_width, dataterm, noise_seed,
+                ref μ_hat, /*regularisation,*/ kernel_plot_width, dataterm, noise_seed,
                 ..
             }
         } = self;
+        #[allow(deprecated)]
+        let regularisation = $reg_convert(self.data.$reg_field);
 
         println!("{}\n{}",
                  format!("Performing experiment {}…", experiment_name).cyan(),
@@ -420,7 +428,12 @@
                         format!("{:?}", iterator_options).bright_black(),
                         format!("{:?}", alg).bright_black());
             };
-
+            let not_implemented = || {
+                let msg = format!("Algorithm “{alg_name}” not implemented for \
+                                   dataterm {dataterm:?} and regularisation {regularisation:?}. \
+                                   Skipping.").red();
+                eprintln!("{}", msg);
+            };
             // Create Logger and IteratorFactory
             let mut logger = Logger::new();
             let findim_data = prepare_optimise_weights(&opA);
@@ -437,20 +450,18 @@
                     this_iters,
                     ..
                 } = data;
-                let post_value = match postprocessing {
-                    None => value,
-                    Some(mut μ) => {
-                        match dataterm {
-                            DataTerm::L2Squared => {
-                                optimise_weights(
-                                    &mut μ, &opA, &b, α, &findim_data, &inner_config,
-                                    inner_it
-                                );
-                                dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon)
-                            },
-                            _ => value,
-                        }
-                    }
+                let post_value = match (postprocessing, dataterm, regularisation) {
+                    (Some(mut μ), DataTerm::L2Squared, Regularisation::Radon(α)) => {
+                        // Comparison postprocessing is only implemented for the case handled
+                        // by the FW variants.
+                        optimise_weights(
+                            &mut μ, &opA, &b, α, &findim_data, &inner_config,
+                            inner_it
+                        );
+                        dataterm.value_at_residual(opA.apply(&μ) - &b)
+                            + regularisation.apply(&μ)
+                    },
+                    _ => value,
                 };
                 CSVLog {
                     iter,
@@ -477,30 +488,72 @@
             // Run the algorithm
             let start = Instant::now();
             let start_cpu = ProcessTime::now();
-            let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) {
-                (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => {
-                    running();
-                    pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter)
+            let μ = match alg {
+                AlgorithmConfig::FB(ref algconfig) => {
+                    match (regularisation, dataterm) {
+                        (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
+                            running();
+                            pointsource_fb_reg(
+                                &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
+                                iterator, plotter
+                            )
+                        },
+                        (Regularisation::Radon(α), DataTerm::L2Squared) => {
+                            running();
+                            pointsource_fb_reg(
+                                &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
+                                iterator, plotter
+                            )
+                        },
+                        _ => {
+                            not_implemented();
+                            continue
+                        }
+                    }
                 },
-                (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => {
-                    running();
-                    pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter)
-                },
-                (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => {
+                AlgorithmConfig::PDPS(ref algconfig) => {
                     running();
-                    pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared)
-                },
-                (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => {
-                    running();
-                    pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1)
+                    match (regularisation, dataterm) {
+                        (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
+                            pointsource_pdps_reg(
+                                &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
+                                iterator, plotter, L2Squared
+                            )
+                        },
+                        (Regularisation::Radon(α),DataTerm::L2Squared) => {
+                            pointsource_pdps_reg(
+                                &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
+                                iterator, plotter, L2Squared
+                            )
+                        },
+                        (Regularisation::NonnegRadon(α), DataTerm::L1) => {
+                            pointsource_pdps_reg(
+                                &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
+                                iterator, plotter, L1
+                            )
+                        },
+                        (Regularisation::Radon(α), DataTerm::L1) => {
+                            pointsource_pdps_reg(
+                                &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
+                                iterator, plotter, L1
+                            )
+                        },
+                    }
                 },
-                _ =>  {
-                    let msg = format!("Algorithm “{alg_name}” not implemented for \
-                                       dataterm {dataterm:?}. Skipping.").red();
-                    eprintln!("{}", msg);
-                    continue
+                AlgorithmConfig::FW(ref algconfig) => {
+                    match (regularisation, dataterm) {
+                        (Regularisation::Radon(α), DataTerm::L2Squared) => {
+                            running();
+                            pointsource_fw(&opA, &b, α, algconfig, iterator, plotter)
+                        },
+                        _ => {
+                            not_implemented();
+                            continue
+                        }
+                    }
                 }
             };
+
             let elapsed = start.elapsed().as_secs_f64();
             let cpu_time = start_cpu.elapsed().as_secs_f64();
 
@@ -520,6 +573,11 @@
         Ok(())
     }
 }
+// *** macro end boiler plate ***
+}}
+// *** actual code ***
+
+impl_experiment!(ExperimentV2, regularisation, std::convert::identity);
 
 /// Plot experiment setup
 #[replace_float_literals(F::cast_from(literal))]
@@ -589,3 +647,46 @@
     opA.write_observable(&b, pfx("b_noisy"))
 }
 
+//
+// Deprecated interface
+//
+
+/// Struct for experiment configurations
+#[derive(Debug, Clone, Serialize)]
+pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize>
+where F : Float,
+      [usize; N] : Serialize,
+      NoiseDistr : Distribution<F>,
+      S : Sensor<F, N>,
+      P : Spread<F, N>,
+      K : SimpleConvolutionKernel<F, N>,
+{
+    /// Domain $Ω$.
+    pub domain : Cube<F, N>,
+    /// Number of sensors along each dimension
+    pub sensor_count : [usize; N],
+    /// Noise distribution
+    pub noise_distr : NoiseDistr,
+    /// Seed for random noise generation (for repeatable experiments)
+    pub noise_seed : u64,
+    /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$.
+    pub sensor : S,
+    /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$.
+    pub spread : P,
+    /// Kernel $ρ$ of $𝒟$.
+    pub kernel : K,
+    /// True point sources
+    pub μ_hat : DiscreteMeasure<Loc<F, N>, F>,
+    /// Regularisation parameter
+    #[deprecated(note = "Use [`ExperimentV2`], which replaces `α` by more generic `regularisation`")]
+    pub α : F,
+    /// For plotting : how wide should the kernels be plotted
+    pub kernel_plot_width : F,
+    /// Data term
+    pub dataterm : DataTerm,
+    /// A map of default configurations for algorithms
+    #[serde(skip)]
+    pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>,
+}
+
+impl_experiment!(Experiment, α, Regularisation::NonnegRadon);

mercurial