src/types.rs

changeset 0
eb3c7813b67a
child 8
ea3ca78873e8
equal deleted inserted replaced
-1:000000000000 0:eb3c7813b67a
1 //! Type definitions and re-exports
2
3 use numeric_literals::replace_float_literals;
4
5 use colored::ColoredString;
6 use serde::{Serialize, Deserialize};
7 use clap::ValueEnum;
8 use alg_tools::iterate::LogRepr;
9 use alg_tools::euclidean::Euclidean;
10 use alg_tools::norms::{Norm, L1};
11
12 pub use alg_tools::types::*;
13 pub use alg_tools::loc::Loc;
14 pub use alg_tools::sets::Cube;
15
16 use crate::measures::DiscreteMeasure;
17
18 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up.
19 pub trait ClapFloat : Float
20 + std::str::FromStr<Err=std::num::ParseFloatError>
21 + std::fmt::Display {}
22 impl ClapFloat for f32 {}
23 impl ClapFloat for f64 {}
24
25 /// Structure for storing iteration statistics
26 #[derive(Debug, Clone, Serialize)]
27 pub struct IterInfo<F : Float, const N : usize> {
28 /// Function value
29 pub value : F,
30 /// Number of speaks
31 pub n_spikes : usize,
32 /// Number of iterations this statistic covers
33 pub this_iters : usize,
34 /// Number of spikes removed by merging since last IterInfo statistic
35 pub merged : usize,
36 /// Number of spikes removed by pruning since last IterInfo statistic
37 pub pruned : usize,
38 /// Number of inner iterations since last IterInfo statistic
39 pub inner_iters : usize,
40 /// Current tolerance
41 pub ε : F,
42 /// Strict tolerance update if one was used
43 pub maybe_ε1 : Option<F>,
44 /// Solve fin.dim problem for this measure to get the optimal `value`.
45 pub postprocessing : Option<DiscreteMeasure<Loc<F, N>, F>>,
46 }
47
48 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float {
49 fn logrepr(&self) -> ColoredString {
50 let eqsign = match self.maybe_ε1 {
51 Some(ε1) if ε1 < self.ε => '≛',
52 _ => '=',
53 };
54 format!("{}\t| N = {}, ε {} {:.8}, inner_iters_mean = {}, merged+pruned_mean = {}+{}",
55 self.value.logrepr(),
56 self.n_spikes,
57 eqsign,
58 self.ε,
59 self.inner_iters as float / self.this_iters as float,
60 self.merged as float / self.this_iters as float,
61 self.pruned as float / self.this_iters as float,
62 ).as_str().into()
63 }
64 }
65
66 /// Branch and bound refinement settings
67 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
68 #[serde(default)]
69 pub struct RefinementSettings<F : Float> {
70 /// Function value tolerance multiplier for bisection tree refinement in
71 /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions.
72 pub tolerance_mult : F,
73 /// Maximum branch and bound steps
74 pub max_steps : usize,
75 }
76
77 #[replace_float_literals(F::cast_from(literal))]
78 impl<F : Float> Default for RefinementSettings<F> {
79 fn default() -> Self {
80 RefinementSettings {
81 tolerance_mult : 0.1,
82 max_steps : 50000,
83 }
84 }
85 }
86
87 /// Data term type
88 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug, ValueEnum)]
89 pub enum DataTerm {
90 /// $\\|z\\|\_2^2/2$
91 L2Squared,
92 /// $\\|z\\|\_1$
93 L1,
94 }
95
96 impl DataTerm {
97 /// Calculate the data term value at residual $z=Aμ - b$.
98 pub fn value_at_residual<F : Float, E : Euclidean<F> + Norm<F, L1>>(&self, z : E) -> F {
99 match self {
100 Self::L2Squared => z.norm2_squared_div2(),
101 Self::L1 => z.norm(L1),
102 }
103 }
104 }
105

mercurial