src/main.rs

branch
dev
changeset 35
b087e3eab191
parent 34
efa60bc4f743
child 37
c5d8bd1a7728
--- a/src/main.rs	Thu Aug 29 00:00:00 2024 -0500
+++ b/src/main.rs	Tue Dec 31 09:25:45 2024 -0500
@@ -31,6 +31,7 @@
 pub mod seminorms;
 pub mod transport;
 pub mod forward_model;
+pub mod preadjoint_helper;
 pub mod plot;
 pub mod subproblem;
 pub mod tolerance;
@@ -39,6 +40,8 @@
 pub mod fb;
 pub mod radon_fb;
 pub mod sliding_fb;
+pub mod sliding_pdps;
+pub mod forward_pdps;
 pub mod frank_wolfe;
 pub mod pdps;
 pub mod run;
@@ -166,6 +169,13 @@
     tau0 : Option<F>,
 
     #[arg(long, requires = "algorithm")]
+    /// Second primal step length parameter override for SlidingPDPS.
+    ///
+    /// Only use if running just a single algorithm, as different algorithms have different
+    /// regularisation parameters.
+    sigmap0 : Option<F>,
+
+    #[arg(long, requires = "algorithm")]
     /// Dual step length parameter override for --algorithm.
     ///
     /// Only use if running just a single algorithm, as different algorithms have different
@@ -173,7 +183,7 @@
     sigma0 : Option<F>,
 
     #[arg(long)]
-    /// Normalised transport step length for sliding_fb.
+    /// Normalised transport step length for sliding methods.
     theta0 : Option<F>,
 
     #[arg(long)]
@@ -184,15 +194,23 @@
     /// Transport toleranced wrt. ∇v
     transport_tolerance_dv : Option<F>,
 
+    #[arg(long)]
+    /// Transport adaptation factor. Must be in (0, 1).
+    transport_adaptation : Option<F>,
+
+    #[arg(long)]
+    /// Minimal step length parameter for sliding methods.
+    tau0_min : Option<F>,
+
     #[arg(value_enum, long)]
     /// PDPS acceleration, when available.
     acceleration : Option<pdps::Acceleration>,
 
-    #[arg(long)]
-    /// Perform postprocess weight optimisation for saved iterations
-    ///
-    /// Only affects FB, FISTA, and PDPS.
-    postprocessing : Option<bool>,
+    // #[arg(long)]
+    // /// Perform postprocess weight optimisation for saved iterations
+    // ///
+    // /// Only affects FB, FISTA, and PDPS.
+    // postprocessing : Option<bool>,
 
     #[arg(value_name = "n", long)]
     /// Merging frequency, if merging enabled (every n iterations)
@@ -246,9 +264,14 @@
     for experiment_shorthand in cli.experiments.iter().unique() {
         let experiment = experiment_shorthand.get_experiment(&cli.experiment_overrides).unwrap();
         let mut algs : Vec<Named<AlgorithmConfig<float>>>
-            = cli.algorithm.iter()
-                            .map(|alg| experiment.algorithm_defaults(*alg, &cli.algoritm_overrides))
-                            .collect();
+            = cli.algorithm
+                 .iter()
+                 .map(|alg| alg.to_named(
+                    experiment.algorithm_defaults(*alg)
+                              .unwrap_or_else(|| alg.default_config())
+                              .cli_override(&cli.algoritm_overrides)
+                 ))
+                 .collect();
         for filename in cli.saved_algorithm.iter() {
             let f = std::fs::File::open(filename).unwrap();
             let alg = serde_json::from_reader(f).unwrap();

mercurial