Fri, 02 Dec 2022 18:08:40 +0200
Remove ergodic tolerance; it's not useful.
| 0 | 1 | //! Type definitions and re-exports |
| 2 | ||
| 3 | use numeric_literals::replace_float_literals; | |
| 4 | ||
| 5 | use colored::ColoredString; | |
| 6 | use serde::{Serialize, Deserialize}; | |
| 7 | use clap::ValueEnum; | |
| 8 | use alg_tools::iterate::LogRepr; | |
| 9 | use alg_tools::euclidean::Euclidean; | |
| 10 | use alg_tools::norms::{Norm, L1}; | |
| 11 | ||
| 12 | pub use alg_tools::types::*; | |
| 13 | pub use alg_tools::loc::Loc; | |
| 14 | pub use alg_tools::sets::Cube; | |
| 15 | ||
| 16 | use crate::measures::DiscreteMeasure; | |
| 17 | ||
| 18 | /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. | |
| 19 | pub trait ClapFloat : Float | |
| 20 | + std::str::FromStr<Err=std::num::ParseFloatError> | |
| 21 | + std::fmt::Display {} | |
| 22 | impl ClapFloat for f32 {} | |
| 23 | impl ClapFloat for f64 {} | |
| 24 | ||
| 25 | /// Structure for storing iteration statistics | |
| 26 | #[derive(Debug, Clone, Serialize)] | |
| 27 | pub struct IterInfo<F : Float, const N : usize> { | |
| 28 | /// Function value | |
| 29 | pub value : F, | |
| 30 | /// Number of speaks | |
| 31 | pub n_spikes : usize, | |
| 32 | /// Number of iterations this statistic covers | |
| 33 | pub this_iters : usize, | |
| 34 | /// Number of spikes removed by merging since last IterInfo statistic | |
| 35 | pub merged : usize, | |
| 36 | /// Number of spikes removed by pruning since last IterInfo statistic | |
| 37 | pub pruned : usize, | |
| 38 | /// Number of inner iterations since last IterInfo statistic | |
| 39 | pub inner_iters : usize, | |
| 40 | /// Current tolerance | |
| 41 | pub ε : F, | |
| 42 | /// Strict tolerance update if one was used | |
| 43 | pub maybe_ε1 : Option<F>, | |
| 44 | /// Solve fin.dim problem for this measure to get the optimal `value`. | |
| 45 | pub postprocessing : Option<DiscreteMeasure<Loc<F, N>, F>>, | |
| 46 | } | |
| 47 | ||
| 48 | impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { | |
| 49 | fn logrepr(&self) -> ColoredString { | |
| 50 | let eqsign = match self.maybe_ε1 { | |
| 51 | Some(ε1) if ε1 < self.ε => '≛', | |
| 52 | _ => '=', | |
| 53 | }; | |
| 54 | format!("{}\t| N = {}, ε {} {:.8}, inner_iters_mean = {}, merged+pruned_mean = {}+{}", | |
| 55 | self.value.logrepr(), | |
| 56 | self.n_spikes, | |
| 57 | eqsign, | |
| 58 | self.ε, | |
| 59 | self.inner_iters as float / self.this_iters as float, | |
| 60 | self.merged as float / self.this_iters as float, | |
| 61 | self.pruned as float / self.this_iters as float, | |
| 62 | ).as_str().into() | |
| 63 | } | |
| 64 | } | |
| 65 | ||
| 66 | /// Branch and bound refinement settings | |
| 67 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
| 68 | #[serde(default)] | |
| 69 | pub struct RefinementSettings<F : Float> { | |
| 70 | /// Function value tolerance multiplier for bisection tree refinement in | |
| 71 | /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions. | |
| 72 | pub tolerance_mult : F, | |
| 73 | /// Maximum branch and bound steps | |
| 74 | pub max_steps : usize, | |
| 75 | } | |
| 76 | ||
| 77 | #[replace_float_literals(F::cast_from(literal))] | |
| 78 | impl<F : Float> Default for RefinementSettings<F> { | |
| 79 | fn default() -> Self { | |
| 80 | RefinementSettings { | |
| 81 | tolerance_mult : 0.1, | |
| 82 | max_steps : 50000, | |
| 83 | } | |
| 84 | } | |
| 85 | } | |
| 86 | ||
| 87 | /// Data term type | |
| 88 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug, ValueEnum)] | |
| 89 | pub enum DataTerm { | |
| 90 | /// $\\|z\\|\_2^2/2$ | |
| 91 | L2Squared, | |
| 92 | /// $\\|z\\|\_1$ | |
| 93 | L1, | |
| 94 | } | |
| 95 | ||
| 96 | impl DataTerm { | |
| 97 | /// Calculate the data term value at residual $z=Aμ - b$. | |
| 98 | pub fn value_at_residual<F : Float, E : Euclidean<F> + Norm<F, L1>>(&self, z : E) -> F { | |
| 99 | match self { | |
| 100 | Self::L2Squared => z.norm2_squared_div2(), | |
| 101 | Self::L1 => z.norm(L1), | |
| 102 | } | |
| 103 | } | |
| 104 | } | |
| 105 |