| 1 //! Type definitions and re-exports |
1 //! Type definitions and re-exports |
| 2 |
2 |
| 3 use numeric_literals::replace_float_literals; |
3 use numeric_literals::replace_float_literals; |
| 4 |
4 |
| |
5 use alg_tools::iterate::LogRepr; |
| 5 use colored::ColoredString; |
6 use colored::ColoredString; |
| 6 use serde::{Serialize, Deserialize}; |
7 use serde::{Deserialize, Serialize}; |
| 7 use alg_tools::iterate::LogRepr; |
|
| 8 use alg_tools::euclidean::Euclidean; |
|
| 9 use alg_tools::norms::{Norm, L1}; |
|
| 10 |
8 |
| 11 pub use alg_tools::types::*; |
9 pub use alg_tools::error::DynResult; |
| 12 pub use alg_tools::loc::Loc; |
10 pub use alg_tools::loc::Loc; |
| 13 pub use alg_tools::sets::Cube; |
11 pub use alg_tools::sets::Cube; |
| |
12 pub use alg_tools::types::*; |
| 14 |
13 |
| 15 // use crate::measures::DiscreteMeasure; |
14 // use crate::measures::DiscreteMeasure; |
| 16 |
15 |
| 17 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. |
16 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. |
| 18 pub trait ClapFloat : Float |
17 pub trait ClapFloat: |
| 19 + std::str::FromStr<Err=std::num::ParseFloatError> |
18 Float + std::str::FromStr<Err = std::num::ParseFloatError> + std::fmt::Display |
| 20 + std::fmt::Display {} |
19 { |
| |
20 } |
| 21 impl ClapFloat for f32 {} |
21 impl ClapFloat for f32 {} |
| 22 impl ClapFloat for f64 {} |
22 impl ClapFloat for f64 {} |
| 23 |
23 |
| 24 /// Structure for storing iteration statistics |
24 /// Structure for storing transport statistics |
| 25 #[derive(Debug, Clone, Serialize)] |
25 #[derive(Debug, Clone, Serialize)] |
| 26 pub struct IterInfo<F : Float, const N : usize> { |
26 pub struct TransportInfo<F: Float = f64> { |
| 27 /// Function value |
27 /// Tuple of (untransported mass, source mass) |
| 28 pub value : F, |
28 pub untransported_fraction: (F, F), |
| 29 /// Number of spikes |
29 /// Tuple of (|destination mass - transported_mass|, transported mass) |
| 30 pub n_spikes : usize, |
30 pub transport_error: (F, F), |
| 31 /// Number of iterations this statistic covers |
31 /// Number of readjustment iterations for transport |
| 32 pub this_iters : usize, |
32 pub readjustment_iters: usize, |
| 33 /// Number of spikes inserted since last IterInfo statistic |
33 /// ($∫ c_2 dγ , ∫ dγ$) |
| 34 pub inserted : usize, |
34 pub dist: (F, F), |
| 35 /// Number of spikes removed by merging since last IterInfo statistic |
|
| 36 pub merged : usize, |
|
| 37 /// Number of spikes removed by pruning since last IterInfo statistic |
|
| 38 pub pruned : usize, |
|
| 39 /// Number of inner iterations since last IterInfo statistic |
|
| 40 pub inner_iters : usize, |
|
| 41 /// Tuple of (transported mass, source mass) |
|
| 42 pub untransported_fraction : Option<(F, F)>, |
|
| 43 /// Tuple of (|destination mass - untransported_mass|, transported mass) |
|
| 44 pub transport_error : Option<(F, F)>, |
|
| 45 /// Current tolerance |
|
| 46 pub ε : F, |
|
| 47 // /// Solve fin.dim problem for this measure to get the optimal `value`. |
|
| 48 // pub postprocessing : Option<RNDM<F, N>>, |
|
| 49 } |
35 } |
| 50 |
36 |
| 51 impl<F : Float, const N : usize> IterInfo<F, N> { |
37 #[replace_float_literals(F::cast_from(literal))] |
| 52 /// Initialise statistics with zeros. `ε` and `value` are unspecified. |
38 impl<F: Float> TransportInfo<F> { |
| |
39 /// Initialise transport statistics |
| 53 pub fn new() -> Self { |
40 pub fn new() -> Self { |
| 54 IterInfo { |
41 TransportInfo { |
| 55 value : F::NAN, |
42 untransported_fraction: (0.0, 0.0), |
| 56 n_spikes : 0, |
43 transport_error: (0.0, 0.0), |
| 57 this_iters : 0, |
44 readjustment_iters: 0, |
| 58 merged : 0, |
45 dist: (0.0, 0.0), |
| 59 inserted : 0, |
|
| 60 pruned : 0, |
|
| 61 inner_iters : 0, |
|
| 62 ε : F::NAN, |
|
| 63 // postprocessing : None, |
|
| 64 untransported_fraction : None, |
|
| 65 transport_error : None, |
|
| 66 } |
46 } |
| 67 } |
47 } |
| 68 } |
48 } |
| 69 |
49 |
| |
50 /// Structure for storing iteration statistics |
| |
51 #[derive(Debug, Clone, Serialize)] |
| |
52 pub struct IterInfo<F: Float = f64> { |
| |
53 /// Function value |
| |
54 pub value: F, |
| |
55 /// Number of spikes |
| |
56 pub n_spikes: usize, |
| |
57 /// Number of iterations this statistic covers |
| |
58 pub this_iters: usize, |
| |
59 /// Number of spikes inserted since last IterInfo statistic |
| |
60 pub inserted: usize, |
| |
61 /// Number of spikes removed by merging since last IterInfo statistic |
| |
62 pub merged: usize, |
| |
63 /// Number of spikes removed by pruning since last IterInfo statistic |
| |
64 pub pruned: usize, |
| |
65 /// Number of inner iterations since last IterInfo statistic |
| |
66 pub inner_iters: usize, |
| |
67 /// Transport statistis |
| |
68 pub transport: Option<TransportInfo<F>>, |
| |
69 /// Current tolerance |
| |
70 pub ε: F, |
| |
71 // /// Solve fin.dim problem for this measure to get the optimal `value`. |
| |
72 // pub postprocessing : Option<RNDM<N, F>>, |
| |
73 } |
| |
74 |
| |
75 impl<F: Float> IterInfo<F> { |
| |
76 /// Initialise statistics with zeros. `ε` and `value` are unspecified. |
| |
77 pub fn new() -> Self { |
| |
78 IterInfo { |
| |
79 value: F::NAN, |
| |
80 n_spikes: 0, |
| |
81 this_iters: 0, |
| |
82 merged: 0, |
| |
83 inserted: 0, |
| |
84 pruned: 0, |
| |
85 inner_iters: 0, |
| |
86 ε: F::NAN, |
| |
87 // postprocessing : None, |
| |
88 transport: None, |
| |
89 } |
| |
90 } |
| |
91 |
| |
92 /// Get mutable reference to transport statistics, creating it if it is `None`. |
| |
93 pub fn get_transport_mut(&mut self) -> &mut TransportInfo<F> { |
| |
94 if self.transport.is_none() { |
| |
95 self.transport = Some(TransportInfo::new()); |
| |
96 } |
| |
97 self.transport.as_mut().unwrap() |
| |
98 } |
| |
99 } |
| |
100 |
| 70 #[replace_float_literals(F::cast_from(literal))] |
101 #[replace_float_literals(F::cast_from(literal))] |
| 71 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { |
102 impl<F> LogRepr for IterInfo<F> |
| |
103 where |
| |
104 F: LogRepr + Float, |
| |
105 { |
| 72 fn logrepr(&self) -> ColoredString { |
106 fn logrepr(&self) -> ColoredString { |
| 73 format!("{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}", |
107 format!( |
| 74 self.value.logrepr(), |
108 "{}\t| N = {}, ε = {:.2e}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}", |
| 75 self.n_spikes, |
109 self.value.logrepr(), |
| 76 self.ε, |
110 self.n_spikes, |
| 77 self.inner_iters as float / self.this_iters.max(1) as float, |
111 self.ε, |
| 78 self.inserted as float / self.this_iters.max(1) as float, |
112 self.inner_iters as float / self.this_iters.max(1) as float, |
| 79 self.merged as float / self.this_iters.max(1) as float, |
113 self.inserted as float / self.this_iters.max(1) as float, |
| 80 self.pruned as float / self.this_iters.max(1) as float, |
114 self.merged as float / self.this_iters.max(1) as float, |
| 81 match self.untransported_fraction { |
115 self.pruned as float / self.this_iters.max(1) as float, |
| 82 None => format!(""), |
116 match &self.transport { |
| 83 Some((a, b)) => if b > 0.0 { |
117 None => format!(""), |
| 84 format!(", untransported {:.2}%", 100.0*a/b) |
118 Some(t) => { |
| 85 } else { |
119 let (a1, b1) = t.untransported_fraction; |
| 86 format!("") |
120 let (a2, b2) = t.transport_error; |
| 87 } |
121 let (a3, b3) = t.dist; |
| 88 }, |
122 format!( |
| 89 match self.transport_error { |
123 ", γ-un/er/d/it = {:.2}%/{:.2}%/{:.2e}/{:.2}", |
| 90 None => format!(""), |
124 if b1 > 0.0 { 100.0 * a1 / b1 } else { F::NAN }, |
| 91 Some((a, b)) => if b > 0.0 { |
125 if b2 > 0.0 { 100.0 * a2 / b2 } else { F::NAN }, |
| 92 format!(", transport error {:.2}%", 100.0*a/b) |
126 if b3 > 0.0 { a3 / b3 } else { F::NAN }, |
| 93 } else { |
127 t.readjustment_iters as float / self.this_iters.max(1) as float, |
| 94 format!("") |
128 ) |
| 95 } |
|
| 96 } |
129 } |
| 97 ).as_str().into() |
130 } |
| |
131 ) |
| |
132 .as_str() |
| |
133 .into() |
| 98 } |
134 } |
| 99 } |
135 } |
| 100 |
136 |
| 101 /// Branch and bound refinement settings |
137 /// Branch and bound refinement settings |
| 102 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
138 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 103 #[serde(default)] |
139 #[serde(default)] |
| 104 pub struct RefinementSettings<F : Float> { |
140 pub struct RefinementSettings<F: Float> { |
| 105 /// Function value tolerance multiplier for bisection tree refinement in |
141 /// Function value tolerance multiplier for bisection tree refinement in |
| 106 /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions. |
142 /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions. |
| 107 pub tolerance_mult : F, |
143 pub tolerance_mult: F, |
| 108 /// Maximum branch and bound steps |
144 /// Maximum branch and bound steps |
| 109 pub max_steps : usize, |
145 pub max_steps: usize, |
| 110 } |
146 } |
| 111 |
147 |
| 112 #[replace_float_literals(F::cast_from(literal))] |
148 #[replace_float_literals(F::cast_from(literal))] |
| 113 impl<F : Float> Default for RefinementSettings<F> { |
149 impl<F: Float> Default for RefinementSettings<F> { |
| 114 fn default() -> Self { |
150 fn default() -> Self { |
| 115 RefinementSettings { |
151 RefinementSettings { tolerance_mult: 0.1, max_steps: 50000 } |
| 116 tolerance_mult : 0.1, |
|
| 117 max_steps : 50000, |
|
| 118 } |
|
| 119 } |
152 } |
| 120 } |
153 } |
| 121 |
154 |
| 122 /// Data term type |
155 /// Data term type |
| 123 #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)] |
156 #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)] |
| 124 pub enum DataTerm { |
157 pub enum DataTermType { |
| 125 /// $\\|z\\|\_2^2/2$ |
158 /// $\\|z\\|\_2^2/2$ |
| 126 L2Squared, |
159 L222, |
| 127 /// $\\|z\\|\_1$ |
160 /// $\\|z\\|\_1$ |
| 128 L1, |
161 L1, |
| 129 } |
162 } |
| 130 |
163 |
| 131 impl DataTerm { |
164 pub use alg_tools::mapping::Lipschitz; |
| 132 /// Calculate the data term value at residual $z=Aμ - b$. |
|
| 133 pub fn value_at_residual<F : Float, E : Euclidean<F> + Norm<F, L1>>(&self, z : E) -> F { |
|
| 134 match self { |
|
| 135 Self::L2Squared => z.norm2_squared_div2(), |
|
| 136 Self::L1 => z.norm(L1), |
|
| 137 } |
|
| 138 } |
|
| 139 } |
|
| 140 |
|
| 141 /// Type for indicating norm-2-squared data fidelity or transport cost. |
|
| 142 #[derive(Clone, Copy, Serialize, Deserialize)] |
|
| 143 pub struct L2Squared; |
|
| 144 |
|
| 145 /// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`. |
|
| 146 pub trait Lipschitz<M> { |
|
| 147 /// The type of floats |
|
| 148 type FloatType : Float; |
|
| 149 |
|
| 150 /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`. |
|
| 151 fn lipschitz_factor(&self, seminorm : M) -> Option<Self::FloatType>; |
|
| 152 } |
|
| 153 |
165 |
| 154 /// Trait for norm-bounded functions. |
166 /// Trait for norm-bounded functions. |
| 155 pub trait NormBounded<M> { |
167 pub trait NormBounded<M> { |
| 156 type FloatType : Float; |
168 type FloatType: Float; |
| 157 |
169 |
| 158 /// Returns a bound on the values of this function object in the `M`-norm. |
170 /// Returns a bound on the values of this function object in the `M`-norm. |
| 159 fn norm_bound(&self, m : M) -> Self::FloatType; |
171 fn norm_bound(&self, m: M) -> Self::FloatType; |
| 160 } |
172 } |