src/measures/merging.rs

branch
dev
changeset 34
efa60bc4f743
parent 0
eb3c7813b67a
child 39
6316d68b58af
--- a/src/measures/merging.rs	Tue Aug 01 10:32:12 2023 +0300
+++ b/src/measures/merging.rs	Thu Aug 29 00:00:00 2024 -0500
@@ -20,8 +20,10 @@
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[allow(dead_code)]
 pub enum SpikeMergingMethod<F> {
-    /// Try to merge spikes within a given radius of eachother
+    /// Try to merge spikes within a given radius of each other, averaging the location
     HeuristicRadius(F),
+    /// Try to merge spikes within a given radius of each other, attempting original locations
+    HeuristicRadiusNoInterp(F),
     /// No merging
     None,
 }
@@ -40,7 +42,8 @@
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
         match self {
             Self::None => write!(f, "none"),
-            Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f),
+            Self::HeuristicRadius(r) => write!(f, "i:{}", r),
+            Self::HeuristicRadiusNoInterp(r) => write!(f, "n:{}", r),
         }
     }
 }
@@ -52,7 +55,19 @@
         if s == "none" {
             Ok(Self::None)
         } else {
-            Ok(Self::HeuristicRadius(F::from_str(s)?))
+            let mut subs = s.split(':');
+            match subs.next() {
+                None =>  Ok(Self::HeuristicRadius(F::from_str(s)?)),
+                Some(t) if t == "n" => match subs.next() {
+                    None => Err(core::num::dec2flt::pfe_invalid()),
+                    Some(v) => Ok(Self::HeuristicRadiusNoInterp(F::from_str(v)?))
+                },
+                Some(t) if t == "i" => match subs.next() {
+                    None => Err(core::num::dec2flt::pfe_invalid()),
+                    Some(v) => Ok(Self::HeuristicRadius(F::from_str(v)?))
+                },
+                Some(v) => Ok(Self::HeuristicRadius(F::from_str(v)?))
+            }
         }
     }
 }
@@ -77,11 +92,14 @@
     ///
     /// This method is stable with respect to spike locations:  on merge, the weight of existing
     /// spikes is set to zero, and a new one inserted at the end of the spike vector.
-    fn merge_spikes<G, V>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> Option<V>
-    where G : Fn(&'_ Self) -> Option<V> {
+    fn merge_spikes<G>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> usize
+    where G : FnMut(&'_ Self) -> bool {
         match method {
-            SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept),
-            SpikeMergingMethod::None => None,
+            SpikeMergingMethod::HeuristicRadius(ρ) =>
+                self.do_merge_spikes_radius(ρ, true, accept),
+            SpikeMergingMethod::HeuristicRadiusNoInterp(ρ) =>
+                self.do_merge_spikes_radius(ρ, false, accept),
+            SpikeMergingMethod::None => 0,
         }
     }
 
@@ -90,22 +108,23 @@
     /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of
     /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value`
     // for a merge  accepted by `fitness`. If no merge was accepted, `value` applied to the initial
-    /// `self` is returned.
+    /// `self` is returned. also the number of merges is returned;
     fn merge_spikes_fitness<G, H, V, O>(
         &mut self,
         method : SpikeMergingMethod<F>,
         value : G,
         fitness : H
-    ) -> V
+    ) -> (V, usize)
     where G : Fn(&'_ Self) -> V,
           H : Fn(&'_ V) -> O,
           O : PartialOrd {
-        let initial_res = value(self);
-        let initial_fitness = fitness(&initial_res);
-        self.merge_spikes(method, |μ| {
-            let res = value(μ);
-            (fitness(&res) <= initial_fitness).then_some(res)
-        }).unwrap_or(initial_res)
+        let mut res = value(self);
+        let initial_fitness = fitness(&res);
+        let count = self.merge_spikes(method, |μ| {
+            res = value(μ);
+            fitness(&res) <= initial_fitness
+        });
+        (res, count)
     }
 
     /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm).
@@ -113,8 +132,8 @@
     /// This method implements [`SpikeMerging::merge_spikes`] for
     /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are
     /// as for that method.
-    fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V>
-    where G : Fn(&'_ Self) -> Option<V>;
+    fn do_merge_spikes_radius<G>(&mut self, ρ : F, interp : bool, accept : G) -> usize
+    where G : FnMut(&'_ Self) -> bool;
 }
 
 #[replace_float_literals(F::cast_from(literal))]
@@ -126,70 +145,58 @@
     /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`].
     /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its
     /// return value.
-    fn attempt_merge<G, V>(
+    ///
+    /// Returns the index of `self.spikes` storing the new spike.
+    fn attempt_merge<G>(
         &mut self,
-        res : &mut Option<V>,
         i : usize,
         j : usize,
-        accept : &G
-    ) -> bool
-    where G : Fn(&'_ Self) -> Option<V> {
+        interp : bool,
+        accept : &mut G
+    ) -> Option<usize>
+    where G : FnMut(&'_ Self) -> bool {
         let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
         let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];
 
-        // Merge inplace
-        self.spikes[i].α = 0.0;
-        self.spikes[j].α = 0.0;
-        //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 });
-        self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) });
-        match accept(self) {
-            some@Some(..) => {
-                // Merge accepted, update our return value
-                *res = some;
-                // On next iteration process the newly merged spike.
-                //indices[k+1] = self.spikes.len() - 1;
-                true
-            },
-            None => {
+        if interp {
+            // Merge inplace
+            self.spikes[i].α = 0.0;
+            self.spikes[j].α = 0.0;
+            let αia = αi.abs();
+            let αja = αj.abs();
+            self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αia + xj * αja) / (αia + αja) });
+            if accept(self) {
+                Some(self.spikes.len()-1)
+            } else {
                 // Merge not accepted, restore modification
                 self.spikes[i].α = αi;
                 self.spikes[j].α = αj;
                 self.spikes.pop();
-                false
+                None
+            }
+        } else {
+            // Attempt merge inplace, first combination
+            self.spikes[i].α = αi + αj;
+            self.spikes[j].α = 0.0;
+            if accept(self) {
+                // Merge accepted
+                Some(i)
+            } else {
+                // Attempt merge inplace, second combination
+                self.spikes[i].α = 0.0;
+                self.spikes[j].α = αi + αj;
+                if accept(self) {
+                    // Merge accepted
+                    Some(j)
+                } else {
+                    // Merge not accepted, restore modification
+                    self.spikes[i].α = αi;
+                    self.spikes[j].α = αj;
+                    None
+                }
             }
         }
     }
-
-    /*
-    /// Attempts to merge spikes with indices i and j, acceptance through a delta.
-    fn attempt_merge_change<G, V>(
-        &mut self,
-        res : &mut Option<V>,
-        i : usize,
-        j : usize,
-        accept_change : &G
-    ) -> bool
-    where G : Fn(&'_ Self) -> Option<V> {
-        let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
-        let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];
-        let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 };
-        let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into();
-
-        match accept_change(&λ) {
-            some@Some(..) => {
-                // Merge accepted, update our return value
-                *res = some;
-                self.spikes[i].α = 0.0;
-                self.spikes[j].α = 0.0;
-                self.spikes.push(δ);
-                true
-            },
-            None => {
-                false
-            }
-        }
-    }*/
-
 }
 
 /// Sorts a vector of indices into `slice` by `compare`.
@@ -207,12 +214,13 @@
 #[replace_float_literals(F::cast_from(literal))]
 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> {
 
-    fn do_merge_spikes_radius<G, V>(
+    fn do_merge_spikes_radius<G>(
         &mut self,
         ρ : F,
-        accept : G
-    ) -> Option<V>
-    where G : Fn(&'_ Self) -> Option<V> {
+        interp : bool,
+        mut accept : G
+    ) -> usize
+    where G : FnMut(&'_ Self) -> bool {
         // Sort by coordinate into an indexing array.
         let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| {
             let &Loc([x1]) = &δ1.x;
@@ -222,11 +230,11 @@
         });
 
         // Initialise result
-        let mut res = None;
+        let mut count = 0;
 
         // Scan consecutive pairs and merge if close enough and accepted by `accept`.
         if indices.len() == 0 {
-            return res
+            return count
         }
         for k in 0..(indices.len()-1) {
             let i = indices[k];
@@ -236,13 +244,16 @@
             debug_assert!(xi <= xj);
             // If close enough, attempt merging
             if αi != 0.0 && αj != 0.0 && xj <= xi + ρ {
-                if self.attempt_merge(&mut res, i, j, &accept) {
-                    indices[k+1] = self.spikes.len() - 1;
+                if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) {
+                    // For this to work (the debug_assert! to not trigger above), the new
+                    // coordinate produced by attempt_merge has to be at most xj.
+                    indices[k+1] = l;
+                    count += 1
                 }
             }
         }
 
-        res
+        count
     }
 }
 
@@ -260,18 +271,18 @@
 #[replace_float_literals(F::cast_from(literal))]
 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> {
 
-    fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V>
-    where G : Fn(&'_ Self) -> Option<V> {
+    fn do_merge_spikes_radius<G>(&mut self, ρ : F, interp : bool, mut accept : G) -> usize
+    where G : FnMut(&'_ Self) -> bool {
         // Sort by first coordinate into an indexing array.
         let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate);
 
         // Initialise result
-        let mut res = None;
+        let mut count = 0;
         let mut start_scan_2nd = 0;
 
         // Scan in order
         if indices.len() == 0 {
-            return res
+            return count
         }
         for k in 0..indices.len()-1 {
             let i = indices[k];
@@ -324,9 +335,10 @@
 
             // Attempt merging closest close-enough spike
             if let Some((l, j, _)) = closest {
-                if self.attempt_merge(&mut res, i, j, &accept) {
+                if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) {
                     // If merge was succesfull, make new spike candidate for merging.
-                    indices[l] = self.spikes.len() - 1;
+                    indices[l] = n;
+                    count += 1;
                     let compare = |i, j| compare_first_coordinate(&self.spikes[i],
                                                                   &self.spikes[j]);
                     // Re-sort relevant range of indices
@@ -339,7 +351,7 @@
             }
         }
 
-        res
+        count
     }
 }
 

mercurial