--- a/src/measures/merging.rs Sun Apr 27 15:03:51 2025 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,333 +0,0 @@ -/*! -Spike merging heuristics for [`DiscreteMeasure`]s. - -This module primarily provides the [`SpikeMerging`] trait, and within it, -the [`SpikeMerging::merge_spikes`] method. The trait is implemented on -[`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`. -*/ - -use numeric_literals::replace_float_literals; -use serde::{Deserialize, Serialize}; -use std::cmp::Ordering; -//use clap::builder::{PossibleValuesParser, PossibleValue}; -use alg_tools::nanleast::NaNLeast; - -use super::delta::*; -use super::discrete::*; -use crate::types::*; - -/// Spike merging heuristic selection -#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] -#[allow(dead_code)] -pub struct SpikeMergingMethod<F> { - // Merging radius - pub(crate) radius: F, - // Enabled - pub(crate) enabled: bool, - // Interpolate merged points - pub(crate) interp: bool, -} - -#[replace_float_literals(F::cast_from(literal))] -impl<F: Float> Default for SpikeMergingMethod<F> { - fn default() -> Self { - SpikeMergingMethod { - radius: 0.01, - enabled: false, - interp: true, - } - } -} - -/// Trait for dimension-dependent implementation of heuristic peak merging strategies. -pub trait SpikeMerging<F> { - /// Attempt spike merging according to [`SpikeMerging`] method. - /// - /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure - /// `accept` if any merging is performed. The closure should accept as its only parameter a - /// new candidate measure (it will generally be internally mutated `self`, although this is - /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of - /// an arbitrary value. This method will return that value for the *last* accepted merge, or - /// [`None`] if no merge was accepted. - /// - /// This method is stable with respect to spike locations: on merge, the weights of existing - /// removed spikes is set to zero, new ones inserted at the end of the spike vector. - /// They merge may also be performed by increasing the weights of the existing spikes, - /// without inserting new spikes. - fn merge_spikes<G>(&mut self, method: SpikeMergingMethod<F>, accept: G) -> usize - where - G: FnMut(&'_ Self) -> bool, - { - if method.enabled { - self.do_merge_spikes_radius(method.radius, method.interp, accept) - } else { - 0 - } - } - - /// Attempt to merge spikes based on a value and a fitness function. - /// - /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of - /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` - // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial - /// `self` is returned. also the number of merges is returned; - fn merge_spikes_fitness<G, H, V, O>( - &mut self, - method: SpikeMergingMethod<F>, - value: G, - fitness: H, - ) -> (V, usize) - where - G: Fn(&'_ Self) -> V, - H: Fn(&'_ V) -> O, - O: PartialOrd, - { - let mut res = value(self); - let initial_fitness = fitness(&res); - let count = self.merge_spikes(method, |μ| { - res = value(μ); - fitness(&res) <= initial_fitness - }); - (res, count) - } - - /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). - /// - /// This method implements [`SpikeMerging::merge_spikes`]. - fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, accept: G) -> usize - where - G: FnMut(&'_ Self) -> bool; -} - -#[replace_float_literals(F::cast_from(literal))] -impl<F: Float, const N: usize> DiscreteMeasure<Loc<F, N>, F> { - /// Attempts to merge spikes with indices `i` and `j`. - /// - /// This assumes that the weights of the two spikes have already been checked not to be zero. - /// - /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. - /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its - /// return value. - /// - /// Returns the index of `self.spikes` storing the new spike. - fn attempt_merge<G>( - &mut self, - i: usize, - j: usize, - interp: bool, - accept: &mut G, - ) -> Option<usize> - where - G: FnMut(&'_ Self) -> bool, - { - let &DeltaMeasure { x: xi, α: αi } = &self.spikes[i]; - let &DeltaMeasure { x: xj, α: αj } = &self.spikes[j]; - - if interp { - // Merge inplace - self.spikes[i].α = 0.0; - self.spikes[j].α = 0.0; - let αia = αi.abs(); - let αja = αj.abs(); - self.spikes.push(DeltaMeasure { - α: αi + αj, - x: (xi * αia + xj * αja) / (αia + αja), - }); - if accept(self) { - Some(self.spikes.len() - 1) - } else { - // Merge not accepted, restore modification - self.spikes[i].α = αi; - self.spikes[j].α = αj; - self.spikes.pop(); - None - } - } else { - // Attempt merge inplace, first combination - self.spikes[i].α = αi + αj; - self.spikes[j].α = 0.0; - if accept(self) { - // Merge accepted - Some(i) - } else { - // Attempt merge inplace, second combination - self.spikes[i].α = 0.0; - self.spikes[j].α = αi + αj; - if accept(self) { - // Merge accepted - Some(j) - } else { - // Merge not accepted, restore modification - self.spikes[i].α = αi; - self.spikes[j].α = αj; - None - } - } - } - } -} - -/// Sorts a vector of indices into `slice` by `compare`. -/// -/// The closure `compare` operators on references to elements of `slice`. -/// Returns the sorted vector of indices into `slice`. -pub fn sort_indices_by<V, F>(slice: &[V], mut compare: F) -> Vec<usize> -where - F: FnMut(&V, &V) -> Ordering, -{ - let mut indices = Vec::from_iter(0..slice.len()); - indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); - indices -} - -#[replace_float_literals(F::cast_from(literal))] -impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> { - fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize - where - G: FnMut(&'_ Self) -> bool, - { - // Sort by coordinate into an indexing array. - let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { - let &Loc([x1]) = &δ1.x; - let &Loc([x2]) = &δ2.x; - // nan-ignoring ordering of floats - NaNLeast(x1).cmp(&NaNLeast(x2)) - }); - - // Initialise result - let mut count = 0; - - // Scan consecutive pairs and merge if close enough and accepted by `accept`. - if indices.len() == 0 { - return count; - } - for k in 0..(indices.len() - 1) { - let i = indices[k]; - let j = indices[k + 1]; - let &DeltaMeasure { - x: Loc([xi]), - α: αi, - } = &self.spikes[i]; - let &DeltaMeasure { - x: Loc([xj]), - α: αj, - } = &self.spikes[j]; - debug_assert!(xi <= xj); - // If close enough, attempt merging - if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { - if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) { - // For this to work (the debug_assert! to not trigger above), the new - // coordinate produced by attempt_merge has to be at most xj. - indices[k + 1] = l; - count += 1 - } - } - } - - count - } -} - -/// Orders `δ1` and `δ1` according to the first coordinate. -fn compare_first_coordinate<F: Float>( - δ1: &DeltaMeasure<Loc<F, 2>, F>, - δ2: &DeltaMeasure<Loc<F, 2>, F>, -) -> Ordering { - let &Loc([x11, ..]) = &δ1.x; - let &Loc([x21, ..]) = &δ2.x; - // nan-ignoring ordering of floats - NaNLeast(x11).cmp(&NaNLeast(x21)) -} - -#[replace_float_literals(F::cast_from(literal))] -impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> { - fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize - where - G: FnMut(&'_ Self) -> bool, - { - // Sort by first coordinate into an indexing array. - let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); - - // Initialise result - let mut count = 0; - let mut start_scan_2nd = 0; - - // Scan in order - if indices.len() == 0 { - return count; - } - for k in 0..indices.len() - 1 { - let i = indices[k]; - let &DeltaMeasure { - x: Loc([xi1, xi2]), - α: αi, - } = &self[i]; - - if αi == 0.0 { - // Nothin to be done if the weight is already zero - continue; - } - - let mut closest = None; - - // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd` - // the smallest invalid merging index on the previous loop iteration, because a - // the _closest_ mergeable spike might have index less than `k` in `indices`, and a - // merge with it might have not been attempted with this spike if a different closer - // spike was discovered based on the second coordinate. - 'scan_2nd: for l in (start_scan_2nd + 1)..indices.len() { - if l == k { - // Do not attempt to merge a spike with itself - continue; - } - let j = indices[l]; - let &DeltaMeasure { - x: Loc([xj1, xj2]), - α: αj, - } = &self[j]; - - if xj1 < xi1 - ρ { - // Spike `j = indices[l]` has too low first coordinate. Update starting index - // for next iteration, and continue scanning. - start_scan_2nd = l; - continue 'scan_2nd; - } else if xj1 > xi1 + ρ { - // Break out: spike `j = indices[l]` has already too high first coordinate, no - // more close enough spikes can be found due to the sorting of `indices`. - break 'scan_2nd; - } - - // If also second coordinate is close enough, attempt merging if closer than - // previously discovered mergeable spikes. - let d2 = (xi2 - xj2).abs(); - if αj != 0.0 && d2 <= ρ { - let r1 = xi1 - xj1; - let d = (d2 * d2 + r1 * r1).sqrt(); - match closest { - None => closest = Some((l, j, d)), - Some((_, _, r)) if r > d => closest = Some((l, j, d)), - _ => {} - } - } - } - - // Attempt merging closest close-enough spike - if let Some((l, j, _)) = closest { - if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) { - // If merge was succesfull, make new spike candidate for merging. - indices[l] = n; - count += 1; - let compare = |i, j| compare_first_coordinate(&self.spikes[i], &self.spikes[j]); - // Re-sort relevant range of indices - if l < k { - indices[l..k].sort_by(|&i, &j| compare(i, j)); - } else { - indices[k + 1..=l].sort_by(|&i, &j| compare(i, j)); - } - } - } - } - - count - } -}