--- a/src/types.rs Tue Apr 08 13:31:39 2025 -0500 +++ b/src/types.rs Fri May 08 16:47:58 2026 -0500 @@ -2,159 +2,171 @@ use numeric_literals::replace_float_literals; +use alg_tools::iterate::LogRepr; use colored::ColoredString; -use serde::{Serialize, Deserialize}; -use alg_tools::iterate::LogRepr; -use alg_tools::euclidean::Euclidean; -use alg_tools::norms::{Norm, L1}; +use serde::{Deserialize, Serialize}; -pub use alg_tools::types::*; +pub use alg_tools::error::DynResult; pub use alg_tools::loc::Loc; pub use alg_tools::sets::Cube; +pub use alg_tools::types::*; // use crate::measures::DiscreteMeasure; /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up. -pub trait ClapFloat : Float - + std::str::FromStr<Err=std::num::ParseFloatError> - + std::fmt::Display {} +pub trait ClapFloat: + Float + std::str::FromStr<Err = std::num::ParseFloatError> + std::fmt::Display +{ +} impl ClapFloat for f32 {} impl ClapFloat for f64 {} -/// Structure for storing iteration statistics +/// Structure for storing transport statistics #[derive(Debug, Clone, Serialize)] -pub struct IterInfo<F : Float, const N : usize> { - /// Function value - pub value : F, - /// Number of spikes - pub n_spikes : usize, - /// Number of iterations this statistic covers - pub this_iters : usize, - /// Number of spikes inserted since last IterInfo statistic - pub inserted : usize, - /// Number of spikes removed by merging since last IterInfo statistic - pub merged : usize, - /// Number of spikes removed by pruning since last IterInfo statistic - pub pruned : usize, - /// Number of inner iterations since last IterInfo statistic - pub inner_iters : usize, - /// Tuple of (transported mass, source mass) - pub untransported_fraction : Option<(F, F)>, - /// Tuple of (|destination mass - untransported_mass|, transported mass) - pub transport_error : Option<(F, F)>, - /// Current tolerance - pub ε : F, - // /// Solve fin.dim problem for this measure to get the optimal `value`. - // pub postprocessing : Option<RNDM<F, N>>, +pub struct TransportInfo<F: Float = f64> { + /// Tuple of (untransported mass, source mass) + pub untransported_fraction: (F, F), + /// Tuple of (|destination mass - transported_mass|, transported mass) + pub transport_error: (F, F), + /// Number of readjustment iterations for transport + pub readjustment_iters: usize, + /// ($∫ c_2 dγ , ∫ dγ$) + pub dist: (F, F), } -impl<F : Float, const N : usize> IterInfo<F, N> { - /// Initialise statistics with zeros. `ε` and `value` are unspecified. +#[replace_float_literals(F::cast_from(literal))] +impl<F: Float> TransportInfo<F> { + /// Initialise transport statistics pub fn new() -> Self { - IterInfo { - value : F::NAN, - n_spikes : 0, - this_iters : 0, - merged : 0, - inserted : 0, - pruned : 0, - inner_iters : 0, - ε : F::NAN, - // postprocessing : None, - untransported_fraction : None, - transport_error : None, + TransportInfo { + untransported_fraction: (0.0, 0.0), + transport_error: (0.0, 0.0), + readjustment_iters: 0, + dist: (0.0, 0.0), } } } +/// Structure for storing iteration statistics +#[derive(Debug, Clone, Serialize)] +pub struct IterInfo<F: Float = f64> { + /// Function value + pub value: F, + /// Number of spikes + pub n_spikes: usize, + /// Number of iterations this statistic covers + pub this_iters: usize, + /// Number of spikes inserted since last IterInfo statistic + pub inserted: usize, + /// Number of spikes removed by merging since last IterInfo statistic + pub merged: usize, + /// Number of spikes removed by pruning since last IterInfo statistic + pub pruned: usize, + /// Number of inner iterations since last IterInfo statistic + pub inner_iters: usize, + /// Transport statistis + pub transport: Option<TransportInfo<F>>, + /// Current tolerance + pub ε: F, + // /// Solve fin.dim problem for this measure to get the optimal `value`. + // pub postprocessing : Option<RNDM<N, F>>, +} + +impl<F: Float> IterInfo<F> { + /// Initialise statistics with zeros. `ε` and `value` are unspecified. + pub fn new() -> Self { + IterInfo { + value: F::NAN, + n_spikes: 0, + this_iters: 0, + merged: 0, + inserted: 0, + pruned: 0, + inner_iters: 0, + ε: F::NAN, + // postprocessing : None, + transport: None, + } + } + + /// Get mutable reference to transport statistics, creating it if it is `None`. + pub fn get_transport_mut(&mut self) -> &mut TransportInfo<F> { + if self.transport.is_none() { + self.transport = Some(TransportInfo::new()); + } + self.transport.as_mut().unwrap() + } +} + #[replace_float_literals(F::cast_from(literal))] -impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float { +impl<F> LogRepr for IterInfo<F> +where + F: LogRepr + Float, +{ fn logrepr(&self) -> ColoredString { - format!("{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}", - self.value.logrepr(), - self.n_spikes, - self.ε, - self.inner_iters as float / self.this_iters.max(1) as float, - self.inserted as float / self.this_iters.max(1) as float, - self.merged as float / self.this_iters.max(1) as float, - self.pruned as float / self.this_iters.max(1) as float, - match self.untransported_fraction { - None => format!(""), - Some((a, b)) => if b > 0.0 { - format!(", untransported {:.2}%", 100.0*a/b) - } else { - format!("") - } - }, - match self.transport_error { - None => format!(""), - Some((a, b)) => if b > 0.0 { - format!(", transport error {:.2}%", 100.0*a/b) - } else { - format!("") - } + format!( + "{}\t| N = {}, ε = {:.2e}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}", + self.value.logrepr(), + self.n_spikes, + self.ε, + self.inner_iters as float / self.this_iters.max(1) as float, + self.inserted as float / self.this_iters.max(1) as float, + self.merged as float / self.this_iters.max(1) as float, + self.pruned as float / self.this_iters.max(1) as float, + match &self.transport { + None => format!(""), + Some(t) => { + let (a1, b1) = t.untransported_fraction; + let (a2, b2) = t.transport_error; + let (a3, b3) = t.dist; + format!( + ", γ-un/er/d/it = {:.2}%/{:.2}%/{:.2e}/{:.2}", + if b1 > 0.0 { 100.0 * a1 / b1 } else { F::NAN }, + if b2 > 0.0 { 100.0 * a2 / b2 } else { F::NAN }, + if b3 > 0.0 { a3 / b3 } else { F::NAN }, + t.readjustment_iters as float / self.this_iters.max(1) as float, + ) } - ).as_str().into() + } + ) + .as_str() + .into() } } /// Branch and bound refinement settings #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[serde(default)] -pub struct RefinementSettings<F : Float> { +pub struct RefinementSettings<F: Float> { /// Function value tolerance multiplier for bisection tree refinement in /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions. - pub tolerance_mult : F, + pub tolerance_mult: F, /// Maximum branch and bound steps - pub max_steps : usize, + pub max_steps: usize, } #[replace_float_literals(F::cast_from(literal))] -impl<F : Float> Default for RefinementSettings<F> { +impl<F: Float> Default for RefinementSettings<F> { fn default() -> Self { - RefinementSettings { - tolerance_mult : 0.1, - max_steps : 50000, - } + RefinementSettings { tolerance_mult: 0.1, max_steps: 50000 } } } /// Data term type #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)] -pub enum DataTerm { +pub enum DataTermType { /// $\\|z\\|\_2^2/2$ - L2Squared, + L222, /// $\\|z\\|\_1$ L1, } -impl DataTerm { - /// Calculate the data term value at residual $z=Aμ - b$. - pub fn value_at_residual<F : Float, E : Euclidean<F> + Norm<F, L1>>(&self, z : E) -> F { - match self { - Self::L2Squared => z.norm2_squared_div2(), - Self::L1 => z.norm(L1), - } - } -} - -/// Type for indicating norm-2-squared data fidelity or transport cost. -#[derive(Clone, Copy, Serialize, Deserialize)] -pub struct L2Squared; - -/// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`. -pub trait Lipschitz<M> { - /// The type of floats - type FloatType : Float; - - /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`. - fn lipschitz_factor(&self, seminorm : M) -> Option<Self::FloatType>; -} +pub use alg_tools::mapping::Lipschitz; /// Trait for norm-bounded functions. pub trait NormBounded<M> { - type FloatType : Float; + type FloatType: Float; /// Returns a bound on the values of this function object in the `M`-norm. - fn norm_bound(&self, m : M) -> Self::FloatType; + fn norm_bound(&self, m: M) -> Self::FloatType; }