src/types.rs

Thu, 19 Mar 2026 18:21:17 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 19 Mar 2026 18:21:17 -0500
branch
dev
changeset 67
95bb12bdb6ac
parent 63
7a8a55fd41c0
permissions
-rw-r--r--

Remove apparently now unused dec2flt feature request

//! Type definitions and re-exports

use numeric_literals::replace_float_literals;

use alg_tools::iterate::LogRepr;
use colored::ColoredString;
use serde::{Deserialize, Serialize};

pub use alg_tools::error::DynResult;
pub use alg_tools::loc::Loc;
pub use alg_tools::sets::Cube;
pub use alg_tools::types::*;

// use crate::measures::DiscreteMeasure;

/// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up.
pub trait ClapFloat:
    Float + std::str::FromStr<Err = std::num::ParseFloatError> + std::fmt::Display
{
}
impl ClapFloat for f32 {}
impl ClapFloat for f64 {}

/// Structure for storing transport statistics
#[derive(Debug, Clone, Serialize)]
pub struct TransportInfo<F: Float = f64> {
    /// Tuple of (untransported mass, source mass)
    pub untransported_fraction: (F, F),
    /// Tuple of (|destination mass - transported_mass|, transported mass)
    pub transport_error: (F, F),
    /// Number of readjustment iterations for transport
    pub readjustment_iters: usize,
    /// ($∫ c_2 dγ , ∫ dγ$)
    pub dist: (F, F),
}

#[replace_float_literals(F::cast_from(literal))]
impl<F: Float> TransportInfo<F> {
    /// Initialise transport statistics
    pub fn new() -> Self {
        TransportInfo {
            untransported_fraction: (0.0, 0.0),
            transport_error: (0.0, 0.0),
            readjustment_iters: 0,
            dist: (0.0, 0.0),
        }
    }
}

/// Structure for storing iteration statistics
#[derive(Debug, Clone, Serialize)]
pub struct IterInfo<F: Float = f64> {
    /// Function value
    pub value: F,
    /// Number of spikes
    pub n_spikes: usize,
    /// Number of iterations this statistic covers
    pub this_iters: usize,
    /// Number of spikes inserted since last IterInfo statistic
    pub inserted: usize,
    /// Number of spikes removed by merging since last IterInfo statistic
    pub merged: usize,
    /// Number of spikes removed by pruning since last IterInfo statistic
    pub pruned: usize,
    /// Number of inner iterations since last IterInfo statistic
    pub inner_iters: usize,
    /// Transport statistis
    pub transport: Option<TransportInfo<F>>,
    /// Current tolerance
    pub ε: F,
    // /// Solve fin.dim problem for this measure to get the optimal `value`.
    // pub postprocessing : Option<RNDM<N, F>>,
}

impl<F: Float> IterInfo<F> {
    /// Initialise statistics with zeros. `ε` and `value` are unspecified.
    pub fn new() -> Self {
        IterInfo {
            value: F::NAN,
            n_spikes: 0,
            this_iters: 0,
            merged: 0,
            inserted: 0,
            pruned: 0,
            inner_iters: 0,
            ε: F::NAN,
            // postprocessing : None,
            transport: None,
        }
    }

    /// Get mutable reference to transport statistics, creating it if it is `None`.
    pub fn get_transport_mut(&mut self) -> &mut TransportInfo<F> {
        if self.transport.is_none() {
            self.transport = Some(TransportInfo::new());
        }
        self.transport.as_mut().unwrap()
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F> LogRepr for IterInfo<F>
where
    F: LogRepr + Float,
{
    fn logrepr(&self) -> ColoredString {
        format!(
            "{}\t| N = {}, ε = {:.2e}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}",
            self.value.logrepr(),
            self.n_spikes,
            self.ε,
            self.inner_iters as float / self.this_iters.max(1) as float,
            self.inserted as float / self.this_iters.max(1) as float,
            self.merged as float / self.this_iters.max(1) as float,
            self.pruned as float / self.this_iters.max(1) as float,
            match &self.transport {
                None => format!(""),
                Some(t) => {
                    let (a1, b1) = t.untransported_fraction;
                    let (a2, b2) = t.transport_error;
                    let (a3, b3) = t.dist;
                    format!(
                        ", γ-un/er/d/it = {:.2}%/{:.2}%/{:.2e}/{:.2}",
                        if b1 > 0.0 { 100.0 * a1 / b1 } else { F::NAN },
                        if b2 > 0.0 { 100.0 * a2 / b2 } else { F::NAN },
                        if b3 > 0.0 { a3 / b3 } else { F::NAN },
                        t.readjustment_iters as float / self.this_iters.max(1) as float,
                    )
                }
            }
        )
        .as_str()
        .into()
    }
}

/// Branch and bound refinement settings
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct RefinementSettings<F: Float> {
    /// Function value tolerance multiplier for bisection tree refinement in
    /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions.
    pub tolerance_mult: F,
    /// Maximum branch and bound steps
    pub max_steps: usize,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F: Float> Default for RefinementSettings<F> {
    fn default() -> Self {
        RefinementSettings { tolerance_mult: 0.1, max_steps: 50000 }
    }
}

/// Data term type
#[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)]
pub enum DataTermType {
    /// $\\|z\\|\_2^2/2$
    L222,
    /// $\\|z\\|\_1$
    L1,
}

pub use alg_tools::mapping::Lipschitz;

/// Trait for norm-bounded functions.
pub trait NormBounded<M> {
    type FloatType: Float;

    /// Returns a bound on the values of this function object in the `M`-norm.
    fn norm_bound(&self, m: M) -> Self::FloatType;
}

mercurial