| 1 //! Type definitions and re-exports |
1 //! Type definitions and re-exports |
| 2 |
2 |
| 3 use numeric_literals::replace_float_literals; |
3 use numeric_literals::replace_float_literals; |
| 4 |
4 |
| |
5 use alg_tools::iterate::LogRepr; |
| 5 use colored::ColoredString; |
6 use colored::ColoredString; |
| 6 use serde::{Serialize, Deserialize}; |
7 use serde::{Deserialize, Serialize}; |
| 7 use alg_tools::iterate::LogRepr; |
|
| 8 use alg_tools::euclidean::Euclidean; |
|
| 9 use alg_tools::norms::{Norm, L1}; |
|
| 10 |
8 |
| 11 pub use alg_tools::types::*; |
9 pub use alg_tools::error::DynResult; |
| 12 pub use alg_tools::loc::Loc; |
10 pub use alg_tools::loc::Loc; |
| 13 pub use alg_tools::sets::Cube; |
11 pub use alg_tools::sets::Cube; |
| |
12 pub use alg_tools::types::*; |
| 14 |
13 |
| 15 // use crate::measures::DiscreteMeasure; |
14 // use crate::measures::DiscreteMeasure; |
| 16 |
15 |
| 17 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. |
16 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. |
| 18 pub trait ClapFloat : Float |
17 pub trait ClapFloat: |
| 19 + std::str::FromStr<Err=std::num::ParseFloatError> |
18 Float + std::str::FromStr<Err = std::num::ParseFloatError> + std::fmt::Display |
| 20 + std::fmt::Display {} |
19 { |
| |
20 } |
| 21 impl ClapFloat for f32 {} |
21 impl ClapFloat for f32 {} |
| 22 impl ClapFloat for f64 {} |
22 impl ClapFloat for f64 {} |
| 23 |
23 |
| 24 /// Structure for storing iteration statistics |
24 /// Structure for storing iteration statistics |
| 25 #[derive(Debug, Clone, Serialize)] |
25 #[derive(Debug, Clone, Serialize)] |
| 26 pub struct IterInfo<F : Float, const N : usize> { |
26 pub struct IterInfo<F: Float = f64> { |
| 27 /// Function value |
27 /// Function value |
| 28 pub value : F, |
28 pub value: F, |
| 29 /// Number of spikes |
29 /// Number of spikes |
| 30 pub n_spikes : usize, |
30 pub n_spikes: usize, |
| 31 /// Number of iterations this statistic covers |
31 /// Number of iterations this statistic covers |
| 32 pub this_iters : usize, |
32 pub this_iters: usize, |
| 33 /// Number of spikes inserted since last IterInfo statistic |
33 /// Number of spikes inserted since last IterInfo statistic |
| 34 pub inserted : usize, |
34 pub inserted: usize, |
| 35 /// Number of spikes removed by merging since last IterInfo statistic |
35 /// Number of spikes removed by merging since last IterInfo statistic |
| 36 pub merged : usize, |
36 pub merged: usize, |
| 37 /// Number of spikes removed by pruning since last IterInfo statistic |
37 /// Number of spikes removed by pruning since last IterInfo statistic |
| 38 pub pruned : usize, |
38 pub pruned: usize, |
| 39 /// Number of inner iterations since last IterInfo statistic |
39 /// Number of inner iterations since last IterInfo statistic |
| 40 pub inner_iters : usize, |
40 pub inner_iters: usize, |
| 41 /// Tuple of (transported mass, source mass) |
41 /// Tuple of (transported mass, source mass) |
| 42 pub untransported_fraction : Option<(F, F)>, |
42 pub untransported_fraction: Option<(F, F)>, |
| 43 /// Tuple of (|destination mass - untransported_mass|, transported mass) |
43 /// Tuple of (|destination mass - untransported_mass|, transported mass) |
| 44 pub transport_error : Option<(F, F)>, |
44 pub transport_error: Option<(F, F)>, |
| 45 /// Current tolerance |
45 /// Current tolerance |
| 46 pub ε : F, |
46 pub ε: F, |
| 47 // /// Solve fin.dim problem for this measure to get the optimal `value`. |
47 // /// Solve fin.dim problem for this measure to get the optimal `value`. |
| 48 // pub postprocessing : Option<RNDM<F, N>>, |
48 // pub postprocessing : Option<RNDM<N, F>>, |
| 49 } |
49 } |
| 50 |
50 |
| 51 impl<F : Float, const N : usize> IterInfo<F, N> { |
51 impl<F: Float> IterInfo<F> { |
| 52 /// Initialise statistics with zeros. `ε` and `value` are unspecified. |
52 /// Initialise statistics with zeros. `ε` and `value` are unspecified. |
| 53 pub fn new() -> Self { |
53 pub fn new() -> Self { |
| 54 IterInfo { |
54 IterInfo { |
| 55 value : F::NAN, |
55 value: F::NAN, |
| 56 n_spikes : 0, |
56 n_spikes: 0, |
| 57 this_iters : 0, |
57 this_iters: 0, |
| 58 merged : 0, |
58 merged: 0, |
| 59 inserted : 0, |
59 inserted: 0, |
| 60 pruned : 0, |
60 pruned: 0, |
| 61 inner_iters : 0, |
61 inner_iters: 0, |
| 62 ε : F::NAN, |
62 ε: F::NAN, |
| 63 // postprocessing : None, |
63 // postprocessing : None, |
| 64 untransported_fraction : None, |
64 untransported_fraction: None, |
| 65 transport_error : None, |
65 transport_error: None, |
| 66 } |
66 } |
| 67 } |
67 } |
| 68 } |
68 } |
| 69 |
69 |
| 70 #[replace_float_literals(F::cast_from(literal))] |
70 #[replace_float_literals(F::cast_from(literal))] |
| 71 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { |
71 impl<F> LogRepr for IterInfo<F> |
| |
72 where |
| |
73 F: LogRepr + Float, |
| |
74 { |
| 72 fn logrepr(&self) -> ColoredString { |
75 fn logrepr(&self) -> ColoredString { |
| 73 format!("{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}", |
76 format!( |
| 74 self.value.logrepr(), |
77 "{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}", |
| 75 self.n_spikes, |
78 self.value.logrepr(), |
| 76 self.ε, |
79 self.n_spikes, |
| 77 self.inner_iters as float / self.this_iters.max(1) as float, |
80 self.ε, |
| 78 self.inserted as float / self.this_iters.max(1) as float, |
81 self.inner_iters as float / self.this_iters.max(1) as float, |
| 79 self.merged as float / self.this_iters.max(1) as float, |
82 self.inserted as float / self.this_iters.max(1) as float, |
| 80 self.pruned as float / self.this_iters.max(1) as float, |
83 self.merged as float / self.this_iters.max(1) as float, |
| 81 match self.untransported_fraction { |
84 self.pruned as float / self.this_iters.max(1) as float, |
| 82 None => format!(""), |
85 match self.untransported_fraction { |
| 83 Some((a, b)) => if b > 0.0 { |
86 None => format!(""), |
| 84 format!(", untransported {:.2}%", 100.0*a/b) |
87 Some((a, b)) => |
| |
88 if b > 0.0 { |
| |
89 format!(", untransported {:.2}%", 100.0 * a / b) |
| 85 } else { |
90 } else { |
| 86 format!("") |
91 format!("") |
| 87 } |
92 }, |
| 88 }, |
93 }, |
| 89 match self.transport_error { |
94 match self.transport_error { |
| 90 None => format!(""), |
95 None => format!(""), |
| 91 Some((a, b)) => if b > 0.0 { |
96 Some((a, b)) => |
| 92 format!(", transport error {:.2}%", 100.0*a/b) |
97 if b > 0.0 { |
| |
98 format!(", transport error {:.2}%", 100.0 * a / b) |
| 93 } else { |
99 } else { |
| 94 format!("") |
100 format!("") |
| 95 } |
101 }, |
| 96 } |
102 } |
| 97 ).as_str().into() |
103 ) |
| |
104 .as_str() |
| |
105 .into() |
| 98 } |
106 } |
| 99 } |
107 } |
| 100 |
108 |
| 101 /// Branch and bound refinement settings |
109 /// Branch and bound refinement settings |
| 102 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
110 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 103 #[serde(default)] |
111 #[serde(default)] |
| 104 pub struct RefinementSettings<F : Float> { |
112 pub struct RefinementSettings<F: Float> { |
| 105 /// Function value tolerance multiplier for bisection tree refinement in |
113 /// Function value tolerance multiplier for bisection tree refinement in |
| 106 /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions. |
114 /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions. |
| 107 pub tolerance_mult : F, |
115 pub tolerance_mult: F, |
| 108 /// Maximum branch and bound steps |
116 /// Maximum branch and bound steps |
| 109 pub max_steps : usize, |
117 pub max_steps: usize, |
| 110 } |
118 } |
| 111 |
119 |
| 112 #[replace_float_literals(F::cast_from(literal))] |
120 #[replace_float_literals(F::cast_from(literal))] |
| 113 impl<F : Float> Default for RefinementSettings<F> { |
121 impl<F: Float> Default for RefinementSettings<F> { |
| 114 fn default() -> Self { |
122 fn default() -> Self { |
| 115 RefinementSettings { |
123 RefinementSettings { |
| 116 tolerance_mult : 0.1, |
124 tolerance_mult: 0.1, |
| 117 max_steps : 50000, |
125 max_steps: 50000, |
| 118 } |
126 } |
| 119 } |
127 } |
| 120 } |
128 } |
| 121 |
129 |
| 122 /// Data term type |
130 /// Data term type |
| 123 #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)] |
131 #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)] |
| 124 pub enum DataTerm { |
132 pub enum DataTermType { |
| 125 /// $\\|z\\|\_2^2/2$ |
133 /// $\\|z\\|\_2^2/2$ |
| 126 L2Squared, |
134 L222, |
| 127 /// $\\|z\\|\_1$ |
135 /// $\\|z\\|\_1$ |
| 128 L1, |
136 L1, |
| 129 } |
137 } |
| 130 |
138 |
| 131 impl DataTerm { |
139 pub use alg_tools::mapping::Lipschitz; |
| 132 /// Calculate the data term value at residual $z=Aμ - b$. |
|
| 133 pub fn value_at_residual<F : Float, E : Euclidean<F> + Norm<F, L1>>(&self, z : E) -> F { |
|
| 134 match self { |
|
| 135 Self::L2Squared => z.norm2_squared_div2(), |
|
| 136 Self::L1 => z.norm(L1), |
|
| 137 } |
|
| 138 } |
|
| 139 } |
|
| 140 |
|
| 141 /// Type for indicating norm-2-squared data fidelity or transport cost. |
|
| 142 #[derive(Clone, Copy, Serialize, Deserialize)] |
|
| 143 pub struct L2Squared; |
|
| 144 |
|
| 145 /// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`. |
|
| 146 pub trait Lipschitz<M> { |
|
| 147 /// The type of floats |
|
| 148 type FloatType : Float; |
|
| 149 |
|
| 150 /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`. |
|
| 151 fn lipschitz_factor(&self, seminorm : M) -> Option<Self::FloatType>; |
|
| 152 } |
|
| 153 |
140 |
| 154 /// Trait for norm-bounded functions. |
141 /// Trait for norm-bounded functions. |
| 155 pub trait NormBounded<M> { |
142 pub trait NormBounded<M> { |
| 156 type FloatType : Float; |
143 type FloatType: Float; |
| 157 |
144 |
| 158 /// Returns a bound on the values of this function object in the `M`-norm. |
145 /// Returns a bound on the values of this function object in the `M`-norm. |
| 159 fn norm_bound(&self, m : M) -> Self::FloatType; |
146 fn norm_bound(&self, m: M) -> Self::FloatType; |
| 160 } |
147 } |