2 |
2 |
3 use numeric_literals::replace_float_literals; |
3 use numeric_literals::replace_float_literals; |
4 |
4 |
5 use colored::ColoredString; |
5 use colored::ColoredString; |
6 use serde::{Serialize, Deserialize}; |
6 use serde::{Serialize, Deserialize}; |
7 use clap::ValueEnum; |
|
8 use alg_tools::iterate::LogRepr; |
7 use alg_tools::iterate::LogRepr; |
9 use alg_tools::euclidean::Euclidean; |
8 use alg_tools::euclidean::Euclidean; |
10 use alg_tools::norms::{Norm, L1}; |
9 use alg_tools::norms::{Norm, L1}; |
11 |
10 |
12 pub use alg_tools::types::*; |
11 pub use alg_tools::types::*; |
13 pub use alg_tools::loc::Loc; |
12 pub use alg_tools::loc::Loc; |
14 pub use alg_tools::sets::Cube; |
13 pub use alg_tools::sets::Cube; |
15 |
14 |
16 use crate::measures::DiscreteMeasure; |
15 // use crate::measures::DiscreteMeasure; |
17 |
16 |
18 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. |
17 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. |
19 pub trait ClapFloat : Float |
18 pub trait ClapFloat : Float |
20 + std::str::FromStr<Err=std::num::ParseFloatError> |
19 + std::str::FromStr<Err=std::num::ParseFloatError> |
21 + std::fmt::Display {} |
20 + std::fmt::Display {} |
25 /// Structure for storing iteration statistics |
24 /// Structure for storing iteration statistics |
26 #[derive(Debug, Clone, Serialize)] |
25 #[derive(Debug, Clone, Serialize)] |
27 pub struct IterInfo<F : Float, const N : usize> { |
26 pub struct IterInfo<F : Float, const N : usize> { |
28 /// Function value |
27 /// Function value |
29 pub value : F, |
28 pub value : F, |
30 /// Number of speaks |
29 /// Number of spikes |
31 pub n_spikes : usize, |
30 pub n_spikes : usize, |
32 /// Number of iterations this statistic covers |
31 /// Number of iterations this statistic covers |
33 pub this_iters : usize, |
32 pub this_iters : usize, |
|
33 /// Number of spikes inserted since last IterInfo statistic |
|
34 pub inserted : usize, |
34 /// Number of spikes removed by merging since last IterInfo statistic |
35 /// Number of spikes removed by merging since last IterInfo statistic |
35 pub merged : usize, |
36 pub merged : usize, |
36 /// Number of spikes removed by pruning since last IterInfo statistic |
37 /// Number of spikes removed by pruning since last IterInfo statistic |
37 pub pruned : usize, |
38 pub pruned : usize, |
38 /// Number of inner iterations since last IterInfo statistic |
39 /// Number of inner iterations since last IterInfo statistic |
39 pub inner_iters : usize, |
40 pub inner_iters : usize, |
|
41 /// Tuple of (transported mass, source mass) |
|
42 pub untransported_fraction : Option<(F, F)>, |
|
43 /// Tuple of (|destination mass - untransported_mass|, transported mass) |
|
44 pub transport_error : Option<(F, F)>, |
40 /// Current tolerance |
45 /// Current tolerance |
41 pub ε : F, |
46 pub ε : F, |
42 /// Solve fin.dim problem for this measure to get the optimal `value`. |
47 // /// Solve fin.dim problem for this measure to get the optimal `value`. |
43 pub postprocessing : Option<DiscreteMeasure<Loc<F, N>, F>>, |
48 // pub postprocessing : Option<RNDM<F, N>>, |
44 } |
49 } |
45 |
50 |
|
51 impl<F : Float, const N : usize> IterInfo<F, N> { |
|
52 /// Initialise statistics with zeros. `ε` and `value` are unspecified. |
|
53 pub fn new() -> Self { |
|
54 IterInfo { |
|
55 value : F::NAN, |
|
56 n_spikes : 0, |
|
57 this_iters : 0, |
|
58 merged : 0, |
|
59 inserted : 0, |
|
60 pruned : 0, |
|
61 inner_iters : 0, |
|
62 ε : F::NAN, |
|
63 // postprocessing : None, |
|
64 untransported_fraction : None, |
|
65 transport_error : None, |
|
66 } |
|
67 } |
|
68 } |
|
69 |
|
70 #[replace_float_literals(F::cast_from(literal))] |
46 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { |
71 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { |
47 fn logrepr(&self) -> ColoredString { |
72 fn logrepr(&self) -> ColoredString { |
48 format!("{}\t| N = {}, ε = {:.8}, inner_iters_mean = {}, merged+pruned_mean = {}+{}", |
73 format!("{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}", |
49 self.value.logrepr(), |
74 self.value.logrepr(), |
50 self.n_spikes, |
75 self.n_spikes, |
51 self.ε, |
76 self.ε, |
52 self.inner_iters as float / self.this_iters as float, |
77 self.inner_iters as float / self.this_iters.max(1) as float, |
53 self.merged as float / self.this_iters as float, |
78 self.inserted as float / self.this_iters.max(1) as float, |
54 self.pruned as float / self.this_iters as float, |
79 self.merged as float / self.this_iters.max(1) as float, |
|
80 self.pruned as float / self.this_iters.max(1) as float, |
|
81 match self.untransported_fraction { |
|
82 None => format!(""), |
|
83 Some((a, b)) => if b > 0.0 { |
|
84 format!(", untransported {:.2}%", 100.0*a/b) |
|
85 } else { |
|
86 format!("") |
|
87 } |
|
88 }, |
|
89 match self.transport_error { |
|
90 None => format!(""), |
|
91 Some((a, b)) => if b > 0.0 { |
|
92 format!(", transport error {:.2}%", 100.0*a/b) |
|
93 } else { |
|
94 format!("") |
|
95 } |
|
96 } |
55 ).as_str().into() |
97 ).as_str().into() |
56 } |
98 } |
57 } |
99 } |
58 |
100 |
59 /// Branch and bound refinement settings |
101 /// Branch and bound refinement settings |
93 Self::L2Squared => z.norm2_squared_div2(), |
135 Self::L2Squared => z.norm2_squared_div2(), |
94 Self::L1 => z.norm(L1), |
136 Self::L1 => z.norm(L1), |
95 } |
137 } |
96 } |
138 } |
97 } |
139 } |
|
140 |
|
141 /// Type for indicating norm-2-squared data fidelity or transport cost. |
|
142 #[derive(Clone, Copy, Serialize, Deserialize)] |
|
143 pub struct L2Squared; |
|
144 |
|
145 /// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`. |
|
146 pub trait Lipschitz<M> { |
|
147 /// The type of floats |
|
148 type FloatType : Float; |
|
149 |
|
150 /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`. |
|
151 fn lipschitz_factor(&self, seminorm : M) -> Option<Self::FloatType>; |
|
152 } |
|
153 |
|
154 /// Trait for norm-bounded functions. |
|
155 pub trait NormBounded<M> { |
|
156 type FloatType : Float; |
|
157 |
|
158 /// Returns a bound on the values of this function object in the `M`-norm. |
|
159 fn norm_bound(&self, m : M) -> Self::FloatType; |
|
160 } |