src/merging.rs

changeset 0
e8f3b6c55ce7
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/merging.rs	Fri Nov 28 12:48:17 2025 -0500
@@ -0,0 +1,316 @@
+/*!
+Spike merging heuristics for [`DiscreteMeasure`]s.
+
+This module primarily provides the [`SpikeMerging`] trait, and within it,
+the [`SpikeMerging::merge_spikes`] method. The trait is implemented on
+[`DiscreteMeasure<Loc<N, F>, F>`]s in dimensions `N=1` and `N=2`.
+*/
+
+use numeric_literals::replace_float_literals;
+use serde::{Deserialize, Serialize};
+use std::cmp::Ordering;
+//use clap::builder::{PossibleValuesParser, PossibleValue};
+use alg_tools::nanleast::NaNLeast;
+
+use super::delta::*;
+use super::discrete::*;
+use alg_tools::loc::Loc;
+use alg_tools::types::*;
+
+/// Spike merging heuristic selection
+#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
+#[allow(dead_code)]
+pub struct SpikeMergingMethod<F> {
+    // Merging radius
+    pub radius: F,
+    // Enabled
+    pub enabled: bool,
+    // Interpolate merged points
+    pub interp: bool,
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F: Float> Default for SpikeMergingMethod<F> {
+    fn default() -> Self {
+        SpikeMergingMethod { radius: 0.01, enabled: false, interp: true }
+    }
+}
+
+/// Trait for dimension-dependent implementation of heuristic peak merging strategies.
+pub trait SpikeMerging<F> {
+    /// Attempt spike merging according to [`SpikeMerging`] method.
+    ///
+    /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure
+    /// `accept` if any merging is performed. The closure should accept as its only parameter a
+    /// new candidate measure (it will generally be internally mutated `self`, although this is
+    /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of
+    /// an arbitrary value. This method will return that value for the *last* accepted merge, or
+    /// [`None`] if no merge was accepted.
+    ///
+    /// This method is stable with respect to spike locations:  on merge, the weights of existing
+    /// removed spikes is set to zero, new ones inserted at the end of the spike vector.
+    /// They merge may also be performed by increasing the weights of the existing spikes,
+    /// without inserting new spikes.
+    fn merge_spikes<G>(&mut self, method: SpikeMergingMethod<F>, accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
+        if method.enabled {
+            self.do_merge_spikes_radius(method.radius, method.interp, accept)
+        } else {
+            0
+        }
+    }
+
+    /// Attempt to merge spikes based on a value and a fitness function.
+    ///
+    /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of
+    /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value`
+    // for a merge  accepted by `fitness`. If no merge was accepted, `value` applied to the initial
+    /// `self` is returned. also the number of merges is returned;
+    fn merge_spikes_fitness<G, H, V, O>(
+        &mut self,
+        method: SpikeMergingMethod<F>,
+        value: G,
+        fitness: H,
+    ) -> (V, usize)
+    where
+        G: Fn(&'_ Self) -> V,
+        H: Fn(&'_ V) -> O,
+        O: PartialOrd,
+    {
+        let mut res = value(self);
+        let initial_fitness = fitness(&res);
+        let count = self.merge_spikes(method, |μ| {
+            res = value(μ);
+            fitness(&res) <= initial_fitness
+        });
+        (res, count)
+    }
+
+    /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm).
+    ///
+    /// This method implements [`SpikeMerging::merge_spikes`].
+    fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool;
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F: Float, const N: usize> DiscreteMeasure<Loc<N, F>, F> {
+    /// Attempts to merge spikes with indices `i` and `j`.
+    ///
+    /// This assumes that the weights of the two spikes have already been checked not to be zero.
+    ///
+    /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`].
+    /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its
+    /// return value.
+    ///
+    /// Returns the index of `self.spikes` storing the new spike.
+    fn attempt_merge<G>(
+        &mut self,
+        i: usize,
+        j: usize,
+        interp: bool,
+        accept: &mut G,
+    ) -> Option<usize>
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
+        let &DeltaMeasure { x: xi, α: αi } = &self.spikes[i];
+        let &DeltaMeasure { x: xj, α: αj } = &self.spikes[j];
+
+        if interp {
+            // Merge inplace
+            self.spikes[i].α = 0.0;
+            self.spikes[j].α = 0.0;
+            let αia = αi.abs();
+            let αja = αj.abs();
+            self.spikes
+                .push(DeltaMeasure { α: αi + αj, x: (xi * αia + xj * αja) / (αia + αja) });
+            if accept(self) {
+                Some(self.spikes.len() - 1)
+            } else {
+                // Merge not accepted, restore modification
+                self.spikes[i].α = αi;
+                self.spikes[j].α = αj;
+                self.spikes.pop();
+                None
+            }
+        } else {
+            // Attempt merge inplace, first combination
+            self.spikes[i].α = αi + αj;
+            self.spikes[j].α = 0.0;
+            if accept(self) {
+                // Merge accepted
+                Some(i)
+            } else {
+                // Attempt merge inplace, second combination
+                self.spikes[i].α = 0.0;
+                self.spikes[j].α = αi + αj;
+                if accept(self) {
+                    // Merge accepted
+                    Some(j)
+                } else {
+                    // Merge not accepted, restore modification
+                    self.spikes[i].α = αi;
+                    self.spikes[j].α = αj;
+                    None
+                }
+            }
+        }
+    }
+}
+
+/// Sorts a vector of indices into `slice` by `compare`.
+///
+/// The closure `compare` operators on references to elements of `slice`.
+/// Returns the sorted vector of indices into `slice`.
+pub fn sort_indices_by<V, F>(slice: &[V], mut compare: F) -> Vec<usize>
+where
+    F: FnMut(&V, &V) -> Ordering,
+{
+    let mut indices = Vec::from_iter(0..slice.len());
+    indices.sort_by(|&i, &j| compare(&slice[i], &slice[j]));
+    indices
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<1, F>, F> {
+    fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
+        // Sort by coordinate into an indexing array.
+        let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| {
+            let &Loc([x1]) = &δ1.x;
+            let &Loc([x2]) = &δ2.x;
+            // nan-ignoring ordering of floats
+            NaNLeast(x1).cmp(&NaNLeast(x2))
+        });
+
+        // Initialise result
+        let mut count = 0;
+
+        // Scan consecutive pairs and merge if close enough and accepted by `accept`.
+        if indices.len() == 0 {
+            return count;
+        }
+        for k in 0..(indices.len() - 1) {
+            let i = indices[k];
+            let j = indices[k + 1];
+            let &DeltaMeasure { x: Loc([xi]), α: αi } = &self.spikes[i];
+            let &DeltaMeasure { x: Loc([xj]), α: αj } = &self.spikes[j];
+            debug_assert!(xi <= xj);
+            // If close enough, attempt merging
+            if αi != 0.0 && αj != 0.0 && xj <= xi + ρ {
+                if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) {
+                    // For this to work (the debug_assert! to not trigger above), the new
+                    // coordinate produced by attempt_merge has to be at most xj.
+                    indices[k + 1] = l;
+                    count += 1
+                }
+            }
+        }
+
+        count
+    }
+}
+
+/// Orders `δ1` and `δ1` according to the first coordinate.
+fn compare_first_coordinate<F: Float>(
+    δ1: &DeltaMeasure<Loc<2, F>, F>,
+    δ2: &DeltaMeasure<Loc<2, F>, F>,
+) -> Ordering {
+    let &Loc([x11, ..]) = &δ1.x;
+    let &Loc([x21, ..]) = &δ2.x;
+    // nan-ignoring ordering of floats
+    NaNLeast(x11).cmp(&NaNLeast(x21))
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<2, F>, F> {
+    fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize
+    where
+        G: FnMut(&'_ Self) -> bool,
+    {
+        // Sort by first coordinate into an indexing array.
+        let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate);
+
+        // Initialise result
+        let mut count = 0;
+        let mut start_scan_2nd = 0;
+
+        // Scan in order
+        if indices.len() == 0 {
+            return count;
+        }
+        for k in 0..indices.len() - 1 {
+            let i = indices[k];
+            let &DeltaMeasure { x: Loc([xi1, xi2]), α: αi } = &self[i];
+
+            if αi == 0.0 {
+                // Nothin to be done if the weight is already zero
+                continue;
+            }
+
+            let mut closest = None;
+
+            // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd`
+            // the smallest invalid merging index on the previous loop iteration, because a
+            // the _closest_ mergeable spike might have index less than `k` in `indices`, and a
+            // merge with it might have not been attempted with this spike if a different closer
+            // spike was discovered based on the second coordinate.
+            'scan_2nd: for l in (start_scan_2nd + 1)..indices.len() {
+                if l == k {
+                    // Do not attempt to merge a spike with itself
+                    continue;
+                }
+                let j = indices[l];
+                let &DeltaMeasure { x: Loc([xj1, xj2]), α: αj } = &self[j];
+
+                if xj1 < xi1 - ρ {
+                    // Spike `j = indices[l]` has too low first coordinate. Update starting index
+                    // for next iteration, and continue scanning.
+                    start_scan_2nd = l;
+                    continue 'scan_2nd;
+                } else if xj1 > xi1 + ρ {
+                    // Break out: spike `j = indices[l]` has already too high first coordinate, no
+                    // more close enough spikes can be found due to the sorting of `indices`.
+                    break 'scan_2nd;
+                }
+
+                // If also second coordinate is close enough, attempt merging if closer than
+                // previously discovered mergeable spikes.
+                let d2 = (xi2 - xj2).abs();
+                if αj != 0.0 && d2 <= ρ {
+                    let r1 = xi1 - xj1;
+                    let d = (d2 * d2 + r1 * r1).sqrt();
+                    match closest {
+                        None => closest = Some((l, j, d)),
+                        Some((_, _, r)) if r > d => closest = Some((l, j, d)),
+                        _ => {}
+                    }
+                }
+            }
+
+            // Attempt merging closest close-enough spike
+            if let Some((l, j, _)) = closest {
+                if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) {
+                    // If merge was succesfull, make new spike candidate for merging.
+                    indices[l] = n;
+                    count += 1;
+                    let compare = |i, j| compare_first_coordinate(&self.spikes[i], &self.spikes[j]);
+                    // Re-sort relevant range of indices
+                    if l < k {
+                        indices[l..k].sort_by(|&i, &j| compare(i, j));
+                    } else {
+                        indices[k + 1..=l].sort_by(|&i, &j| compare(i, j));
+                    }
+                }
+            }
+        }
+
+        count
+    }
+}

mercurial