Wed, 30 Nov 2022 23:45:04 +0200
Sketch FBGenericConfig clap
src/fb.rs | file | annotate | diff | comparison | revisions | |
src/pdps.rs | file | annotate | diff | comparison | revisions | |
src/run.rs | file | annotate | diff | comparison | revisions |
--- a/src/fb.rs Thu Dec 01 23:07:35 2022 +0200 +++ b/src/fb.rs Wed Nov 30 23:45:04 2022 +0200 @@ -84,6 +84,7 @@ use serde::{Serialize, Deserialize}; use colored::Colorize; use nalgebra::DVector; +use clap::Parser; use alg_tools::iterate::{ AlgIteratorFactory, @@ -146,6 +147,11 @@ Zero, } +impl Default for InsertionStyle { + fn default() -> Self { + Self::Reuse + } +} /// Meta-algorithm type #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[allow(dead_code)] @@ -166,10 +172,16 @@ AfterNth{ n : usize, factor : F }, } +impl<F : ClapFloat> Default for ErgodicTolerance<F> { + fn default() -> Self { + Self::NonErgodic + } +} + /// Settings for [`pointsource_fb`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct FBConfig<F : Float> { +pub struct FBConfig<F : ClapFloat> { /// Step length scaling pub τ0 : F, /// Meta-algorithm to apply @@ -180,24 +192,29 @@ /// Settings for the solution of the stepwise optimality condition in algorithms based on /// [`generic_pointsource_fb`]. -#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] +#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug, Parser)] #[serde(default)] -pub struct FBGenericConfig<F : Float> { +pub struct FBGenericConfig<F : ClapFloat> { + #[clap(skip)] /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`]. pub insertion_style : InsertionStyle, + #[clap(skip)] /// Tolerance for point insertion. pub tolerance : Tolerance<F>, /// Stop looking for predual maximum (where to isert a new point) below /// `tolerance` multiplied by this factor. pub insertion_cutoff_factor : F, + #[clap(skip)] /// Apply tolerance ergodically pub ergodic_tolerance : ErgodicTolerance<F>, + #[clap(skip)] /// Settings for branch and bound refinement when looking for predual maxima pub refinement : RefinementSettings<F>, /// Maximum insertions within each outer iteration pub max_insertions : usize, /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. pub bootstrap_insertions : Option<(usize, usize)>, + #[clap(skip)] /// Inner method settings pub inner : InnerSettings<F>, /// Spike merging method @@ -213,7 +230,7 @@ } #[replace_float_literals(F::cast_from(literal))] -impl<F : Float> Default for FBConfig<F> { +impl<F : ClapFloat> Default for FBConfig<F> { fn default() -> Self { FBConfig { τ0 : 0.99, @@ -224,7 +241,7 @@ } #[replace_float_literals(F::cast_from(literal))] -impl<F : Float> Default for FBGenericConfig<F> { +impl<F : ClapFloat> Default for FBGenericConfig<F> { fn default() -> Self { FBGenericConfig { insertion_style : InsertionStyle::Reuse, @@ -457,7 +474,7 @@ iterator : I, plotter : SeqPlotter<F, N> ) -> DiscreteMeasure<Loc<F, N>, F> -where F : Float + ToNalgebraRealField, +where F : ClapFloat + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow @@ -521,7 +538,7 @@ mut residual : A::Observable, mut specialisation : Spec, ) -> DiscreteMeasure<Loc<F, N>, F> -where F : Float + ToNalgebraRealField, +where F : ClapFloat + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, Spec : FBSpecialisation<F, A::Observable, N>, A::Observable : std::ops::MulAssign<F>,
--- a/src/pdps.rs Thu Dec 01 23:07:35 2022 +0200 +++ b/src/pdps.rs Wed Nov 30 23:45:04 2022 +0200 @@ -107,7 +107,7 @@ /// Settings for [`pointsource_pdps`]. #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct PDPSConfig<F : Float> { +pub struct PDPSConfig<F : ClapFloat> { /// Primal step length scaling. We must have `τ0 * σ0 < 1`. pub τ0 : F, /// Dual step length scaling. We must have `τ0 * σ0 < 1`. @@ -119,7 +119,7 @@ } #[replace_float_literals(F::cast_from(literal))] -impl<F : Float> Default for PDPSConfig<F> { +impl<F : ClapFloat> Default for PDPSConfig<F> { fn default() -> Self { let τ0 = 0.5; PDPSConfig { @@ -311,7 +311,7 @@ plotter : SeqPlotter<F, N>, dataterm : D, ) -> DiscreteMeasure<Loc<F, N>, F> -where F : Float + ToNalgebraRealField, +where F : ClapFloat + ToNalgebraRealField, I : AlgIteratorFactory<IterInfo<F, N>>, for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + std::ops::Add<A::Observable, Output=A::Observable>,
--- a/src/run.rs Thu Dec 01 23:07:35 2022 +0200 +++ b/src/run.rs Wed Nov 30 23:45:04 2022 +0200 @@ -67,7 +67,7 @@ /// Available algorithms and their configurations #[derive(Copy, Clone, Debug, Serialize, Deserialize)] -pub enum AlgorithmConfig<F : Float> { +pub enum AlgorithmConfig<F : ClapFloat> { FB(FBConfig<F>), FW(FWConfig<F>), PDPS(PDPSConfig<F>), @@ -141,7 +141,7 @@ impl DefaultAlgorithm { /// Returns the algorithm configuration corresponding to the algorithm shorthand - pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { + pub fn default_config<F : ClapFloat>(&self) -> AlgorithmConfig<F> { use DefaultAlgorithm::*; match *self { FB => AlgorithmConfig::FB(Default::default()), @@ -159,11 +159,11 @@ } /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand - pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { + pub fn get_named<F : ClapFloat>(&self) -> Named<AlgorithmConfig<F>> { self.to_named(self.default_config()) } - pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { + pub fn to_named<F : ClapFloat>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { let name = self.to_possible_value().unwrap().get_name().to_string(); Named{ name , data : alg } } @@ -196,7 +196,7 @@ #[derive(Clone, Debug, Serialize)] #[serde(default)] -pub struct Configuration<F : Float> { +pub struct Configuration<F : ClapFloat> { /// Algorithms to run pub algorithms : Vec<Named<AlgorithmConfig<F>>>, /// Options for algorithm step iteration (verbosity, etc.) @@ -286,7 +286,7 @@ /// Struct for experiment configurations #[derive(Debug, Clone, Serialize)] pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> -where F : Float, +where F : ClapFloat, [usize; N] : Serialize, NoiseDistr : Distribution<F>, S : Sensor<F, N>, @@ -548,7 +548,7 @@ b : &A::Observable, kernel_plot_width : F, ) -> DynError -where F : Float + ToNalgebraRealField, +where F : ClapFloat + ToNalgebraRealField, Sensor : RealMapping<F, N> + Support<F, N> + Clone, Spread : RealMapping<F, N> + Support<F, N> + Clone, Kernel : RealMapping<F, N> + Support<F, N>,