src/run.rs

branch
dev
changeset 39
6316d68b58af
parent 37
c5d8bd1a7728
child 41
b6bdb6cb4d44
--- a/src/run.rs	Thu Jan 23 23:35:28 2025 +0100
+++ b/src/run.rs	Thu Jan 23 23:34:05 2025 +0100
@@ -56,7 +56,7 @@
 use crate::kernels::*;
 use crate::types::*;
 use crate::measures::*;
-use crate::measures::merging::SpikeMerging;
+use crate::measures::merging::{SpikeMerging,SpikeMergingMethod};
 use crate::forward_model::*;
 use crate::forward_model::sensor_grid::{
     SensorGrid,
@@ -95,7 +95,7 @@
     pointsource_fw_reg,
     //WeightOptim,
 };
-//use crate::subproblem::InnerSettings;
+use crate::subproblem::{InnerSettings, InnerMethod};
 use crate::seminorms::*;
 use crate::plot::*;
 use crate::{AlgorithmOverrides, CommandLineArgs};
@@ -146,6 +146,13 @@
 impl<F : ClapFloat> AlgorithmConfig<F> {
     /// Override supported parameters based on the command line.
     pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self {
+        let override_merging = |g : SpikeMergingMethod<F>| {
+            SpikeMergingMethod {
+                enabled : cli.merge.unwrap_or(g.enabled),
+                radius : cli.merge_radius.unwrap_or(g.radius),
+                interp : cli.merge_interp.unwrap_or(g.interp),
+            }
+        };
         let override_fb_generic = |g : FBGenericConfig<F>| {
             FBGenericConfig {
                 bootstrap_insertions : cli.bootstrap_insertions
@@ -153,8 +160,9 @@
                                           .map_or(g.bootstrap_insertions,
                                                   |n| Some((n[0], n[1]))),
                 merge_every : cli.merge_every.unwrap_or(g.merge_every),
-                merging : cli.merging.clone().unwrap_or(g.merging),
-                final_merging : cli.final_merging.clone().unwrap_or(g.final_merging),
+                merging : override_merging(g.merging),
+                final_merging : cli.final_merging.unwrap_or(g.final_merging),
+                fitness_merging : cli.fitness_merging.unwrap_or(g.fitness_merging),
                 tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance),
                 .. g
             }
@@ -162,8 +170,8 @@
         let override_transport = |g : TransportConfig<F>| {
             TransportConfig {
                 θ0 : cli.theta0.unwrap_or(g.θ0),
-                tolerance_ω: cli.transport_tolerance_omega.unwrap_or(g.tolerance_ω),
-                tolerance_dv: cli.transport_tolerance_dv.unwrap_or(g.tolerance_dv),
+                tolerance_mult_pos: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_pos),
+                tolerance_mult_pri: cli.transport_tolerance_pri.unwrap_or(g.tolerance_mult_pri),
                 adaptation: cli.transport_adaptation.unwrap_or(g.adaptation),
                 .. g
             }
@@ -189,7 +197,7 @@
                 .. pdps
             }, prox),
             FW(fw) => FW(FWConfig {
-                merging : cli.merging.clone().unwrap_or(fw.merging),
+                merging : override_merging(fw.merging),
                 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance),
                 .. fw
             }),
@@ -282,6 +290,14 @@
     /// Returns the algorithm configuration corresponding to the algorithm shorthand
     pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> {
         use DefaultAlgorithm::*;
+        let radon_insertion = FBGenericConfig {
+            merging : SpikeMergingMethod{ interp : false, .. Default::default() },
+            inner : InnerSettings {
+                method : InnerMethod::PDPS, // SSN not implemented
+                .. Default::default()
+            },
+            .. Default::default()
+        };
         match *self {
             FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave),
             FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave),
@@ -297,12 +313,30 @@
 
             // Radon variants
 
-            RadonFB => AlgorithmConfig::FB(Default::default(), ProxTerm::RadonSquared),
-            RadonFISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::RadonSquared),
-            RadonPDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::RadonSquared),
-            RadonSlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::RadonSquared),
-            RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::RadonSquared),
-            RadonForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::RadonSquared),
+            RadonFB => AlgorithmConfig::FB(
+                FBConfig{ generic : radon_insertion, ..Default::default() },
+                ProxTerm::RadonSquared
+            ),
+            RadonFISTA => AlgorithmConfig::FISTA(
+                FBConfig{ generic : radon_insertion, ..Default::default() },
+                ProxTerm::RadonSquared
+            ),
+            RadonPDPS => AlgorithmConfig::PDPS(
+                PDPSConfig{ generic : radon_insertion, ..Default::default() },
+                ProxTerm::RadonSquared
+            ),
+            RadonSlidingFB => AlgorithmConfig::SlidingFB(
+                SlidingFBConfig{ insertion : radon_insertion, ..Default::default() },
+                ProxTerm::RadonSquared
+            ),
+            RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS(
+                SlidingPDPSConfig{ insertion : radon_insertion, ..Default::default() },
+                ProxTerm::RadonSquared
+            ),
+            RadonForwardPDPS => AlgorithmConfig::ForwardPDPS(
+                ForwardPDPSConfig{ insertion : radon_insertion, ..Default::default() },
+                ProxTerm::RadonSquared
+            ),
         }
     }
 
@@ -340,6 +374,12 @@
     Iter,
 }
 
+impl Default for PlotLevel {
+    fn default() -> Self {
+        Self::Data
+    }
+}
+
 type DefaultBT<F, const N : usize> = BT<
     DynamicDepth,
     F,
@@ -418,7 +458,7 @@
 /// Struct for experiment configurations
 #[derive(Debug, Clone, Serialize)]
 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize>
-where F : Float,
+where F : Float + ClapFloat,
       [usize; N] : Serialize,
       NoiseDistr : Distribution<F>,
       S : Sensor<F, N>,
@@ -448,13 +488,14 @@
     /// Data term
     pub dataterm : DataTerm,
     /// A map of default configurations for algorithms
-    #[serde(skip)]
-    pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>,
+    pub algorithm_overrides : HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>,
+    /// Default merge radius
+    pub default_merge_radius : F,
 }
 
 #[derive(Debug, Clone, Serialize)]
 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize>
-where F : Float,
+where F : Float + ClapFloat,
       [usize; N] : Serialize,
       NoiseDistr : Distribution<F>,
       S : Sensor<F, N>,
@@ -477,7 +518,7 @@
               algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError;
 
     /// Return algorithm default config
-    fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>>;
+    fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F>;
 }
 
 /// Helper function to print experiment start message and save setup.
@@ -494,8 +535,8 @@
     let Named { name : experiment_name, data } = experiment;
 
     println!("{}\n{}",
-                format!("Performing experiment {}…", experiment_name).cyan(),
-                format!("{:?}", data).bright_black());
+             format!("Performing experiment {}…", experiment_name).cyan(),
+             format!("Experiment settings: {}", serde_json::to_string(&data)?).bright_black());
 
     // Set up output directory
     let prefix = format!("{}/{}/", cli.outdir, experiment_name);
@@ -525,7 +566,7 @@
 >;
 
 /// Helper function to run all algorithms on an experiment.
-fn do_runall<F : Float, Z, const N : usize>(
+fn do_runall<F : Float + for<'b> Deserialize<'b>, Z, const N : usize>(
     experiment_name : &String,
     prefix : &String,
     cli : &CommandLineArgs,
@@ -547,7 +588,7 @@
     let iterator_options = AlgIteratorOptions{
             max_iter : cli.max_iter,
             verbose_iter : cli.verbose_iter
-                                .map_or(Verbose::Logarithmic(10),
+                                .map_or(Verbose::LogarithmicCap{base : 10, cap : 2},
                                         |n| Verbose::Every(n)),
             quiet : cli.quiet,
     };
@@ -565,8 +606,8 @@
         let running = if !cli.quiet {
             format!("{}\n{}\n{}\n",
                     format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(),
-                    format!("{:?}", iterator_options).bright_black(),
-                    format!("{:?}", alg).bright_black())
+                    format!("Iteration settings: {}", serde_json::to_string(&iterator_options)?).bright_black(),
+                    format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black())
         } else {
             "".to_string()
         };
@@ -614,16 +655,17 @@
         save_extra(mkname(""), z)?;
         //logger.write_csv(mkname("log.txt"))?;
         logs.push((mkname("log.txt"), logger));
-            }
+    }
 
-    save_logs(logs)
+    save_logs(logs, format!("{prefix}valuerange.json"), cli.load_valuerange)
 }
 
 #[replace_float_literals(F::cast_from(literal))]
 impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N : usize> RunnableExperiment<F> for
 Named<ExperimentV2<F, NoiseDistr, S, K, P, N>>
 where
-    F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
+    F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>
+        + Default + for<'b> Deserialize<'b>,
     [usize; N] : Serialize,
     S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
     P : Spread<F, N> + Copy + Serialize + std::fmt::Debug,
@@ -655,8 +697,11 @@
     // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
 {
 
-    fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> {
-        self.data.algorithm_defaults.get(&alg).cloned()
+    fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> {
+        AlgorithmOverrides {
+            merge_radius : Some(self.data.default_merge_radius),
+            .. self.data.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default())
+        }
     }
 
     fn runall(&self, cli : &CommandLineArgs,
@@ -887,7 +932,8 @@
 impl<F, NoiseDistr, S, K, P, B, /*PreadjointCodomain,*/ const N : usize> RunnableExperiment<F> for
 Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>>
 where
-    F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
+    F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>
+        + Default + for<'b> Deserialize<'b>,
     [usize; N] : Serialize,
     S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
     P : Spread<F, N> + Copy + Serialize + std::fmt::Debug,
@@ -920,8 +966,11 @@
     // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
 {
 
-    fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> {
-        self.data.base.algorithm_defaults.get(&alg).cloned()
+    fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> {
+        AlgorithmOverrides {
+            merge_radius : Some(self.data.base.default_merge_radius),
+            .. self.data.base.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default())
+        }
     }
 
     fn runall(&self, cli : &CommandLineArgs,
@@ -1077,16 +1126,31 @@
     }
 }
 
+#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
+struct ValueRange<F : Float> {
+    ini : F,
+    min : F,
+}
+
+impl<F : Float> ValueRange<F> {
+    fn expand_with(self, other : Self) -> Self {
+        ValueRange {
+            ini : self.ini.max(other.ini),
+            min : self.min.min(other.min),
+        }
+    }
+}
 
 /// Calculative minimum and maximum values of all the `logs`, and save them into
 /// corresponding file names given as the first elements of the tuples in the vectors.
-fn save_logs<F : Float, const N : usize>(
-    logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)>
+fn save_logs<F : Float + for<'b> Deserialize<'b>, const N : usize>(
+    logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)>,
+    valuerange_file : String,
+    load_valuerange : bool,
 ) -> DynError {
     // Process logs for relative values
     println!("{}", "Processing logs…");
 
-
     // Find minimum value and initial value within a single log
     let proc_single_log = |log : &Logger<Timed<IterInfo<F, N>>>| {
         let d = log.data();
@@ -1094,15 +1158,22 @@
                   .map(|i| i.data.value)
                   .reduce(NumTraitsFloat::min);
         d.first()
-            .map(|i| i.data.value)
-            .zip(mi)
+         .map(|i| i.data.value)
+         .zip(mi)
+         .map(|(ini, min)| ValueRange{ ini, min })
     };
 
     // Find minimum and maximum value over all logs
-    let (v_ini, v_min) = logs.iter()
-                             .filter_map(|&(_, ref log)| proc_single_log(log))
-                             .reduce(|(i1, m1), (i2, m2)| (i1.max(i2), m1.min(m2)))
-                             .ok_or(anyhow!("No algorithms found"))?;
+    let mut v = logs.iter()
+                    .filter_map(|&(_, ref log)| proc_single_log(log))
+                    .reduce(|v1, v2| v1.expand_with(v2))
+                    .ok_or(anyhow!("No algorithms found"))?;
+
+    // Load existing range
+    if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() {
+        let data = std::fs::read_to_string(&valuerange_file)?;
+        v = v.expand_with(serde_json::from_str(&data)?);
+    }
 
     let logmap = |Timed { cpu_time, iter, data }| {
         let IterInfo {
@@ -1128,7 +1199,7 @@
         //     },
         //     _ => value,
         // };
-        let relative_value = (value - v_min)/(v_ini - v_min);
+        let relative_value = (value - v.min)/(v.ini - v.min);
         CSVLog {
             iter,
             value,
@@ -1145,6 +1216,8 @@
 
     println!("{}", "Saving logs …".green());
 
+    serde_json::to_writer_pretty(std::fs::File::create(&valuerange_file)?, &v)?;
+
     for (name, logger) in logs {
         logger.map(logmap).write_csv(name)?;
     }

mercurial