src/measures/merging.rs

changeset 52
f0e8704d3f0e
parent 51
0693cc9ba9f0
--- a/src/measures/merging.rs	Tue Aug 01 10:25:09 2023 +0300
+++ b/src/measures/merging.rs	Mon Feb 17 13:54:53 2025 -0500
@@ -7,60 +7,35 @@
 */
 
 use numeric_literals::replace_float_literals;
+use serde::{Deserialize, Serialize};
 use std::cmp::Ordering;
-use serde::{Serialize, Deserialize};
 //use clap::builder::{PossibleValuesParser, PossibleValue};
 use alg_tools::nanleast::NaNLeast;
 
-use crate::types::*;
 use super::delta::*;
 use super::discrete::*;
+use crate::types::*;
 
 /// Spike merging heuristic selection
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
 #[allow(dead_code)]
-pub enum SpikeMergingMethod<F> {
-    /// Try to merge spikes within a given radius of eachother
-    HeuristicRadius(F),
-    /// No merging
-    None,
-}
-
-// impl<F : Float> SpikeMergingMethod<F> {
-//     /// This is for [`clap`] to display command line help.
-//     pub fn value_parser() -> PossibleValuesParser {
-//         PossibleValuesParser::new([
-//             PossibleValue::new("none").help("No merging"),
-//             PossibleValue::new("<radius>").help("Heuristic merging within indicated radius")
-//         ])
-//     }
-// }
-
-impl<F : ClapFloat> std::fmt::Display for SpikeMergingMethod<F> {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
-        match self {
-            Self::None => write!(f, "none"),
-            Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f),
-        }
-    }
-}
-
-impl<F : ClapFloat> std::str::FromStr for SpikeMergingMethod<F> {
-    type Err = F::Err;
-
-    fn from_str(s: &str) -> Result<Self, Self::Err> {
-        if s == "none" {
-            Ok(Self::None)
-        } else {
-            Ok(Self::HeuristicRadius(F::from_str(s)?))
-        }
-    }
+pub struct SpikeMergingMethod<F> {
+    // Merging radius
+    pub(crate) radius: F,
+    // Enabled
+    pub(crate) enabled: bool,
+    // Interpolate merged points
+    pub(crate) interp: bool,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> Default for SpikeMergingMethod<F> {
+impl<F: Float> Default for SpikeMergingMethod<F> {
     fn default() -> Self {
-        SpikeMergingMethod::HeuristicRadius(0.02)
+        SpikeMergingMethod {
+            radius: 0.01,
+            enabled: false,
+            interp: true,
+        }
     }
 }
 
@@ -75,13 +50,18 @@
     /// an arbitrary value. This method will return that value for the *last* accepted merge, or
     /// [`None`] if no merge was accepted.
     ///
-    /// This method is stable with respect to spike locations:  on merge, the weight of existing
-    /// spikes is set to zero, and a new one inserted at the end of the spike vector.
-    fn merge_spikes<G, V>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> Option<V>
-    where G : Fn(&'_ Self) -> Option<V> {
-        match method {
-            SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept),
-            SpikeMergingMethod::None => None,
+    /// This method is stable with respect to spike locations:  on merge, the weights of existing
+    /// removed spikes is set to zero, new ones inserted at the end of the spike vector.
+    /// They merge may also be performed by increasing the weights of the existing spikes,
+    /// without inserting new spikes.
+    fn merge_spikes<G>(&mut self, method: SpikeMergingMethod<F>, accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
+        if method.enabled {
+            self.do_merge_spikes_radius(method.radius, method.interp, accept)
+        } else {
+            0
         }
     }
 
@@ -90,35 +70,37 @@
     /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of
     /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value`
     // for a merge  accepted by `fitness`. If no merge was accepted, `value` applied to the initial
-    /// `self` is returned.
+    /// `self` is returned. also the number of merges is returned;
     fn merge_spikes_fitness<G, H, V, O>(
         &mut self,
-        method : SpikeMergingMethod<F>,
-        value : G,
-        fitness : H
-    ) -> V
-    where G : Fn(&'_ Self) -> V,
-          H : Fn(&'_ V) -> O,
-          O : PartialOrd {
-        let initial_res = value(self);
-        let initial_fitness = fitness(&initial_res);
-        self.merge_spikes(method, |μ| {
-            let res = value(μ);
-            (fitness(&res) <= initial_fitness).then_some(res)
-        }).unwrap_or(initial_res)
+        method: SpikeMergingMethod<F>,
+        value: G,
+        fitness: H,
+    ) -> (V, usize)
+    where
+        G: Fn(&'_ Self) -> V,
+        H: Fn(&'_ V) -> O,
+        O: PartialOrd,
+    {
+        let mut res = value(self);
+        let initial_fitness = fitness(&res);
+        let count = self.merge_spikes(method, |μ| {
+            res = value(μ);
+            fitness(&res) <= initial_fitness
+        });
+        (res, count)
     }
 
     /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm).
     ///
-    /// This method implements [`SpikeMerging::merge_spikes`] for
-    /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are
-    /// as for that method.
-    fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V>
-    where G : Fn(&'_ Self) -> Option<V>;
+    /// This method implements [`SpikeMerging::merge_spikes`].
+    fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool;
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float, const N : usize>  DiscreteMeasure<Loc<F, N>, F> {
+impl<F: Float, const N: usize> DiscreteMeasure<Loc<F, N>, F> {
     /// Attempts to merge spikes with indices `i` and `j`.
     ///
     /// This assumes that the weights of the two spikes have already been checked not to be zero.
@@ -126,78 +108,72 @@
     /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`].
     /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its
     /// return value.
-    fn attempt_merge<G, V>(
+    ///
+    /// Returns the index of `self.spikes` storing the new spike.
+    fn attempt_merge<G>(
         &mut self,
-        res : &mut Option<V>,
-        i : usize,
-        j : usize,
-        accept : &G
-    ) -> bool
-    where G : Fn(&'_ Self) -> Option<V> {
-        let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
-        let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];
+        i: usize,
+        j: usize,
+        interp: bool,
+        accept: &mut G,
+    ) -> Option<usize>
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
+        let &DeltaMeasure { x: xi, α: αi } = &self.spikes[i];
+        let &DeltaMeasure { x: xj, α: αj } = &self.spikes[j];
 
-        // Merge inplace
-        self.spikes[i].α = 0.0;
-        self.spikes[j].α = 0.0;
-        //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 });
-        self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) });
-        match accept(self) {
-            some@Some(..) => {
-                // Merge accepted, update our return value
-                *res = some;
-                // On next iteration process the newly merged spike.
-                //indices[k+1] = self.spikes.len() - 1;
-                true
-            },
-            None => {
+        if interp {
+            // Merge inplace
+            self.spikes[i].α = 0.0;
+            self.spikes[j].α = 0.0;
+            let αia = αi.abs();
+            let αja = αj.abs();
+            self.spikes.push(DeltaMeasure {
+                α: αi + αj,
+                x: (xi * αia + xj * αja) / (αia + αja),
+            });
+            if accept(self) {
+                Some(self.spikes.len() - 1)
+            } else {
                 // Merge not accepted, restore modification
                 self.spikes[i].α = αi;
                 self.spikes[j].α = αj;
                 self.spikes.pop();
-                false
+                None
+            }
+        } else {
+            // Attempt merge inplace, first combination
+            self.spikes[i].α = αi + αj;
+            self.spikes[j].α = 0.0;
+            if accept(self) {
+                // Merge accepted
+                Some(i)
+            } else {
+                // Attempt merge inplace, second combination
+                self.spikes[i].α = 0.0;
+                self.spikes[j].α = αi + αj;
+                if accept(self) {
+                    // Merge accepted
+                    Some(j)
+                } else {
+                    // Merge not accepted, restore modification
+                    self.spikes[i].α = αi;
+                    self.spikes[j].α = αj;
+                    None
+                }
             }
         }
     }
-
-    /*
-    /// Attempts to merge spikes with indices i and j, acceptance through a delta.
-    fn attempt_merge_change<G, V>(
-        &mut self,
-        res : &mut Option<V>,
-        i : usize,
-        j : usize,
-        accept_change : &G
-    ) -> bool
-    where G : Fn(&'_ Self) -> Option<V> {
-        let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
-        let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];
-        let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 };
-        let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into();
-
-        match accept_change(&λ) {
-            some@Some(..) => {
-                // Merge accepted, update our return value
-                *res = some;
-                self.spikes[i].α = 0.0;
-                self.spikes[j].α = 0.0;
-                self.spikes.push(δ);
-                true
-            },
-            None => {
-                false
-            }
-        }
-    }*/
-
 }
 
 /// Sorts a vector of indices into `slice` by `compare`.
 ///
 /// The closure `compare` operators on references to elements of `slice`.
 /// Returns the sorted vector of indices into `slice`.
-pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize>
-where F : FnMut(&V, &V) -> Ordering
+pub fn sort_indices_by<V, F>(slice: &[V], mut compare: F) -> Vec<usize>
+where
+    F: FnMut(&V, &V) -> Ordering,
 {
     let mut indices = Vec::from_iter(0..slice.len());
     indices.sort_by(|&i, &j| compare(&slice[i], &slice[j]));
@@ -205,14 +181,11 @@
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> {
-
-    fn do_merge_spikes_radius<G, V>(
-        &mut self,
-        ρ : F,
-        accept : G
-    ) -> Option<V>
-    where G : Fn(&'_ Self) -> Option<V> {
+impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> {
+    fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
         // Sort by coordinate into an indexing array.
         let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| {
             let &Loc([x1]) = &δ1.x;
@@ -222,34 +195,43 @@
         });
 
         // Initialise result
-        let mut res = None;
+        let mut count = 0;
 
         // Scan consecutive pairs and merge if close enough and accepted by `accept`.
         if indices.len() == 0 {
-            return res
+            return count;
         }
-        for k in 0..(indices.len()-1) {
+        for k in 0..(indices.len() - 1) {
             let i = indices[k];
-            let j = indices[k+1];
-            let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i];
-            let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j];
+            let j = indices[k + 1];
+            let &DeltaMeasure {
+                x: Loc([xi]),
+                α: αi,
+            } = &self.spikes[i];
+            let &DeltaMeasure {
+                x: Loc([xj]),
+                α: αj,
+            } = &self.spikes[j];
             debug_assert!(xi <= xj);
             // If close enough, attempt merging
             if αi != 0.0 && αj != 0.0 && xj <= xi + ρ {
-                if self.attempt_merge(&mut res, i, j, &accept) {
-                    indices[k+1] = self.spikes.len() - 1;
+                if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) {
+                    // For this to work (the debug_assert! to not trigger above), the new
+                    // coordinate produced by attempt_merge has to be at most xj.
+                    indices[k + 1] = l;
+                    count += 1
                 }
             }
         }
 
-        res
+        count
     }
 }
 
 /// Orders `δ1` and `δ1` according to the first coordinate.
-fn compare_first_coordinate<F : Float>(
-    δ1 : &DeltaMeasure<Loc<F, 2>, F>,
-    δ2 : &DeltaMeasure<Loc<F, 2>, F>
+fn compare_first_coordinate<F: Float>(
+    δ1: &DeltaMeasure<Loc<F, 2>, F>,
+    δ2: &DeltaMeasure<Loc<F, 2>, F>,
 ) -> Ordering {
     let &Loc([x11, ..]) = &δ1.x;
     let &Loc([x21, ..]) = &δ2.x;
@@ -258,28 +240,32 @@
 }
 
 #[replace_float_literals(F::cast_from(literal))]
-impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> {
-
-    fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V>
-    where G : Fn(&'_ Self) -> Option<V> {
+impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> {
+    fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
         // Sort by first coordinate into an indexing array.
         let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate);
 
         // Initialise result
-        let mut res = None;
+        let mut count = 0;
         let mut start_scan_2nd = 0;
 
         // Scan in order
         if indices.len() == 0 {
-            return res
+            return count;
         }
-        for k in 0..indices.len()-1 {
+        for k in 0..indices.len() - 1 {
             let i = indices[k];
-            let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i];
+            let &DeltaMeasure {
+                x: Loc([xi1, xi2]),
+                α: αi,
+            } = &self[i];
 
             if αi == 0.0 {
                 // Nothin to be done if the weight is already zero
-                continue
+                continue;
             }
 
             let mut closest = None;
@@ -289,57 +275,59 @@
             // the _closest_ mergeable spike might have index less than `k` in `indices`, and a
             // merge with it might have not been attempted with this spike if a different closer
             // spike was discovered based on the second coordinate.
-            'scan_2nd: for l in (start_scan_2nd+1)..indices.len() {
+            'scan_2nd: for l in (start_scan_2nd + 1)..indices.len() {
                 if l == k {
                     // Do not attempt to merge a spike with itself
-                    continue
+                    continue;
                 }
                 let j = indices[l];
-                let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j];
+                let &DeltaMeasure {
+                    x: Loc([xj1, xj2]),
+                    α: αj,
+                } = &self[j];
 
                 if xj1 < xi1 - ρ {
                     // Spike `j = indices[l]` has too low first coordinate. Update starting index
                     // for next iteration, and continue scanning.
                     start_scan_2nd = l;
-                    continue 'scan_2nd
+                    continue 'scan_2nd;
                 } else if xj1 > xi1 + ρ {
                     // Break out: spike `j = indices[l]` has already too high first coordinate, no
                     // more close enough spikes can be found due to the sorting of `indices`.
-                    break 'scan_2nd
+                    break 'scan_2nd;
                 }
 
                 // If also second coordinate is close enough, attempt merging if closer than
                 // previously discovered mergeable spikes.
-                let d2 = (xi2-xj2).abs();
+                let d2 = (xi2 - xj2).abs();
                 if αj != 0.0 && d2 <= ρ {
-                    let r1 = xi1-xj1;
-                    let d = (d2*d2 + r1*r1).sqrt();
+                    let r1 = xi1 - xj1;
+                    let d = (d2 * d2 + r1 * r1).sqrt();
                     match closest {
                         None => closest = Some((l, j, d)),
                         Some((_, _, r)) if r > d => closest = Some((l, j, d)),
-                        _ => {},
+                        _ => {}
                     }
                 }
             }
 
             // Attempt merging closest close-enough spike
             if let Some((l, j, _)) = closest {
-                if self.attempt_merge(&mut res, i, j, &accept) {
+                if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) {
                     // If merge was succesfull, make new spike candidate for merging.
-                    indices[l] = self.spikes.len() - 1;
-                    let compare = |i, j| compare_first_coordinate(&self.spikes[i],
-                                                                  &self.spikes[j]);
+                    indices[l] = n;
+                    count += 1;
+                    let compare = |i, j| compare_first_coordinate(&self.spikes[i], &self.spikes[j]);
                     // Re-sort relevant range of indices
                     if l < k {
                         indices[l..k].sort_by(|&i, &j| compare(i, j));
                     } else {
-                        indices[k+1..=l].sort_by(|&i, &j| compare(i, j));
+                        indices[k + 1..=l].sort_by(|&i, &j| compare(i, j));
                     }
                 }
             }
         }
 
-        res
+        count
     }
 }
-

mercurial