Sketch FBGenericConfig clap draft

Wed, 30 Nov 2022 23:45:04 +0200

author
Tuomo Valkonen <tuomov@iki.fi>
date
Wed, 30 Nov 2022 23:45:04 +0200
changeset 1
d4fd5f32d10e
parent 0
eb3c7813b67a

Sketch FBGenericConfig clap

src/fb.rs file | annotate | diff | comparison | revisions
src/pdps.rs file | annotate | diff | comparison | revisions
src/run.rs file | annotate | diff | comparison | revisions
--- a/src/fb.rs	Thu Dec 01 23:07:35 2022 +0200
+++ b/src/fb.rs	Wed Nov 30 23:45:04 2022 +0200
@@ -84,6 +84,7 @@
 use serde::{Serialize, Deserialize};
 use colored::Colorize;
 use nalgebra::DVector;
+use clap::Parser;
 
 use alg_tools::iterate::{
     AlgIteratorFactory,
@@ -146,6 +147,11 @@
     Zero,
 }
 
+impl Default for InsertionStyle {
+    fn default() -> Self {
+        Self::Reuse
+    }
+}
 /// Meta-algorithm type
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[allow(dead_code)]
@@ -166,10 +172,16 @@
     AfterNth{ n : usize, factor : F },
 }
 
+impl<F : ClapFloat> Default for ErgodicTolerance<F> {
+    fn default() -> Self {
+        Self::NonErgodic
+    }
+}
+
 /// Settings for [`pointsource_fb`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
-pub struct FBConfig<F : Float> {
+pub struct FBConfig<F : ClapFloat> {
     /// Step length scaling
     pub τ0 : F,
     /// Meta-algorithm to apply
@@ -180,24 +192,29 @@
 
 /// Settings for the solution of the stepwise optimality condition in algorithms based on
 /// [`generic_pointsource_fb`].
-#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
+#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug, Parser)]
 #[serde(default)]
-pub struct FBGenericConfig<F : Float> {
+pub struct FBGenericConfig<F : ClapFloat> {
+    #[clap(skip)]
     /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`].
     pub insertion_style : InsertionStyle,
+    #[clap(skip)]
     /// Tolerance for point insertion.
     pub tolerance : Tolerance<F>,
     /// Stop looking for predual maximum (where to isert a new point) below
     /// `tolerance` multiplied by this factor.
     pub insertion_cutoff_factor : F,
+    #[clap(skip)]
     /// Apply tolerance ergodically
     pub ergodic_tolerance : ErgodicTolerance<F>,
+    #[clap(skip)]
     /// Settings for branch and bound refinement when looking for predual maxima
     pub refinement : RefinementSettings<F>,
     /// Maximum insertions within each outer iteration
     pub max_insertions : usize,
     /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations.
     pub bootstrap_insertions : Option<(usize, usize)>,
+    #[clap(skip)]
     /// Inner method settings
     pub inner : InnerSettings<F>,
     /// Spike merging method
@@ -213,7 +230,7 @@
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for FBConfig<F> {
+impl<F : ClapFloat> Default for FBConfig<F> {
     fn default() -> Self {
         FBConfig {
             τ0 : 0.99,
@@ -224,7 +241,7 @@
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for FBGenericConfig<F> {
+impl<F : ClapFloat> Default for FBGenericConfig<F> {
     fn default() -> Self {
         FBGenericConfig {
             insertion_style : InsertionStyle::Reuse,
@@ -457,7 +474,7 @@
     iterator : I,
     plotter : SeqPlotter<F, N>
 ) -> DiscreteMeasure<Loc<F, N>, F>
-where F : Float + ToNalgebraRealField,
+where F : ClapFloat + ToNalgebraRealField,
       I : AlgIteratorFactory<IterInfo<F, N>>,
       for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
                                   //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
@@ -521,7 +538,7 @@
     mut residual : A::Observable,
     mut specialisation : Spec,
 ) -> DiscreteMeasure<Loc<F, N>, F>
-where F : Float + ToNalgebraRealField,
+where F : ClapFloat + ToNalgebraRealField,
       I : AlgIteratorFactory<IterInfo<F, N>>,
       Spec : FBSpecialisation<F, A::Observable, N>,
       A::Observable : std::ops::MulAssign<F>,
--- a/src/pdps.rs	Thu Dec 01 23:07:35 2022 +0200
+++ b/src/pdps.rs	Wed Nov 30 23:45:04 2022 +0200
@@ -107,7 +107,7 @@
 /// Settings for [`pointsource_pdps`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
-pub struct PDPSConfig<F : Float> {
+pub struct PDPSConfig<F : ClapFloat> {
     /// Primal step length scaling. We must have `τ0 * σ0 < 1`.
     pub τ0 : F,
     /// Dual step length scaling. We must have `τ0 * σ0 < 1`.
@@ -119,7 +119,7 @@
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for PDPSConfig<F> {
+impl<F : ClapFloat> Default for PDPSConfig<F> {
     fn default() -> Self {
         let τ0 = 0.5;
         PDPSConfig {
@@ -311,7 +311,7 @@
     plotter : SeqPlotter<F, N>,
     dataterm : D,
 ) -> DiscreteMeasure<Loc<F, N>, F>
-where F : Float + ToNalgebraRealField,
+where F : ClapFloat + ToNalgebraRealField,
       I : AlgIteratorFactory<IterInfo<F, N>>,
       for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>
                                   + std::ops::Add<A::Observable, Output=A::Observable>,
--- a/src/run.rs	Thu Dec 01 23:07:35 2022 +0200
+++ b/src/run.rs	Wed Nov 30 23:45:04 2022 +0200
@@ -67,7 +67,7 @@
 
 /// Available algorithms and their configurations
 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
-pub enum AlgorithmConfig<F : Float> {
+pub enum AlgorithmConfig<F : ClapFloat> {
     FB(FBConfig<F>),
     FW(FWConfig<F>),
     PDPS(PDPSConfig<F>),
@@ -141,7 +141,7 @@
 
 impl DefaultAlgorithm {
     /// Returns the algorithm configuration corresponding to the algorithm shorthand
-    pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> {
+    pub fn default_config<F : ClapFloat>(&self) -> AlgorithmConfig<F> {
         use DefaultAlgorithm::*;
         match *self {
             FB => AlgorithmConfig::FB(Default::default()),
@@ -159,11 +159,11 @@
     }
 
     /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand
-    pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> {
+    pub fn get_named<F : ClapFloat>(&self) -> Named<AlgorithmConfig<F>> {
         self.to_named(self.default_config())
     }
 
-    pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> {
+    pub fn to_named<F : ClapFloat>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> {
         let name = self.to_possible_value().unwrap().get_name().to_string();
         Named{ name , data : alg }
     }
@@ -196,7 +196,7 @@
 
 #[derive(Clone, Debug, Serialize)]
 #[serde(default)]
-pub struct Configuration<F : Float> {
+pub struct Configuration<F : ClapFloat> {
     /// Algorithms to run
     pub algorithms : Vec<Named<AlgorithmConfig<F>>>,
     /// Options for algorithm step iteration (verbosity, etc.)
@@ -286,7 +286,7 @@
 /// Struct for experiment configurations
 #[derive(Debug, Clone, Serialize)]
 pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize>
-where F : Float,
+where F : ClapFloat,
       [usize; N] : Serialize,
       NoiseDistr : Distribution<F>,
       S : Sensor<F, N>,
@@ -548,7 +548,7 @@
     b : &A::Observable,
     kernel_plot_width : F,
 ) -> DynError
-where F : Float + ToNalgebraRealField,
+where F : ClapFloat + ToNalgebraRealField,
       Sensor : RealMapping<F, N> + Support<F, N> + Clone,
       Spread : RealMapping<F, N> + Support<F, N> + Clone,
       Kernel : RealMapping<F, N> + Support<F, N>,

mercurial