| 2 |
2 |
| 3 use numeric_literals::replace_float_literals; |
3 use numeric_literals::replace_float_literals; |
| 4 |
4 |
| 5 use colored::ColoredString; |
5 use colored::ColoredString; |
| 6 use serde::{Serialize, Deserialize}; |
6 use serde::{Serialize, Deserialize}; |
| 7 use clap::ValueEnum; |
|
| 8 use alg_tools::iterate::LogRepr; |
7 use alg_tools::iterate::LogRepr; |
| 9 use alg_tools::euclidean::Euclidean; |
8 use alg_tools::euclidean::Euclidean; |
| 10 use alg_tools::norms::{Norm, L1}; |
9 use alg_tools::norms::{Norm, L1}; |
| 11 |
10 |
| 12 pub use alg_tools::types::*; |
11 pub use alg_tools::types::*; |
| 13 pub use alg_tools::loc::Loc; |
12 pub use alg_tools::loc::Loc; |
| 14 pub use alg_tools::sets::Cube; |
13 pub use alg_tools::sets::Cube; |
| 15 |
14 |
| 16 use crate::measures::DiscreteMeasure; |
15 // use crate::measures::DiscreteMeasure; |
| 17 |
16 |
| 18 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. |
17 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. |
| 19 pub trait ClapFloat : Float |
18 pub trait ClapFloat : Float |
| 20 + std::str::FromStr<Err=std::num::ParseFloatError> |
19 + std::str::FromStr<Err=std::num::ParseFloatError> |
| 21 + std::fmt::Display {} |
20 + std::fmt::Display {} |
| 25 /// Structure for storing iteration statistics |
24 /// Structure for storing iteration statistics |
| 26 #[derive(Debug, Clone, Serialize)] |
25 #[derive(Debug, Clone, Serialize)] |
| 27 pub struct IterInfo<F : Float, const N : usize> { |
26 pub struct IterInfo<F : Float, const N : usize> { |
| 28 /// Function value |
27 /// Function value |
| 29 pub value : F, |
28 pub value : F, |
| 30 /// Number of speaks |
29 /// Number of spikes |
| 31 pub n_spikes : usize, |
30 pub n_spikes : usize, |
| 32 /// Number of iterations this statistic covers |
31 /// Number of iterations this statistic covers |
| 33 pub this_iters : usize, |
32 pub this_iters : usize, |
| |
33 /// Number of spikes inserted since last IterInfo statistic |
| |
34 pub inserted : usize, |
| 34 /// Number of spikes removed by merging since last IterInfo statistic |
35 /// Number of spikes removed by merging since last IterInfo statistic |
| 35 pub merged : usize, |
36 pub merged : usize, |
| 36 /// Number of spikes removed by pruning since last IterInfo statistic |
37 /// Number of spikes removed by pruning since last IterInfo statistic |
| 37 pub pruned : usize, |
38 pub pruned : usize, |
| 38 /// Number of inner iterations since last IterInfo statistic |
39 /// Number of inner iterations since last IterInfo statistic |
| 39 pub inner_iters : usize, |
40 pub inner_iters : usize, |
| |
41 /// Tuple of (transported mass, source mass) |
| |
42 pub untransported_fraction : Option<(F, F)>, |
| |
43 /// Tuple of (|destination mass - untransported_mass|, transported mass) |
| |
44 pub transport_error : Option<(F, F)>, |
| 40 /// Current tolerance |
45 /// Current tolerance |
| 41 pub ε : F, |
46 pub ε : F, |
| 42 /// Solve fin.dim problem for this measure to get the optimal `value`. |
47 // /// Solve fin.dim problem for this measure to get the optimal `value`. |
| 43 pub postprocessing : Option<DiscreteMeasure<Loc<F, N>, F>>, |
48 // pub postprocessing : Option<RNDM<F, N>>, |
| 44 } |
49 } |
| 45 |
50 |
| |
51 impl<F : Float, const N : usize> IterInfo<F, N> { |
| |
52 /// Initialise statistics with zeros. `ε` and `value` are unspecified. |
| |
53 pub fn new() -> Self { |
| |
54 IterInfo { |
| |
55 value : F::NAN, |
| |
56 n_spikes : 0, |
| |
57 this_iters : 0, |
| |
58 merged : 0, |
| |
59 inserted : 0, |
| |
60 pruned : 0, |
| |
61 inner_iters : 0, |
| |
62 ε : F::NAN, |
| |
63 // postprocessing : None, |
| |
64 untransported_fraction : None, |
| |
65 transport_error : None, |
| |
66 } |
| |
67 } |
| |
68 } |
| |
69 |
| |
70 #[replace_float_literals(F::cast_from(literal))] |
| 46 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { |
71 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { |
| 47 fn logrepr(&self) -> ColoredString { |
72 fn logrepr(&self) -> ColoredString { |
| 48 format!("{}\t| N = {}, ε = {:.8}, inner_iters_mean = {}, merged+pruned_mean = {}+{}", |
73 format!("{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}", |
| 49 self.value.logrepr(), |
74 self.value.logrepr(), |
| 50 self.n_spikes, |
75 self.n_spikes, |
| 51 self.ε, |
76 self.ε, |
| 52 self.inner_iters as float / self.this_iters as float, |
77 self.inner_iters as float / self.this_iters.max(1) as float, |
| 53 self.merged as float / self.this_iters as float, |
78 self.inserted as float / self.this_iters.max(1) as float, |
| 54 self.pruned as float / self.this_iters as float, |
79 self.merged as float / self.this_iters.max(1) as float, |
| |
80 self.pruned as float / self.this_iters.max(1) as float, |
| |
81 match self.untransported_fraction { |
| |
82 None => format!(""), |
| |
83 Some((a, b)) => if b > 0.0 { |
| |
84 format!(", untransported {:.2}%", 100.0*a/b) |
| |
85 } else { |
| |
86 format!("") |
| |
87 } |
| |
88 }, |
| |
89 match self.transport_error { |
| |
90 None => format!(""), |
| |
91 Some((a, b)) => if b > 0.0 { |
| |
92 format!(", transport error {:.2}%", 100.0*a/b) |
| |
93 } else { |
| |
94 format!("") |
| |
95 } |
| |
96 } |
| 55 ).as_str().into() |
97 ).as_str().into() |
| 56 } |
98 } |
| 57 } |
99 } |
| 58 |
100 |
| 59 /// Branch and bound refinement settings |
101 /// Branch and bound refinement settings |
| 93 Self::L2Squared => z.norm2_squared_div2(), |
135 Self::L2Squared => z.norm2_squared_div2(), |
| 94 Self::L1 => z.norm(L1), |
136 Self::L1 => z.norm(L1), |
| 95 } |
137 } |
| 96 } |
138 } |
| 97 } |
139 } |
| |
140 |
| |
141 /// Type for indicating norm-2-squared data fidelity or transport cost. |
| |
142 #[derive(Clone, Copy, Serialize, Deserialize)] |
| |
143 pub struct L2Squared; |
| |
144 |
| |
145 /// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`. |
| |
146 pub trait Lipschitz<M> { |
| |
147 /// The type of floats |
| |
148 type FloatType : Float; |
| |
149 |
| |
150 /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`. |
| |
151 fn lipschitz_factor(&self, seminorm : M) -> Option<Self::FloatType>; |
| |
152 } |
| |
153 |
| |
154 /// Trait for norm-bounded functions. |
| |
155 pub trait NormBounded<M> { |
| |
156 type FloatType : Float; |
| |
157 |
| |
158 /// Returns a bound on the values of this function object in the `M`-norm. |
| |
159 fn norm_bound(&self, m : M) -> Self::FloatType; |
| |
160 } |