| |
1 //! Type definitions and re-exports |
| |
2 |
| |
3 use numeric_literals::replace_float_literals; |
| |
4 |
| |
5 use colored::ColoredString; |
| |
6 use serde::{Serialize, Deserialize}; |
| |
7 use clap::ValueEnum; |
| |
8 use alg_tools::iterate::LogRepr; |
| |
9 use alg_tools::euclidean::Euclidean; |
| |
10 use alg_tools::norms::{Norm, L1}; |
| |
11 |
| |
12 pub use alg_tools::types::*; |
| |
13 pub use alg_tools::loc::Loc; |
| |
14 pub use alg_tools::sets::Cube; |
| |
15 |
| |
16 use crate::measures::DiscreteMeasure; |
| |
17 |
| |
18 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. |
| |
19 pub trait ClapFloat : Float |
| |
20 + std::str::FromStr<Err=std::num::ParseFloatError> |
| |
21 + std::fmt::Display {} |
| |
22 impl ClapFloat for f32 {} |
| |
23 impl ClapFloat for f64 {} |
| |
24 |
| |
25 /// Structure for storing iteration statistics |
| |
26 #[derive(Debug, Clone, Serialize)] |
| |
27 pub struct IterInfo<F : Float, const N : usize> { |
| |
28 /// Function value |
| |
29 pub value : F, |
| |
30 /// Number of speaks |
| |
31 pub n_spikes : usize, |
| |
32 /// Number of iterations this statistic covers |
| |
33 pub this_iters : usize, |
| |
34 /// Number of spikes removed by merging since last IterInfo statistic |
| |
35 pub merged : usize, |
| |
36 /// Number of spikes removed by pruning since last IterInfo statistic |
| |
37 pub pruned : usize, |
| |
38 /// Number of inner iterations since last IterInfo statistic |
| |
39 pub inner_iters : usize, |
| |
40 /// Current tolerance |
| |
41 pub ε : F, |
| |
42 /// Strict tolerance update if one was used |
| |
43 pub maybe_ε1 : Option<F>, |
| |
44 /// Solve fin.dim problem for this measure to get the optimal `value`. |
| |
45 pub postprocessing : Option<DiscreteMeasure<Loc<F, N>, F>>, |
| |
46 } |
| |
47 |
| |
48 impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { |
| |
49 fn logrepr(&self) -> ColoredString { |
| |
50 let eqsign = match self.maybe_ε1 { |
| |
51 Some(ε1) if ε1 < self.ε => '≛', |
| |
52 _ => '=', |
| |
53 }; |
| |
54 format!("{}\t| N = {}, ε {} {:.8}, inner_iters_mean = {}, merged+pruned_mean = {}+{}", |
| |
55 self.value.logrepr(), |
| |
56 self.n_spikes, |
| |
57 eqsign, |
| |
58 self.ε, |
| |
59 self.inner_iters as float / self.this_iters as float, |
| |
60 self.merged as float / self.this_iters as float, |
| |
61 self.pruned as float / self.this_iters as float, |
| |
62 ).as_str().into() |
| |
63 } |
| |
64 } |
| |
65 |
| |
66 /// Branch and bound refinement settings |
| |
67 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| |
68 #[serde(default)] |
| |
69 pub struct RefinementSettings<F : Float> { |
| |
70 /// Function value tolerance multiplier for bisection tree refinement in |
| |
71 /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions. |
| |
72 pub tolerance_mult : F, |
| |
73 /// Maximum branch and bound steps |
| |
74 pub max_steps : usize, |
| |
75 } |
| |
76 |
| |
77 #[replace_float_literals(F::cast_from(literal))] |
| |
78 impl<F : Float> Default for RefinementSettings<F> { |
| |
79 fn default() -> Self { |
| |
80 RefinementSettings { |
| |
81 tolerance_mult : 0.1, |
| |
82 max_steps : 50000, |
| |
83 } |
| |
84 } |
| |
85 } |
| |
86 |
| |
87 /// Data term type |
| |
88 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug, ValueEnum)] |
| |
89 pub enum DataTerm { |
| |
90 /// $\\|z\\|\_2^2/2$ |
| |
91 L2Squared, |
| |
92 /// $\\|z\\|\_1$ |
| |
93 L1, |
| |
94 } |
| |
95 |
| |
96 impl DataTerm { |
| |
97 /// Calculate the data term value at residual $z=Aμ - b$. |
| |
98 pub fn value_at_residual<F : Float, E : Euclidean<F> + Norm<F, L1>>(&self, z : E) -> F { |
| |
99 match self { |
| |
100 Self::L2Squared => z.norm2_squared_div2(), |
| |
101 Self::L1 => z.norm(L1), |
| |
102 } |
| |
103 } |
| |
104 } |
| |
105 |