| 2 | 2 | 
| 3 use numeric_literals::replace_float_literals; | 3 use numeric_literals::replace_float_literals; | 
| 4 | 4 | 
| 5 use colored::ColoredString; | 5 use colored::ColoredString; | 
| 6 use serde::{Serialize, Deserialize}; | 6 use serde::{Serialize, Deserialize}; | 
| 7 use clap::ValueEnum; |  | 
| 8 use alg_tools::iterate::LogRepr; | 7 use alg_tools::iterate::LogRepr; | 
| 9 use alg_tools::euclidean::Euclidean; | 8 use alg_tools::euclidean::Euclidean; | 
| 10 use alg_tools::norms::{Norm, L1}; | 9 use alg_tools::norms::{Norm, L1}; | 
| 11 | 10 | 
| 12 pub use alg_tools::types::*; | 11 pub use alg_tools::types::*; | 
| 13 pub use alg_tools::loc::Loc; | 12 pub use alg_tools::loc::Loc; | 
| 14 pub use alg_tools::sets::Cube; | 13 pub use alg_tools::sets::Cube; | 
| 15 | 14 | 
| 16 use crate::measures::DiscreteMeasure; | 15 // use crate::measures::DiscreteMeasure; | 
| 17 | 16 | 
| 18 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. | 17 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. | 
| 19 pub trait ClapFloat : Float | 18 pub trait ClapFloat : Float | 
| 20                       + std::str::FromStr<Err=std::num::ParseFloatError> | 19                       + std::str::FromStr<Err=std::num::ParseFloatError> | 
| 21                       + std::fmt::Display {} | 20                       + std::fmt::Display {} | 
| 25 /// Structure for storing iteration statistics | 24 /// Structure for storing iteration statistics | 
| 26 #[derive(Debug, Clone, Serialize)] | 25 #[derive(Debug, Clone, Serialize)] | 
| 27 pub struct IterInfo<F : Float, const N : usize> { | 26 pub struct IterInfo<F : Float, const N : usize> { | 
| 28     /// Function value | 27     /// Function value | 
| 29     pub value : F, | 28     pub value : F, | 
| 30     /// Number of speaks | 29     /// Number of spikes | 
| 31     pub n_spikes : usize, | 30     pub n_spikes : usize, | 
| 32     /// Number of iterations this statistic covers | 31     /// Number of iterations this statistic covers | 
| 33     pub this_iters : usize, | 32     pub this_iters : usize, | 
|  | 33     /// Number of spikes inserted since last IterInfo statistic | 
|  | 34     pub inserted : usize, | 
| 34     /// Number of spikes removed by merging since last IterInfo statistic | 35     /// Number of spikes removed by merging since last IterInfo statistic | 
| 35     pub merged : usize, | 36     pub merged : usize, | 
| 36     /// Number of spikes removed by pruning since last IterInfo statistic | 37     /// Number of spikes removed by pruning since last IterInfo statistic | 
| 37     pub pruned : usize, | 38     pub pruned : usize, | 
| 38     /// Number of inner iterations since last IterInfo statistic | 39     /// Number of inner iterations since last IterInfo statistic | 
| 39     pub inner_iters : usize, | 40     pub inner_iters : usize, | 
|  | 41     /// Tuple of (transported mass, source mass) | 
|  | 42     pub untransported_fraction : Option<(F, F)>, | 
|  | 43     /// Tuple of (|destination mass - untransported_mass|, transported mass) | 
|  | 44     pub transport_error : Option<(F, F)>, | 
| 40     /// Current tolerance | 45     /// Current tolerance | 
| 41     pub ε : F, | 46     pub ε : F, | 
| 42     /// Solve fin.dim problem for this measure to get the optimal `value`. | 47     // /// Solve fin.dim problem for this measure to get the optimal `value`. | 
| 43     pub postprocessing : Option<DiscreteMeasure<Loc<F, N>, F>>, | 48     // pub postprocessing : Option<RNDM<F, N>>, | 
| 44 } | 49 } | 
| 45 | 50 | 
|  | 51 impl<F : Float, const N : usize>  IterInfo<F, N> { | 
|  | 52     /// Initialise statistics with zeros. `ε` and `value` are unspecified. | 
|  | 53     pub fn new() -> Self { | 
|  | 54         IterInfo { | 
|  | 55             value : F::NAN, | 
|  | 56             n_spikes : 0, | 
|  | 57             this_iters : 0, | 
|  | 58             merged : 0, | 
|  | 59             inserted : 0, | 
|  | 60             pruned : 0, | 
|  | 61             inner_iters : 0, | 
|  | 62             ε : F::NAN, | 
|  | 63             // postprocessing : None, | 
|  | 64             untransported_fraction : None, | 
|  | 65             transport_error : None, | 
|  | 66         } | 
|  | 67     } | 
|  | 68 } | 
|  | 69 | 
|  | 70 #[replace_float_literals(F::cast_from(literal))] | 
| 46 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { | 71 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { | 
| 47     fn logrepr(&self) -> ColoredString { | 72     fn logrepr(&self) -> ColoredString { | 
| 48         format!("{}\t| N = {}, ε = {:.8}, inner_iters_mean = {}, merged+pruned_mean = {}+{}", | 73         format!("{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}", | 
| 49                 self.value.logrepr(), | 74                 self.value.logrepr(), | 
| 50                 self.n_spikes, | 75                 self.n_spikes, | 
| 51                 self.ε, | 76                 self.ε, | 
| 52                 self.inner_iters as float / self.this_iters as float, | 77                 self.inner_iters as float / self.this_iters.max(1) as float, | 
| 53                 self.merged as float / self.this_iters as float, | 78                 self.inserted as float / self.this_iters.max(1) as float, | 
| 54                 self.pruned as float / self.this_iters as float, | 79                 self.merged as float / self.this_iters.max(1) as float, | 
|  | 80                 self.pruned as float / self.this_iters.max(1) as float, | 
|  | 81                 match self.untransported_fraction { | 
|  | 82                     None => format!(""), | 
|  | 83                     Some((a, b)) => if b > 0.0 { | 
|  | 84                         format!(", untransported {:.2}%", 100.0*a/b) | 
|  | 85                     } else { | 
|  | 86                         format!("") | 
|  | 87                     } | 
|  | 88                 }, | 
|  | 89                 match self.transport_error { | 
|  | 90                     None => format!(""), | 
|  | 91                     Some((a, b)) => if b > 0.0 { | 
|  | 92                         format!(", transport error {:.2}%", 100.0*a/b) | 
|  | 93                     } else { | 
|  | 94                         format!("") | 
|  | 95                     } | 
|  | 96                 } | 
| 55         ).as_str().into() | 97         ).as_str().into() | 
| 56     } | 98     } | 
| 57 } | 99 } | 
| 58 | 100 | 
| 59 /// Branch and bound refinement settings | 101 /// Branch and bound refinement settings | 
| 93             Self::L2Squared => z.norm2_squared_div2(), | 135             Self::L2Squared => z.norm2_squared_div2(), | 
| 94             Self::L1 => z.norm(L1), | 136             Self::L1 => z.norm(L1), | 
| 95         } | 137         } | 
| 96     } | 138     } | 
| 97 } | 139 } | 
|  | 140 | 
|  | 141 /// Type for indicating norm-2-squared data fidelity or transport cost. | 
|  | 142 #[derive(Clone, Copy, Serialize, Deserialize)] | 
|  | 143 pub struct L2Squared; | 
|  | 144 | 
|  | 145 /// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`. | 
|  | 146 pub trait Lipschitz<M> { | 
|  | 147     /// The type of floats | 
|  | 148     type FloatType : Float; | 
|  | 149 | 
|  | 150     /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`. | 
|  | 151     fn lipschitz_factor(&self, seminorm : M) -> Option<Self::FloatType>; | 
|  | 152 } | 
|  | 153 | 
|  | 154 /// Trait for norm-bounded functions. | 
|  | 155 pub trait NormBounded<M> { | 
|  | 156     type FloatType : Float; | 
|  | 157 | 
|  | 158     /// Returns a bound on the values of this function object in the `M`-norm. | 
|  | 159     fn norm_bound(&self, m : M) -> Self::FloatType; | 
|  | 160 } |