src/types.rs

Thu, 29 Aug 2024 00:00:00 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 29 Aug 2024 00:00:00 -0500
branch
dev
changeset 34
efa60bc4f743
parent 32
56c8adc32b09
child 35
b087e3eab191
permissions
-rw-r--r--

Radon FB + sliding improvements

//! Type definitions and re-exports

use numeric_literals::replace_float_literals;

use colored::ColoredString;
use serde::{Serialize, Deserialize};
use clap::ValueEnum;
use alg_tools::iterate::LogRepr;
use alg_tools::euclidean::Euclidean;
use alg_tools::norms::{Norm, L1};

pub use alg_tools::types::*;
pub use alg_tools::loc::Loc;
pub use alg_tools::sets::Cube;

use crate::measures::DiscreteMeasure;

/// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up.
pub trait ClapFloat : Float
                      + std::str::FromStr<Err=std::num::ParseFloatError>
                      + std::fmt::Display {}
impl ClapFloat for f32 {}
impl ClapFloat for f64 {}

/// Structure for storing iteration statistics
#[derive(Debug, Clone, Serialize)]
pub struct IterInfo<F : Float, const N : usize> {
    /// Function value
    pub value : F,
    /// Number of spikes
    pub n_spikes : usize,
    /// Number of iterations this statistic covers
    pub this_iters : usize,
    /// Number of spikes removed by merging since last IterInfo statistic
    pub merged : usize,
    /// Number of spikes removed by pruning since last IterInfo statistic
    pub pruned : usize,
    /// Number of inner iterations since last IterInfo statistic
    pub inner_iters : usize,
    /// Tuple of (transported mass, source mass)
    pub untransported_fraction : Option<(F, F)>,
    /// Tuple of (|destination mass - untransported_mass|, transported mass)
    pub transport_error : Option<(F, F)>,
    /// Current tolerance
    pub ε : F,
    /// Solve fin.dim problem for this measure to get the optimal `value`.
    pub postprocessing : Option<DiscreteMeasure<Loc<F, N>, F>>,
}

impl<F : Float, const N : usize>  IterInfo<F, N> {
    /// Initialise statistics with zeros. `ε` and `value` are unspecified.
    pub fn new() -> Self {
        IterInfo {
            value : F::NAN,
            n_spikes : 0,
            this_iters : 0,
            merged : 0,
            pruned : 0,
            inner_iters : 0,
            ε : F::NAN,
            postprocessing : None,
            untransported_fraction : None,
            transport_error : None,
        }
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float {
    fn logrepr(&self) -> ColoredString {
        format!("{}\t| N = {}, ε = {:.8}, inner_iters_mean = {}, merged+pruned_mean = {}+{}{}{}",
                self.value.logrepr(),
                self.n_spikes,
                self.ε,
                self.inner_iters as float / self.this_iters as float,
                self.merged as float / self.this_iters as float,
                self.pruned as float / self.this_iters as float,
                match self.untransported_fraction {
                    None => format!(""),
                    Some((a, b)) => if b > 0.0 {
                        format!(", untransported {:.2}%", 100.0*a/b)
                    } else {
                        format!("")
                    }
                },
                match self.transport_error {
                    None => format!(""),
                    Some((a, b)) => if b > 0.0 {
                        format!(", transport error {:.2}%", 100.0*a/b)
                    } else {
                        format!("")
                    }
                }
        ).as_str().into()
    }
}

/// Branch and bound refinement settings
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct RefinementSettings<F : Float> {
    /// Function value tolerance multiplier for bisection tree refinement in
    /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions.
    pub tolerance_mult : F,
    /// Maximum branch and bound steps
    pub max_steps : usize,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for RefinementSettings<F> {
    fn default() -> Self {
        RefinementSettings {
            tolerance_mult : 0.1,
            max_steps : 50000,
        }
    }
}

/// Data term type
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug, ValueEnum)]
pub enum DataTerm {
    /// $\\|z\\|\_2^2/2$
    L2Squared,
    /// $\\|z\\|\_1$
    L1,
}

impl DataTerm {
    /// Calculate the data term value at residual $z=Aμ - b$.
    pub fn value_at_residual<F : Float, E : Euclidean<F> + Norm<F, L1>>(&self, z : E) -> F {
        match self {
            Self::L2Squared => z.norm2_squared_div2(),
            Self::L1 => z.norm(L1),
        }
    }
}

/// Type for indicating norm-2-squared data fidelity or transport cost.
#[derive(Clone, Copy, Serialize, Deserialize)]
pub struct L2Squared;

/// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`.
pub trait Lipschitz<D> {
    /// The type of floats
    type FloatType : Float;

    /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`.
    fn lipschitz_factor(&self, seminorm : D) -> Option<Self::FloatType>;
}

mercurial