src/types.rs

changeset 70
ed16d0f10d08
parent 63
7a8a55fd41c0
equal deleted inserted replaced
58:6099ba025aac 70:ed16d0f10d08
1 //! Type definitions and re-exports 1 //! Type definitions and re-exports
2 2
3 use numeric_literals::replace_float_literals; 3 use numeric_literals::replace_float_literals;
4 4
5 use alg_tools::iterate::LogRepr;
5 use colored::ColoredString; 6 use colored::ColoredString;
6 use serde::{Serialize, Deserialize}; 7 use serde::{Deserialize, Serialize};
7 use alg_tools::iterate::LogRepr;
8 use alg_tools::euclidean::Euclidean;
9 use alg_tools::norms::{Norm, L1};
10 8
11 pub use alg_tools::types::*; 9 pub use alg_tools::error::DynResult;
12 pub use alg_tools::loc::Loc; 10 pub use alg_tools::loc::Loc;
13 pub use alg_tools::sets::Cube; 11 pub use alg_tools::sets::Cube;
12 pub use alg_tools::types::*;
14 13
15 // use crate::measures::DiscreteMeasure; 14 // use crate::measures::DiscreteMeasure;
16 15
17 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. 16 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up.
18 pub trait ClapFloat : Float 17 pub trait ClapFloat:
19 + std::str::FromStr<Err=std::num::ParseFloatError> 18 Float + std::str::FromStr<Err = std::num::ParseFloatError> + std::fmt::Display
20 + std::fmt::Display {} 19 {
20 }
21 impl ClapFloat for f32 {} 21 impl ClapFloat for f32 {}
22 impl ClapFloat for f64 {} 22 impl ClapFloat for f64 {}
23 23
24 /// Structure for storing iteration statistics 24 /// Structure for storing transport statistics
25 #[derive(Debug, Clone, Serialize)] 25 #[derive(Debug, Clone, Serialize)]
26 pub struct IterInfo<F : Float, const N : usize> { 26 pub struct TransportInfo<F: Float = f64> {
27 /// Function value 27 /// Tuple of (untransported mass, source mass)
28 pub value : F, 28 pub untransported_fraction: (F, F),
29 /// Number of spikes 29 /// Tuple of (|destination mass - transported_mass|, transported mass)
30 pub n_spikes : usize, 30 pub transport_error: (F, F),
31 /// Number of iterations this statistic covers 31 /// Number of readjustment iterations for transport
32 pub this_iters : usize, 32 pub readjustment_iters: usize,
33 /// Number of spikes inserted since last IterInfo statistic 33 /// ($∫ c_2 dγ , ∫ dγ$)
34 pub inserted : usize, 34 pub dist: (F, F),
35 /// Number of spikes removed by merging since last IterInfo statistic
36 pub merged : usize,
37 /// Number of spikes removed by pruning since last IterInfo statistic
38 pub pruned : usize,
39 /// Number of inner iterations since last IterInfo statistic
40 pub inner_iters : usize,
41 /// Tuple of (transported mass, source mass)
42 pub untransported_fraction : Option<(F, F)>,
43 /// Tuple of (|destination mass - untransported_mass|, transported mass)
44 pub transport_error : Option<(F, F)>,
45 /// Current tolerance
46 pub ε : F,
47 // /// Solve fin.dim problem for this measure to get the optimal `value`.
48 // pub postprocessing : Option<RNDM<F, N>>,
49 } 35 }
50 36
51 impl<F : Float, const N : usize> IterInfo<F, N> { 37 #[replace_float_literals(F::cast_from(literal))]
52 /// Initialise statistics with zeros. `ε` and `value` are unspecified. 38 impl<F: Float> TransportInfo<F> {
39 /// Initialise transport statistics
53 pub fn new() -> Self { 40 pub fn new() -> Self {
54 IterInfo { 41 TransportInfo {
55 value : F::NAN, 42 untransported_fraction: (0.0, 0.0),
56 n_spikes : 0, 43 transport_error: (0.0, 0.0),
57 this_iters : 0, 44 readjustment_iters: 0,
58 merged : 0, 45 dist: (0.0, 0.0),
59 inserted : 0,
60 pruned : 0,
61 inner_iters : 0,
62 ε : F::NAN,
63 // postprocessing : None,
64 untransported_fraction : None,
65 transport_error : None,
66 } 46 }
67 } 47 }
68 } 48 }
69 49
50 /// Structure for storing iteration statistics
51 #[derive(Debug, Clone, Serialize)]
52 pub struct IterInfo<F: Float = f64> {
53 /// Function value
54 pub value: F,
55 /// Number of spikes
56 pub n_spikes: usize,
57 /// Number of iterations this statistic covers
58 pub this_iters: usize,
59 /// Number of spikes inserted since last IterInfo statistic
60 pub inserted: usize,
61 /// Number of spikes removed by merging since last IterInfo statistic
62 pub merged: usize,
63 /// Number of spikes removed by pruning since last IterInfo statistic
64 pub pruned: usize,
65 /// Number of inner iterations since last IterInfo statistic
66 pub inner_iters: usize,
67 /// Transport statistis
68 pub transport: Option<TransportInfo<F>>,
69 /// Current tolerance
70 pub ε: F,
71 // /// Solve fin.dim problem for this measure to get the optimal `value`.
72 // pub postprocessing : Option<RNDM<N, F>>,
73 }
74
75 impl<F: Float> IterInfo<F> {
76 /// Initialise statistics with zeros. `ε` and `value` are unspecified.
77 pub fn new() -> Self {
78 IterInfo {
79 value: F::NAN,
80 n_spikes: 0,
81 this_iters: 0,
82 merged: 0,
83 inserted: 0,
84 pruned: 0,
85 inner_iters: 0,
86 ε: F::NAN,
87 // postprocessing : None,
88 transport: None,
89 }
90 }
91
92 /// Get mutable reference to transport statistics, creating it if it is `None`.
93 pub fn get_transport_mut(&mut self) -> &mut TransportInfo<F> {
94 if self.transport.is_none() {
95 self.transport = Some(TransportInfo::new());
96 }
97 self.transport.as_mut().unwrap()
98 }
99 }
100
70 #[replace_float_literals(F::cast_from(literal))] 101 #[replace_float_literals(F::cast_from(literal))]
71 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { 102 impl<F> LogRepr for IterInfo<F>
103 where
104 F: LogRepr + Float,
105 {
72 fn logrepr(&self) -> ColoredString { 106 fn logrepr(&self) -> ColoredString {
73 format!("{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}", 107 format!(
74 self.value.logrepr(), 108 "{}\t| N = {}, ε = {:.2e}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}",
75 self.n_spikes, 109 self.value.logrepr(),
76 self.ε, 110 self.n_spikes,
77 self.inner_iters as float / self.this_iters.max(1) as float, 111 self.ε,
78 self.inserted as float / self.this_iters.max(1) as float, 112 self.inner_iters as float / self.this_iters.max(1) as float,
79 self.merged as float / self.this_iters.max(1) as float, 113 self.inserted as float / self.this_iters.max(1) as float,
80 self.pruned as float / self.this_iters.max(1) as float, 114 self.merged as float / self.this_iters.max(1) as float,
81 match self.untransported_fraction { 115 self.pruned as float / self.this_iters.max(1) as float,
82 None => format!(""), 116 match &self.transport {
83 Some((a, b)) => if b > 0.0 { 117 None => format!(""),
84 format!(", untransported {:.2}%", 100.0*a/b) 118 Some(t) => {
85 } else { 119 let (a1, b1) = t.untransported_fraction;
86 format!("") 120 let (a2, b2) = t.transport_error;
87 } 121 let (a3, b3) = t.dist;
88 }, 122 format!(
89 match self.transport_error { 123 ", γ-un/er/d/it = {:.2}%/{:.2}%/{:.2e}/{:.2}",
90 None => format!(""), 124 if b1 > 0.0 { 100.0 * a1 / b1 } else { F::NAN },
91 Some((a, b)) => if b > 0.0 { 125 if b2 > 0.0 { 100.0 * a2 / b2 } else { F::NAN },
92 format!(", transport error {:.2}%", 100.0*a/b) 126 if b3 > 0.0 { a3 / b3 } else { F::NAN },
93 } else { 127 t.readjustment_iters as float / self.this_iters.max(1) as float,
94 format!("") 128 )
95 }
96 } 129 }
97 ).as_str().into() 130 }
131 )
132 .as_str()
133 .into()
98 } 134 }
99 } 135 }
100 136
101 /// Branch and bound refinement settings 137 /// Branch and bound refinement settings
102 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 138 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
103 #[serde(default)] 139 #[serde(default)]
104 pub struct RefinementSettings<F : Float> { 140 pub struct RefinementSettings<F: Float> {
105 /// Function value tolerance multiplier for bisection tree refinement in 141 /// Function value tolerance multiplier for bisection tree refinement in
106 /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions. 142 /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions.
107 pub tolerance_mult : F, 143 pub tolerance_mult: F,
108 /// Maximum branch and bound steps 144 /// Maximum branch and bound steps
109 pub max_steps : usize, 145 pub max_steps: usize,
110 } 146 }
111 147
112 #[replace_float_literals(F::cast_from(literal))] 148 #[replace_float_literals(F::cast_from(literal))]
113 impl<F : Float> Default for RefinementSettings<F> { 149 impl<F: Float> Default for RefinementSettings<F> {
114 fn default() -> Self { 150 fn default() -> Self {
115 RefinementSettings { 151 RefinementSettings { tolerance_mult: 0.1, max_steps: 50000 }
116 tolerance_mult : 0.1,
117 max_steps : 50000,
118 }
119 } 152 }
120 } 153 }
121 154
122 /// Data term type 155 /// Data term type
123 #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)] 156 #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)]
124 pub enum DataTerm { 157 pub enum DataTermType {
125 /// $\\|z\\|\_2^2/2$ 158 /// $\\|z\\|\_2^2/2$
126 L2Squared, 159 L222,
127 /// $\\|z\\|\_1$ 160 /// $\\|z\\|\_1$
128 L1, 161 L1,
129 } 162 }
130 163
131 impl DataTerm { 164 pub use alg_tools::mapping::Lipschitz;
132 /// Calculate the data term value at residual $z=Aμ - b$.
133 pub fn value_at_residual<F : Float, E : Euclidean<F> + Norm<F, L1>>(&self, z : E) -> F {
134 match self {
135 Self::L2Squared => z.norm2_squared_div2(),
136 Self::L1 => z.norm(L1),
137 }
138 }
139 }
140
141 /// Type for indicating norm-2-squared data fidelity or transport cost.
142 #[derive(Clone, Copy, Serialize, Deserialize)]
143 pub struct L2Squared;
144
145 /// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`.
146 pub trait Lipschitz<M> {
147 /// The type of floats
148 type FloatType : Float;
149
150 /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`.
151 fn lipschitz_factor(&self, seminorm : M) -> Option<Self::FloatType>;
152 }
153 165
154 /// Trait for norm-bounded functions. 166 /// Trait for norm-bounded functions.
155 pub trait NormBounded<M> { 167 pub trait NormBounded<M> {
156 type FloatType : Float; 168 type FloatType: Float;
157 169
158 /// Returns a bound on the values of this function object in the `M`-norm. 170 /// Returns a bound on the values of this function object in the `M`-norm.
159 fn norm_bound(&self, m : M) -> Self::FloatType; 171 fn norm_bound(&self, m: M) -> Self::FloatType;
160 } 172 }

mercurial