src/measures/merging.rs

branch
dev
changeset 51
0693cc9ba9f0
parent 39
6316d68b58af
--- a/src/measures/merging.rs	Mon Feb 17 13:45:11 2025 -0500
+++ b/src/measures/merging.rs	Mon Feb 17 13:51:50 2025 -0500
@@ -7,35 +7,34 @@
 */
 
 use numeric_literals::replace_float_literals;
+use serde::{Deserialize, Serialize};
 use std::cmp::Ordering;
-use serde::{Serialize, Deserialize};
 //use clap::builder::{PossibleValuesParser, PossibleValue};
 use alg_tools::nanleast::NaNLeast;
 
-use crate::types::*;
 use super::delta::*;
 use super::discrete::*;
+use crate::types::*;
 
 /// Spike merging heuristic selection
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[allow(dead_code)]
 pub struct SpikeMergingMethod<F> {
     // Merging radius
-    pub(crate) radius : F,
+    pub(crate) radius: F,
     // Enabled
-    pub(crate) enabled : bool,
+    pub(crate) enabled: bool,
     // Interpolate merged points
-    pub(crate) interp : bool,
+    pub(crate) interp: bool,
 }
 
-
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for SpikeMergingMethod<F> {
+impl<F: Float> Default for SpikeMergingMethod<F> {
     fn default() -> Self {
-        SpikeMergingMethod{
-            radius : 0.01,
-            enabled : false,
-            interp : true,
+        SpikeMergingMethod {
+            radius: 0.01,
+            enabled: false,
+            interp: true,
         }
     }
 }
@@ -55,8 +54,10 @@
     /// removed spikes is set to zero, new ones inserted at the end of the spike vector.
     /// They merge may also be performed by increasing the weights of the existing spikes,
     /// without inserting new spikes.
-    fn merge_spikes<G>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> usize
-    where G : FnMut(&'_ Self) -> bool {
+    fn merge_spikes<G>(&mut self, method: SpikeMergingMethod<F>, accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
         if method.enabled {
             self.do_merge_spikes_radius(method.radius, method.interp, accept)
         } else {
@@ -72,13 +73,15 @@
     /// `self` is returned. also the number of merges is returned;
     fn merge_spikes_fitness<G, H, V, O>(
         &mut self,
-        method : SpikeMergingMethod<F>,
-        value : G,
-        fitness : H
+        method: SpikeMergingMethod<F>,
+        value: G,
+        fitness: H,
     ) -> (V, usize)
-    where G : Fn(&'_ Self) -> V,
-          H : Fn(&'_ V) -> O,
-          O : PartialOrd {
+    where
+        G: Fn(&'_ Self) -> V,
+        H: Fn(&'_ V) -> O,
+        O: PartialOrd,
+    {
         let mut res = value(self);
         let initial_fitness = fitness(&res);
         let count = self.merge_spikes(method, |μ| {
@@ -90,15 +93,14 @@
 
     /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm).
     ///
-    /// This method implements [`SpikeMerging::merge_spikes`] for
-    /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are
-    /// as for that method.
-    fn do_merge_spikes_radius<G>(&mut self, ρ : F, interp : bool, accept : G) -> usize
-    where G : FnMut(&'_ Self) -> bool;
+    /// This method implements [`SpikeMerging::merge_spikes`].
+    fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool;
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float, const N : usize>  DiscreteMeasure<Loc<F, N>, F> {
+impl<F: Float, const N: usize> DiscreteMeasure<Loc<F, N>, F> {
     /// Attempts to merge spikes with indices `i` and `j`.
     ///
     /// This assumes that the weights of the two spikes have already been checked not to be zero.
@@ -110,14 +112,16 @@
     /// Returns the index of `self.spikes` storing the new spike.
     fn attempt_merge<G>(
         &mut self,
-        i : usize,
-        j : usize,
-        interp : bool,
-        accept : &mut G
+        i: usize,
+        j: usize,
+        interp: bool,
+        accept: &mut G,
     ) -> Option<usize>
-    where G : FnMut(&'_ Self) -> bool {
-        let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
-        let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
+        let &DeltaMeasure { x: xi, α: αi } = &self.spikes[i];
+        let &DeltaMeasure { x: xj, α: αj } = &self.spikes[j];
 
         if interp {
             // Merge inplace
@@ -125,9 +129,12 @@
             self.spikes[j].α = 0.0;
             let αia = αi.abs();
             let αja = αj.abs();
-            self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αia + xj * αja) / (αia + αja) });
+            self.spikes.push(DeltaMeasure {
+                α: αi + αj,
+                x: (xi * αia + xj * αja) / (αia + αja),
+            });
             if accept(self) {
-                Some(self.spikes.len()-1)
+                Some(self.spikes.len() - 1)
             } else {
                 // Merge not accepted, restore modification
                 self.spikes[i].α = αi;
@@ -164,8 +171,9 @@
 ///
 /// The closure `compare` operators on references to elements of `slice`.
 /// Returns the sorted vector of indices into `slice`.
-pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize>
-where F : FnMut(&V, &V) -> Ordering
+pub fn sort_indices_by<V, F>(slice: &[V], mut compare: F) -> Vec<usize>
+where
+    F: FnMut(&V, &V) -> Ordering,
 {
     let mut indices = Vec::from_iter(0..slice.len());
     indices.sort_by(|&i, &j| compare(&slice[i], &slice[j]));
@@ -173,15 +181,11 @@
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> {
-
-    fn do_merge_spikes_radius<G>(
-        &mut self,
-        ρ : F,
-        interp : bool,
-        mut accept : G
-    ) -> usize
-    where G : FnMut(&'_ Self) -> bool {
+impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> {
+    fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
         // Sort by coordinate into an indexing array.
         let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| {
             let &Loc([x1]) = &δ1.x;
@@ -195,20 +199,26 @@
 
         // Scan consecutive pairs and merge if close enough and accepted by `accept`.
         if indices.len() == 0 {
-            return count
+            return count;
         }
-        for k in 0..(indices.len()-1) {
+        for k in 0..(indices.len() - 1) {
             let i = indices[k];
-            let j = indices[k+1];
-            let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i];
-            let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j];
+            let j = indices[k + 1];
+            let &DeltaMeasure {
+                x: Loc([xi]),
+                α: αi,
+            } = &self.spikes[i];
+            let &DeltaMeasure {
+                x: Loc([xj]),
+                α: αj,
+            } = &self.spikes[j];
             debug_assert!(xi <= xj);
             // If close enough, attempt merging
             if αi != 0.0 && αj != 0.0 && xj <= xi + ρ {
                 if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) {
                     // For this to work (the debug_assert! to not trigger above), the new
                     // coordinate produced by attempt_merge has to be at most xj.
-                    indices[k+1] = l;
+                    indices[k + 1] = l;
                     count += 1
                 }
             }
@@ -219,9 +229,9 @@
 }
 
 /// Orders `δ1` and `δ1` according to the first coordinate.
-fn compare_first_coordinate<F : Float>(
-    δ1 : &DeltaMeasure<Loc<F, 2>, F>,
-    δ2 : &DeltaMeasure<Loc<F, 2>, F>
+fn compare_first_coordinate<F: Float>(
+    δ1: &DeltaMeasure<Loc<F, 2>, F>,
+    δ2: &DeltaMeasure<Loc<F, 2>, F>,
 ) -> Ordering {
     let &Loc([x11, ..]) = &δ1.x;
     let &Loc([x21, ..]) = &δ2.x;
@@ -230,10 +240,11 @@
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> {
-
-    fn do_merge_spikes_radius<G>(&mut self, ρ : F, interp : bool, mut accept : G) -> usize
-    where G : FnMut(&'_ Self) -> bool {
+impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> {
+    fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
         // Sort by first coordinate into an indexing array.
         let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate);
 
@@ -243,15 +254,18 @@
 
         // Scan in order
         if indices.len() == 0 {
-            return count
+            return count;
         }
-        for k in 0..indices.len()-1 {
+        for k in 0..indices.len() - 1 {
             let i = indices[k];
-            let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i];
+            let &DeltaMeasure {
+                x: Loc([xi1, xi2]),
+                α: αi,
+            } = &self[i];
 
             if αi == 0.0 {
                 // Nothin to be done if the weight is already zero
-                continue
+                continue;
             }
 
             let mut closest = None;
@@ -261,35 +275,38 @@
             // the _closest_ mergeable spike might have index less than `k` in `indices`, and a
             // merge with it might have not been attempted with this spike if a different closer
             // spike was discovered based on the second coordinate.
-            'scan_2nd: for l in (start_scan_2nd+1)..indices.len() {
+            'scan_2nd: for l in (start_scan_2nd + 1)..indices.len() {
                 if l == k {
                     // Do not attempt to merge a spike with itself
-                    continue
+                    continue;
                 }
                 let j = indices[l];
-                let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j];
+                let &DeltaMeasure {
+                    x: Loc([xj1, xj2]),
+                    α: αj,
+                } = &self[j];
 
                 if xj1 < xi1 - ρ {
                     // Spike `j = indices[l]` has too low first coordinate. Update starting index
                     // for next iteration, and continue scanning.
                     start_scan_2nd = l;
-                    continue 'scan_2nd
+                    continue 'scan_2nd;
                 } else if xj1 > xi1 + ρ {
                     // Break out: spike `j = indices[l]` has already too high first coordinate, no
                     // more close enough spikes can be found due to the sorting of `indices`.
-                    break 'scan_2nd
+                    break 'scan_2nd;
                 }
 
                 // If also second coordinate is close enough, attempt merging if closer than
                 // previously discovered mergeable spikes.
-                let d2 = (xi2-xj2).abs();
+                let d2 = (xi2 - xj2).abs();
                 if αj != 0.0 && d2 <= ρ {
-                    let r1 = xi1-xj1;
-                    let d = (d2*d2 + r1*r1).sqrt();
+                    let r1 = xi1 - xj1;
+                    let d = (d2 * d2 + r1 * r1).sqrt();
                     match closest {
                         None => closest = Some((l, j, d)),
                         Some((_, _, r)) if r > d => closest = Some((l, j, d)),
-                        _ => {},
+                        _ => {}
                     }
                 }
             }
@@ -300,13 +317,12 @@
                     // If merge was succesfull, make new spike candidate for merging.
                     indices[l] = n;
                     count += 1;
-                    let compare = |i, j| compare_first_coordinate(&self.spikes[i],
-                                                                  &self.spikes[j]);
+                    let compare = |i, j| compare_first_coordinate(&self.spikes[i], &self.spikes[j]);
                     // Re-sort relevant range of indices
                     if l < k {
                         indices[l..k].sort_by(|&i, &j| compare(i, j));
                     } else {
-                        indices[k+1..=l].sort_by(|&i, &j| compare(i, j));
+                        indices[k + 1..=l].sort_by(|&i, &j| compare(i, j));
                     }
                 }
             }
@@ -315,4 +331,3 @@
         count
     }
 }
-

mercurial