--- a/src/run.rs Thu Jan 23 23:35:28 2025 +0100 +++ b/src/run.rs Thu Jan 23 23:34:05 2025 +0100 @@ -56,7 +56,7 @@ use crate::kernels::*; use crate::types::*; use crate::measures::*; -use crate::measures::merging::SpikeMerging; +use crate::measures::merging::{SpikeMerging,SpikeMergingMethod}; use crate::forward_model::*; use crate::forward_model::sensor_grid::{ SensorGrid, @@ -95,7 +95,7 @@ pointsource_fw_reg, //WeightOptim, }; -//use crate::subproblem::InnerSettings; +use crate::subproblem::{InnerSettings, InnerMethod}; use crate::seminorms::*; use crate::plot::*; use crate::{AlgorithmOverrides, CommandLineArgs}; @@ -146,6 +146,13 @@ impl<F : ClapFloat> AlgorithmConfig<F> { /// Override supported parameters based on the command line. pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { + let override_merging = |g : SpikeMergingMethod<F>| { + SpikeMergingMethod { + enabled : cli.merge.unwrap_or(g.enabled), + radius : cli.merge_radius.unwrap_or(g.radius), + interp : cli.merge_interp.unwrap_or(g.interp), + } + }; let override_fb_generic = |g : FBGenericConfig<F>| { FBGenericConfig { bootstrap_insertions : cli.bootstrap_insertions @@ -153,8 +160,9 @@ .map_or(g.bootstrap_insertions, |n| Some((n[0], n[1]))), merge_every : cli.merge_every.unwrap_or(g.merge_every), - merging : cli.merging.clone().unwrap_or(g.merging), - final_merging : cli.final_merging.clone().unwrap_or(g.final_merging), + merging : override_merging(g.merging), + final_merging : cli.final_merging.unwrap_or(g.final_merging), + fitness_merging : cli.fitness_merging.unwrap_or(g.fitness_merging), tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance), .. g } @@ -162,8 +170,8 @@ let override_transport = |g : TransportConfig<F>| { TransportConfig { θ0 : cli.theta0.unwrap_or(g.θ0), - tolerance_ω: cli.transport_tolerance_omega.unwrap_or(g.tolerance_ω), - tolerance_dv: cli.transport_tolerance_dv.unwrap_or(g.tolerance_dv), + tolerance_mult_pos: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_pos), + tolerance_mult_pri: cli.transport_tolerance_pri.unwrap_or(g.tolerance_mult_pri), adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), .. g } @@ -189,7 +197,7 @@ .. pdps }, prox), FW(fw) => FW(FWConfig { - merging : cli.merging.clone().unwrap_or(fw.merging), + merging : override_merging(fw.merging), tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), .. fw }), @@ -282,6 +290,14 @@ /// Returns the algorithm configuration corresponding to the algorithm shorthand pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { use DefaultAlgorithm::*; + let radon_insertion = FBGenericConfig { + merging : SpikeMergingMethod{ interp : false, .. Default::default() }, + inner : InnerSettings { + method : InnerMethod::PDPS, // SSN not implemented + .. Default::default() + }, + .. Default::default() + }; match *self { FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave), FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave), @@ -297,12 +313,30 @@ // Radon variants - RadonFB => AlgorithmConfig::FB(Default::default(), ProxTerm::RadonSquared), - RadonFISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::RadonSquared), - RadonPDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::RadonSquared), - RadonSlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::RadonSquared), - RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::RadonSquared), - RadonForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::RadonSquared), + RadonFB => AlgorithmConfig::FB( + FBConfig{ generic : radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared + ), + RadonFISTA => AlgorithmConfig::FISTA( + FBConfig{ generic : radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared + ), + RadonPDPS => AlgorithmConfig::PDPS( + PDPSConfig{ generic : radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared + ), + RadonSlidingFB => AlgorithmConfig::SlidingFB( + SlidingFBConfig{ insertion : radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared + ), + RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS( + SlidingPDPSConfig{ insertion : radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared + ), + RadonForwardPDPS => AlgorithmConfig::ForwardPDPS( + ForwardPDPSConfig{ insertion : radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared + ), } } @@ -340,6 +374,12 @@ Iter, } +impl Default for PlotLevel { + fn default() -> Self { + Self::Data + } +} + type DefaultBT<F, const N : usize> = BT< DynamicDepth, F, @@ -418,7 +458,7 @@ /// Struct for experiment configurations #[derive(Debug, Clone, Serialize)] pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize> -where F : Float, +where F : Float + ClapFloat, [usize; N] : Serialize, NoiseDistr : Distribution<F>, S : Sensor<F, N>, @@ -448,13 +488,14 @@ /// Data term pub dataterm : DataTerm, /// A map of default configurations for algorithms - #[serde(skip)] - pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, + pub algorithm_overrides : HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>, + /// Default merge radius + pub default_merge_radius : F, } #[derive(Debug, Clone, Serialize)] pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize> -where F : Float, +where F : Float + ClapFloat, [usize; N] : Serialize, NoiseDistr : Distribution<F>, S : Sensor<F, N>, @@ -477,7 +518,7 @@ algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; /// Return algorithm default config - fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>>; + fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F>; } /// Helper function to print experiment start message and save setup. @@ -494,8 +535,8 @@ let Named { name : experiment_name, data } = experiment; println!("{}\n{}", - format!("Performing experiment {}…", experiment_name).cyan(), - format!("{:?}", data).bright_black()); + format!("Performing experiment {}…", experiment_name).cyan(), + format!("Experiment settings: {}", serde_json::to_string(&data)?).bright_black()); // Set up output directory let prefix = format!("{}/{}/", cli.outdir, experiment_name); @@ -525,7 +566,7 @@ >; /// Helper function to run all algorithms on an experiment. -fn do_runall<F : Float, Z, const N : usize>( +fn do_runall<F : Float + for<'b> Deserialize<'b>, Z, const N : usize>( experiment_name : &String, prefix : &String, cli : &CommandLineArgs, @@ -547,7 +588,7 @@ let iterator_options = AlgIteratorOptions{ max_iter : cli.max_iter, verbose_iter : cli.verbose_iter - .map_or(Verbose::Logarithmic(10), + .map_or(Verbose::LogarithmicCap{base : 10, cap : 2}, |n| Verbose::Every(n)), quiet : cli.quiet, }; @@ -565,8 +606,8 @@ let running = if !cli.quiet { format!("{}\n{}\n{}\n", format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), - format!("{:?}", iterator_options).bright_black(), - format!("{:?}", alg).bright_black()) + format!("Iteration settings: {}", serde_json::to_string(&iterator_options)?).bright_black(), + format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black()) } else { "".to_string() }; @@ -614,16 +655,17 @@ save_extra(mkname(""), z)?; //logger.write_csv(mkname("log.txt"))?; logs.push((mkname("log.txt"), logger)); - } + } - save_logs(logs) + save_logs(logs, format!("{prefix}valuerange.json"), cli.load_valuerange) } #[replace_float_literals(F::cast_from(literal))] impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N : usize> RunnableExperiment<F> for Named<ExperimentV2<F, NoiseDistr, S, K, P, N>> where - F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, + F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> + + Default + for<'b> Deserialize<'b>, [usize; N] : Serialize, S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, @@ -655,8 +697,11 @@ // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, { - fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> { - self.data.algorithm_defaults.get(&alg).cloned() + fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { + AlgorithmOverrides { + merge_radius : Some(self.data.default_merge_radius), + .. self.data.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) + } } fn runall(&self, cli : &CommandLineArgs, @@ -887,7 +932,8 @@ impl<F, NoiseDistr, S, K, P, B, /*PreadjointCodomain,*/ const N : usize> RunnableExperiment<F> for Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> where - F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, + F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> + + Default + for<'b> Deserialize<'b>, [usize; N] : Serialize, S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, @@ -920,8 +966,11 @@ // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, { - fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> { - self.data.base.algorithm_defaults.get(&alg).cloned() + fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { + AlgorithmOverrides { + merge_radius : Some(self.data.base.default_merge_radius), + .. self.data.base.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) + } } fn runall(&self, cli : &CommandLineArgs, @@ -1077,16 +1126,31 @@ } } +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +struct ValueRange<F : Float> { + ini : F, + min : F, +} + +impl<F : Float> ValueRange<F> { + fn expand_with(self, other : Self) -> Self { + ValueRange { + ini : self.ini.max(other.ini), + min : self.min.min(other.min), + } + } +} /// Calculative minimum and maximum values of all the `logs`, and save them into /// corresponding file names given as the first elements of the tuples in the vectors. -fn save_logs<F : Float, const N : usize>( - logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)> +fn save_logs<F : Float + for<'b> Deserialize<'b>, const N : usize>( + logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)>, + valuerange_file : String, + load_valuerange : bool, ) -> DynError { // Process logs for relative values println!("{}", "Processing logs…"); - // Find minimum value and initial value within a single log let proc_single_log = |log : &Logger<Timed<IterInfo<F, N>>>| { let d = log.data(); @@ -1094,15 +1158,22 @@ .map(|i| i.data.value) .reduce(NumTraitsFloat::min); d.first() - .map(|i| i.data.value) - .zip(mi) + .map(|i| i.data.value) + .zip(mi) + .map(|(ini, min)| ValueRange{ ini, min }) }; // Find minimum and maximum value over all logs - let (v_ini, v_min) = logs.iter() - .filter_map(|&(_, ref log)| proc_single_log(log)) - .reduce(|(i1, m1), (i2, m2)| (i1.max(i2), m1.min(m2))) - .ok_or(anyhow!("No algorithms found"))?; + let mut v = logs.iter() + .filter_map(|&(_, ref log)| proc_single_log(log)) + .reduce(|v1, v2| v1.expand_with(v2)) + .ok_or(anyhow!("No algorithms found"))?; + + // Load existing range + if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() { + let data = std::fs::read_to_string(&valuerange_file)?; + v = v.expand_with(serde_json::from_str(&data)?); + } let logmap = |Timed { cpu_time, iter, data }| { let IterInfo { @@ -1128,7 +1199,7 @@ // }, // _ => value, // }; - let relative_value = (value - v_min)/(v_ini - v_min); + let relative_value = (value - v.min)/(v.ini - v.min); CSVLog { iter, value, @@ -1145,6 +1216,8 @@ println!("{}", "Saving logs …".green()); + serde_json::to_writer_pretty(std::fs::File::create(&valuerange_file)?, &v)?; + for (name, logger) in logs { logger.map(logmap).write_csv(name)?; }