src/types.rs

branch
dev
changeset 61
4f468d35fa29
parent 35
b087e3eab191
child 63
7a8a55fd41c0
equal deleted inserted replaced
60:9738b51d90d7 61:4f468d35fa29
1 //! Type definitions and re-exports 1 //! Type definitions and re-exports
2 2
3 use numeric_literals::replace_float_literals; 3 use numeric_literals::replace_float_literals;
4 4
5 use alg_tools::iterate::LogRepr;
5 use colored::ColoredString; 6 use colored::ColoredString;
6 use serde::{Serialize, Deserialize}; 7 use serde::{Deserialize, Serialize};
7 use alg_tools::iterate::LogRepr;
8 use alg_tools::euclidean::Euclidean;
9 use alg_tools::norms::{Norm, L1};
10 8
11 pub use alg_tools::types::*; 9 pub use alg_tools::error::DynResult;
12 pub use alg_tools::loc::Loc; 10 pub use alg_tools::loc::Loc;
13 pub use alg_tools::sets::Cube; 11 pub use alg_tools::sets::Cube;
12 pub use alg_tools::types::*;
14 13
15 // use crate::measures::DiscreteMeasure; 14 // use crate::measures::DiscreteMeasure;
16 15
17 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. 16 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up.
18 pub trait ClapFloat : Float 17 pub trait ClapFloat:
19 + std::str::FromStr<Err=std::num::ParseFloatError> 18 Float + std::str::FromStr<Err = std::num::ParseFloatError> + std::fmt::Display
20 + std::fmt::Display {} 19 {
20 }
21 impl ClapFloat for f32 {} 21 impl ClapFloat for f32 {}
22 impl ClapFloat for f64 {} 22 impl ClapFloat for f64 {}
23 23
24 /// Structure for storing iteration statistics 24 /// Structure for storing iteration statistics
25 #[derive(Debug, Clone, Serialize)] 25 #[derive(Debug, Clone, Serialize)]
26 pub struct IterInfo<F : Float, const N : usize> { 26 pub struct IterInfo<F: Float = f64> {
27 /// Function value 27 /// Function value
28 pub value : F, 28 pub value: F,
29 /// Number of spikes 29 /// Number of spikes
30 pub n_spikes : usize, 30 pub n_spikes: usize,
31 /// Number of iterations this statistic covers 31 /// Number of iterations this statistic covers
32 pub this_iters : usize, 32 pub this_iters: usize,
33 /// Number of spikes inserted since last IterInfo statistic 33 /// Number of spikes inserted since last IterInfo statistic
34 pub inserted : usize, 34 pub inserted: usize,
35 /// Number of spikes removed by merging since last IterInfo statistic 35 /// Number of spikes removed by merging since last IterInfo statistic
36 pub merged : usize, 36 pub merged: usize,
37 /// Number of spikes removed by pruning since last IterInfo statistic 37 /// Number of spikes removed by pruning since last IterInfo statistic
38 pub pruned : usize, 38 pub pruned: usize,
39 /// Number of inner iterations since last IterInfo statistic 39 /// Number of inner iterations since last IterInfo statistic
40 pub inner_iters : usize, 40 pub inner_iters: usize,
41 /// Tuple of (transported mass, source mass) 41 /// Tuple of (transported mass, source mass)
42 pub untransported_fraction : Option<(F, F)>, 42 pub untransported_fraction: Option<(F, F)>,
43 /// Tuple of (|destination mass - untransported_mass|, transported mass) 43 /// Tuple of (|destination mass - untransported_mass|, transported mass)
44 pub transport_error : Option<(F, F)>, 44 pub transport_error: Option<(F, F)>,
45 /// Current tolerance 45 /// Current tolerance
46 pub ε : F, 46 pub ε: F,
47 // /// Solve fin.dim problem for this measure to get the optimal `value`. 47 // /// Solve fin.dim problem for this measure to get the optimal `value`.
48 // pub postprocessing : Option<RNDM<F, N>>, 48 // pub postprocessing : Option<RNDM<N, F>>,
49 } 49 }
50 50
51 impl<F : Float, const N : usize> IterInfo<F, N> { 51 impl<F: Float> IterInfo<F> {
52 /// Initialise statistics with zeros. `ε` and `value` are unspecified. 52 /// Initialise statistics with zeros. `ε` and `value` are unspecified.
53 pub fn new() -> Self { 53 pub fn new() -> Self {
54 IterInfo { 54 IterInfo {
55 value : F::NAN, 55 value: F::NAN,
56 n_spikes : 0, 56 n_spikes: 0,
57 this_iters : 0, 57 this_iters: 0,
58 merged : 0, 58 merged: 0,
59 inserted : 0, 59 inserted: 0,
60 pruned : 0, 60 pruned: 0,
61 inner_iters : 0, 61 inner_iters: 0,
62 ε : F::NAN, 62 ε: F::NAN,
63 // postprocessing : None, 63 // postprocessing : None,
64 untransported_fraction : None, 64 untransported_fraction: None,
65 transport_error : None, 65 transport_error: None,
66 } 66 }
67 } 67 }
68 } 68 }
69 69
70 #[replace_float_literals(F::cast_from(literal))] 70 #[replace_float_literals(F::cast_from(literal))]
71 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { 71 impl<F> LogRepr for IterInfo<F>
72 where
73 F: LogRepr + Float,
74 {
72 fn logrepr(&self) -> ColoredString { 75 fn logrepr(&self) -> ColoredString {
73 format!("{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}", 76 format!(
74 self.value.logrepr(), 77 "{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}",
75 self.n_spikes, 78 self.value.logrepr(),
76 self.ε, 79 self.n_spikes,
77 self.inner_iters as float / self.this_iters.max(1) as float, 80 self.ε,
78 self.inserted as float / self.this_iters.max(1) as float, 81 self.inner_iters as float / self.this_iters.max(1) as float,
79 self.merged as float / self.this_iters.max(1) as float, 82 self.inserted as float / self.this_iters.max(1) as float,
80 self.pruned as float / self.this_iters.max(1) as float, 83 self.merged as float / self.this_iters.max(1) as float,
81 match self.untransported_fraction { 84 self.pruned as float / self.this_iters.max(1) as float,
82 None => format!(""), 85 match self.untransported_fraction {
83 Some((a, b)) => if b > 0.0 { 86 None => format!(""),
84 format!(", untransported {:.2}%", 100.0*a/b) 87 Some((a, b)) =>
88 if b > 0.0 {
89 format!(", untransported {:.2}%", 100.0 * a / b)
85 } else { 90 } else {
86 format!("") 91 format!("")
87 } 92 },
88 }, 93 },
89 match self.transport_error { 94 match self.transport_error {
90 None => format!(""), 95 None => format!(""),
91 Some((a, b)) => if b > 0.0 { 96 Some((a, b)) =>
92 format!(", transport error {:.2}%", 100.0*a/b) 97 if b > 0.0 {
98 format!(", transport error {:.2}%", 100.0 * a / b)
93 } else { 99 } else {
94 format!("") 100 format!("")
95 } 101 },
96 } 102 }
97 ).as_str().into() 103 )
104 .as_str()
105 .into()
98 } 106 }
99 } 107 }
100 108
101 /// Branch and bound refinement settings 109 /// Branch and bound refinement settings
102 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 110 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
103 #[serde(default)] 111 #[serde(default)]
104 pub struct RefinementSettings<F : Float> { 112 pub struct RefinementSettings<F: Float> {
105 /// Function value tolerance multiplier for bisection tree refinement in 113 /// Function value tolerance multiplier for bisection tree refinement in
106 /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions. 114 /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions.
107 pub tolerance_mult : F, 115 pub tolerance_mult: F,
108 /// Maximum branch and bound steps 116 /// Maximum branch and bound steps
109 pub max_steps : usize, 117 pub max_steps: usize,
110 } 118 }
111 119
112 #[replace_float_literals(F::cast_from(literal))] 120 #[replace_float_literals(F::cast_from(literal))]
113 impl<F : Float> Default for RefinementSettings<F> { 121 impl<F: Float> Default for RefinementSettings<F> {
114 fn default() -> Self { 122 fn default() -> Self {
115 RefinementSettings { 123 RefinementSettings {
116 tolerance_mult : 0.1, 124 tolerance_mult: 0.1,
117 max_steps : 50000, 125 max_steps: 50000,
118 } 126 }
119 } 127 }
120 } 128 }
121 129
122 /// Data term type 130 /// Data term type
123 #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)] 131 #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)]
124 pub enum DataTerm { 132 pub enum DataTermType {
125 /// $\\|z\\|\_2^2/2$ 133 /// $\\|z\\|\_2^2/2$
126 L2Squared, 134 L222,
127 /// $\\|z\\|\_1$ 135 /// $\\|z\\|\_1$
128 L1, 136 L1,
129 } 137 }
130 138
131 impl DataTerm { 139 pub use alg_tools::mapping::Lipschitz;
132 /// Calculate the data term value at residual $z=Aμ - b$.
133 pub fn value_at_residual<F : Float, E : Euclidean<F> + Norm<F, L1>>(&self, z : E) -> F {
134 match self {
135 Self::L2Squared => z.norm2_squared_div2(),
136 Self::L1 => z.norm(L1),
137 }
138 }
139 }
140
141 /// Type for indicating norm-2-squared data fidelity or transport cost.
142 #[derive(Clone, Copy, Serialize, Deserialize)]
143 pub struct L2Squared;
144
145 /// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`.
146 pub trait Lipschitz<M> {
147 /// The type of floats
148 type FloatType : Float;
149
150 /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`.
151 fn lipschitz_factor(&self, seminorm : M) -> Option<Self::FloatType>;
152 }
153 140
154 /// Trait for norm-bounded functions. 141 /// Trait for norm-bounded functions.
155 pub trait NormBounded<M> { 142 pub trait NormBounded<M> {
156 type FloatType : Float; 143 type FloatType: Float;
157 144
158 /// Returns a bound on the values of this function object in the `M`-norm. 145 /// Returns a bound on the values of this function object in the `M`-norm.
159 fn norm_bound(&self, m : M) -> Self::FloatType; 146 fn norm_bound(&self, m: M) -> Self::FloatType;
160 } 147 }

mercurial