src/types.rs

changeset 70
ed16d0f10d08
parent 63
7a8a55fd41c0
--- a/src/types.rs	Tue Apr 08 13:31:39 2025 -0500
+++ b/src/types.rs	Fri May 08 16:47:58 2026 -0500
@@ -2,159 +2,171 @@
 
 use numeric_literals::replace_float_literals;
 
+use alg_tools::iterate::LogRepr;
 use colored::ColoredString;
-use serde::{Serialize, Deserialize};
-use alg_tools::iterate::LogRepr;
-use alg_tools::euclidean::Euclidean;
-use alg_tools::norms::{Norm, L1};
+use serde::{Deserialize, Serialize};
 
-pub use alg_tools::types::*;
+pub use alg_tools::error::DynResult;
 pub use alg_tools::loc::Loc;
 pub use alg_tools::sets::Cube;
+pub use alg_tools::types::*;
 
 // use crate::measures::DiscreteMeasure;
 
 /// [`Float`] with extra display and string conversion traits such that [`clap`] doesn't choke up.
-pub trait ClapFloat : Float
-                      + std::str::FromStr<Err=std::num::ParseFloatError>
-                      + std::fmt::Display {}
+pub trait ClapFloat:
+    Float + std::str::FromStr<Err = std::num::ParseFloatError> + std::fmt::Display
+{
+}
 impl ClapFloat for f32 {}
 impl ClapFloat for f64 {}
 
-/// Structure for storing iteration statistics
+/// Structure for storing transport statistics
 #[derive(Debug, Clone, Serialize)]
-pub struct IterInfo<F : Float, const N : usize> {
-    /// Function value
-    pub value : F,
-    /// Number of spikes
-    pub n_spikes : usize,
-    /// Number of iterations this statistic covers
-    pub this_iters : usize,
-    /// Number of spikes inserted since last IterInfo statistic
-    pub inserted : usize,
-    /// Number of spikes removed by merging since last IterInfo statistic
-    pub merged : usize,
-    /// Number of spikes removed by pruning since last IterInfo statistic
-    pub pruned : usize,
-    /// Number of inner iterations since last IterInfo statistic
-    pub inner_iters : usize,
-    /// Tuple of (transported mass, source mass)
-    pub untransported_fraction : Option<(F, F)>,
-    /// Tuple of (|destination mass - untransported_mass|, transported mass)
-    pub transport_error : Option<(F, F)>,
-    /// Current tolerance
-    pub ε : F,
-    // /// Solve fin.dim problem for this measure to get the optimal `value`.
-    // pub postprocessing : Option<RNDM<F, N>>,
+pub struct TransportInfo<F: Float = f64> {
+    /// Tuple of (untransported mass, source mass)
+    pub untransported_fraction: (F, F),
+    /// Tuple of (|destination mass - transported_mass|, transported mass)
+    pub transport_error: (F, F),
+    /// Number of readjustment iterations for transport
+    pub readjustment_iters: usize,
+    /// ($∫ c_2 dγ , ∫ dγ$)
+    pub dist: (F, F),
 }
 
-impl<F : Float, const N : usize>  IterInfo<F, N> {
-    /// Initialise statistics with zeros. `ε` and `value` are unspecified.
+#[replace_float_literals(F::cast_from(literal))]
+impl<F: Float> TransportInfo<F> {
+    /// Initialise transport statistics
     pub fn new() -> Self {
-        IterInfo {
-            value : F::NAN,
-            n_spikes : 0,
-            this_iters : 0,
-            merged : 0,
-            inserted : 0,
-            pruned : 0,
-            inner_iters : 0,
-            ε : F::NAN,
-            // postprocessing : None,
-            untransported_fraction : None,
-            transport_error : None,
+        TransportInfo {
+            untransported_fraction: (0.0, 0.0),
+            transport_error: (0.0, 0.0),
+            readjustment_iters: 0,
+            dist: (0.0, 0.0),
         }
     }
 }
 
+/// Structure for storing iteration statistics
+#[derive(Debug, Clone, Serialize)]
+pub struct IterInfo<F: Float = f64> {
+    /// Function value
+    pub value: F,
+    /// Number of spikes
+    pub n_spikes: usize,
+    /// Number of iterations this statistic covers
+    pub this_iters: usize,
+    /// Number of spikes inserted since last IterInfo statistic
+    pub inserted: usize,
+    /// Number of spikes removed by merging since last IterInfo statistic
+    pub merged: usize,
+    /// Number of spikes removed by pruning since last IterInfo statistic
+    pub pruned: usize,
+    /// Number of inner iterations since last IterInfo statistic
+    pub inner_iters: usize,
+    /// Transport statistis
+    pub transport: Option<TransportInfo<F>>,
+    /// Current tolerance
+    pub ε: F,
+    // /// Solve fin.dim problem for this measure to get the optimal `value`.
+    // pub postprocessing : Option<RNDM<N, F>>,
+}
+
+impl<F: Float> IterInfo<F> {
+    /// Initialise statistics with zeros. `ε` and `value` are unspecified.
+    pub fn new() -> Self {
+        IterInfo {
+            value: F::NAN,
+            n_spikes: 0,
+            this_iters: 0,
+            merged: 0,
+            inserted: 0,
+            pruned: 0,
+            inner_iters: 0,
+            ε: F::NAN,
+            // postprocessing : None,
+            transport: None,
+        }
+    }
+
+    /// Get mutable reference to transport statistics, creating it if it is `None`.
+    pub fn get_transport_mut(&mut self) -> &mut TransportInfo<F> {
+        if self.transport.is_none() {
+            self.transport = Some(TransportInfo::new());
+        }
+        self.transport.as_mut().unwrap()
+    }
+}
+
 #[replace_float_literals(F::cast_from(literal))]
-impl<F, const N : usize> LogRepr for IterInfo<F, N> where F : LogRepr + Float {
+impl<F> LogRepr for IterInfo<F>
+where
+    F: LogRepr + Float,
+{
     fn logrepr(&self) -> ColoredString {
-        format!("{}\t| N = {}, ε = {:.8}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}{}",
-                self.value.logrepr(),
-                self.n_spikes,
-                self.ε,
-                self.inner_iters as float / self.this_iters.max(1) as float,
-                self.inserted as float / self.this_iters.max(1) as float,
-                self.merged as float / self.this_iters.max(1) as float,
-                self.pruned as float / self.this_iters.max(1) as float,
-                match self.untransported_fraction {
-                    None => format!(""),
-                    Some((a, b)) => if b > 0.0 {
-                        format!(", untransported {:.2}%", 100.0*a/b)
-                    } else {
-                        format!("")
-                    }
-                },
-                match self.transport_error {
-                    None => format!(""),
-                    Some((a, b)) => if b > 0.0 {
-                        format!(", transport error {:.2}%", 100.0*a/b)
-                    } else {
-                        format!("")
-                    }
+        format!(
+            "{}\t| N = {}, ε = {:.2e}, 𝔼inner_it = {}, 𝔼ins/mer/pru = {}/{}/{}{}",
+            self.value.logrepr(),
+            self.n_spikes,
+            self.ε,
+            self.inner_iters as float / self.this_iters.max(1) as float,
+            self.inserted as float / self.this_iters.max(1) as float,
+            self.merged as float / self.this_iters.max(1) as float,
+            self.pruned as float / self.this_iters.max(1) as float,
+            match &self.transport {
+                None => format!(""),
+                Some(t) => {
+                    let (a1, b1) = t.untransported_fraction;
+                    let (a2, b2) = t.transport_error;
+                    let (a3, b3) = t.dist;
+                    format!(
+                        ", γ-un/er/d/it = {:.2}%/{:.2}%/{:.2e}/{:.2}",
+                        if b1 > 0.0 { 100.0 * a1 / b1 } else { F::NAN },
+                        if b2 > 0.0 { 100.0 * a2 / b2 } else { F::NAN },
+                        if b3 > 0.0 { a3 / b3 } else { F::NAN },
+                        t.readjustment_iters as float / self.this_iters.max(1) as float,
+                    )
                 }
-        ).as_str().into()
+            }
+        )
+        .as_str()
+        .into()
     }
 }
 
 /// Branch and bound refinement settings
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[serde(default)]
-pub struct RefinementSettings<F : Float> {
+pub struct RefinementSettings<F: Float> {
     /// Function value tolerance multiplier for bisection tree refinement in
     /// [`alg_tools::bisection_tree::BTFN::maximise`] and related functions.
-    pub tolerance_mult : F,
+    pub tolerance_mult: F,
     /// Maximum branch and bound steps
-    pub max_steps : usize,
+    pub max_steps: usize,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for RefinementSettings<F> {
+impl<F: Float> Default for RefinementSettings<F> {
     fn default() -> Self {
-        RefinementSettings {
-            tolerance_mult : 0.1,
-            max_steps : 50000,
-        }
+        RefinementSettings { tolerance_mult: 0.1, max_steps: 50000 }
     }
 }
 
 /// Data term type
 #[derive(Clone, Copy, PartialEq, Serialize, Deserialize, Debug)]
-pub enum DataTerm {
+pub enum DataTermType {
     /// $\\|z\\|\_2^2/2$
-    L2Squared,
+    L222,
     /// $\\|z\\|\_1$
     L1,
 }
 
-impl DataTerm {
-    /// Calculate the data term value at residual $z=Aμ - b$.
-    pub fn value_at_residual<F : Float, E : Euclidean<F> + Norm<F, L1>>(&self, z : E) -> F {
-        match self {
-            Self::L2Squared => z.norm2_squared_div2(),
-            Self::L1 => z.norm(L1),
-        }
-    }
-}
-
-/// Type for indicating norm-2-squared data fidelity or transport cost.
-#[derive(Clone, Copy, Serialize, Deserialize)]
-pub struct L2Squared;
-
-/// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`.
-pub trait Lipschitz<M> {
-    /// The type of floats
-    type FloatType : Float;
-
-    /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`.
-    fn lipschitz_factor(&self, seminorm : M) -> Option<Self::FloatType>;
-}
+pub use alg_tools::mapping::Lipschitz;
 
 /// Trait for norm-bounded functions.
 pub trait NormBounded<M> {
-    type FloatType : Float;
+    type FloatType: Float;
 
     /// Returns a bound on the values of this function object in the `M`-norm.
-    fn norm_bound(&self, m : M) -> Self::FloatType;
+    fn norm_bound(&self, m: M) -> Self::FloatType;
 }

mercurial