src/types.rs

changeset 52
f0e8704d3f0e
parent 35
b087e3eab191
equal deleted inserted replaced
31:6105b5cd8d89 52:f0e8704d3f0e
2 2
3 use numeric_literals::replace_float_literals; 3 use numeric_literals::replace_float_literals;
4 4
5 use colored::ColoredString; 5 use colored::ColoredString;
6 use serde::{Serialize, Deserialize}; 6 use serde::{Serialize, Deserialize};
7 use clap::ValueEnum;
8 use alg_tools::iterate::LogRepr; 7 use alg_tools::iterate::LogRepr;
9 use alg_tools::euclidean::Euclidean; 8 use alg_tools::euclidean::Euclidean;
10 use alg_tools::norms::{Norm, L1}; 9 use alg_tools::norms::{Norm, L1};
11 10
12 pub use alg_tools::types::*; 11 pub use alg_tools::types::*;
13 pub use alg_tools::loc::Loc; 12 pub use alg_tools::loc::Loc;
14 pub use alg_tools::sets::Cube; 13 pub use alg_tools::sets::Cube;
15 14
16 use crate::measures::DiscreteMeasure; 15 // use crate::measures::DiscreteMeasure;
17 16
18 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. 17 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up.
19 pub trait ClapFloat : Float 18 pub trait ClapFloat : Float
20 + std::str::FromStr<Err=std::num::ParseFloatError> 19 + std::str::FromStr<Err=std::num::ParseFloatError>
21 + std::fmt::Display {} 20 + std::fmt::Display {}
25 /// Structure for storing iteration statistics 24 /// Structure for storing iteration statistics
26 #[derive(Debug, Clone, Serialize)] 25 #[derive(Debug, Clone, Serialize)]
27 pub struct IterInfo<F : Float, const N : usize> { 26 pub struct IterInfo<F : Float, const N : usize> {
28 /// Function value 27 /// Function value
29 pub value : F, 28 pub value : F,
30 /// Number of speaks 29 /// Number of spikes
31 pub n_spikes : usize, 30 pub n_spikes : usize,
32 /// Number of iterations this statistic covers 31 /// Number of iterations this statistic covers
33 pub this_iters : usize, 32 pub this_iters : usize,
33 /// Number of spikes inserted since last IterInfo statistic
34 pub inserted : usize,
34 /// Number of spikes removed by merging since last IterInfo statistic 35 /// Number of spikes removed by merging since last IterInfo statistic
35 pub merged : usize, 36 pub merged : usize,
36 /// Number of spikes removed by pruning since last IterInfo statistic 37 /// Number of spikes removed by pruning since last IterInfo statistic
37 pub pruned : usize, 38 pub pruned : usize,
38 /// Number of inner iterations since last IterInfo statistic 39 /// Number of inner iterations since last IterInfo statistic
39 pub inner_iters : usize, 40 pub inner_iters : usize,
41 /// Tuple of (transported mass, source mass)
42 pub untransported_fraction : Option<(F, F)>,
43 /// Tuple of (|destination mass - untransported_mass|, transported mass)
44 pub transport_error : Option<(F, F)>,
40 /// Current tolerance 45 /// Current tolerance
41 pub ε : F, 46 pub ε : F,
42 /// Solve fin.dim problem for this measure to get the optimal `value`. 47 // /// Solve fin.dim problem for this measure to get the optimal `value`.
43 pub postprocessing : Option<DiscreteMeasure<Loc<F, N>, F>>, 48 // pub postprocessing : Option<RNDM<F, N>>,
44 } 49 }
45 50
51 impl<F : Float, const N : usize> IterInfo<F, N> {
52 /// Initialise statistics with zeros. `ε` and `value` are unspecified.
53 pub fn new() -> Self {
54 IterInfo {
55 value : F::NAN,
56 n_spikes : 0,
57 this_iters : 0,
58 merged : 0,
59 inserted : 0,
60 pruned : 0,
61 inner_iters : 0,
62 ε : F::NAN,
63 // postprocessing : None,
64 untransported_fraction : None,
65 transport_error : None,
66 }
67 }
68 }
69
70 #[replace_float_literals(F::cast_from(literal))]
46 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { 71 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float {
47 fn logrepr(&self) -> ColoredString { 72 fn logrepr(&self) -> ColoredString {
48 format!("{}\t| N = {}, ε = {:.8}, inner_iters_mean = {}, merged+pruned_mean = {}+{}", 73 format!("{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}",
49 self.value.logrepr(), 74 self.value.logrepr(),
50 self.n_spikes, 75 self.n_spikes,
51 self.ε, 76 self.ε,
52 self.inner_iters as float / self.this_iters as float, 77 self.inner_iters as float / self.this_iters.max(1) as float,
53 self.merged as float / self.this_iters as float, 78 self.inserted as float / self.this_iters.max(1) as float,
54 self.pruned as float / self.this_iters as float, 79 self.merged as float / self.this_iters.max(1) as float,
80 self.pruned as float / self.this_iters.max(1) as float,
81 match self.untransported_fraction {
82 None => format!(""),
83 Some((a, b)) => if b > 0.0 {
84 format!(", untransported {:.2}%", 100.0*a/b)
85 } else {
86 format!("")
87 }
88 },
89 match self.transport_error {
90 None => format!(""),
91 Some((a, b)) => if b > 0.0 {
92 format!(", transport error {:.2}%", 100.0*a/b)
93 } else {
94 format!("")
95 }
96 }
55 ).as_str().into() 97 ).as_str().into()
56 } 98 }
57 } 99 }
58 100
59 /// Branch and bound refinement settings 101 /// Branch and bound refinement settings
76 } 118 }
77 } 119 }
78 } 120 }
79 121
80 /// Data term type 122 /// Data term type
81 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug, ValueEnum)] 123 #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)]
82 pub enum DataTerm { 124 pub enum DataTerm {
83 /// $\\|z\\|\_2^2/2$ 125 /// $\\|z\\|\_2^2/2$
84 L2Squared, 126 L2Squared,
85 /// $\\|z\\|\_1$ 127 /// $\\|z\\|\_1$
86 L1, 128 L1,
93 Self::L2Squared => z.norm2_squared_div2(), 135 Self::L2Squared => z.norm2_squared_div2(),
94 Self::L1 => z.norm(L1), 136 Self::L1 => z.norm(L1),
95 } 137 }
96 } 138 }
97 } 139 }
140
141 /// Type for indicating norm-2-squared data fidelity or transport cost.
142 #[derive(Clone, Copy, Serialize, Deserialize)]
143 pub struct L2Squared;
144
145 /// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`.
146 pub trait Lipschitz<M> {
147 /// The type of floats
148 type FloatType : Float;
149
150 /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`.
151 fn lipschitz_factor(&self, seminorm : M) -> Option<Self::FloatType>;
152 }
153
154 /// Trait for norm-bounded functions.
155 pub trait NormBounded<M> {
156 type FloatType : Float;
157
158 /// Returns a bound on the values of this function object in the `M`-norm.
159 fn norm_bound(&self, m : M) -> Self::FloatType;
160 }

mercurial