diff -r 000000000000 -r eb3c7813b67a src/measures/merging.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/measures/merging.rs Thu Dec 01 23:07:35 2022 +0200 @@ -0,0 +1,345 @@ +/*! +Spike merging heuristics for [`DiscreteMeasure`]s. + +This module primarily provides the [`SpikeMerging`] trait, and within it, +the [`SpikeMerging::merge_spikes`] method. The trait is implemented on +[`DiscreteMeasure, F>`]s in dimensions `N=1` and `N=2`. +*/ + +use numeric_literals::replace_float_literals; +use std::cmp::Ordering; +use serde::{Serialize, Deserialize}; +//use clap::builder::{PossibleValuesParser, PossibleValue}; +use alg_tools::nanleast::NaNLeast; + +use crate::types::*; +use super::delta::*; +use super::discrete::*; + +/// Spike merging heuristic selection +#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] +#[allow(dead_code)] +pub enum SpikeMergingMethod { + /// Try to merge spikes within a given radius of eachother + HeuristicRadius(F), + /// No merging + None, +} + +// impl SpikeMergingMethod { +// /// This is for [`clap`] to display command line help. +// pub fn value_parser() -> PossibleValuesParser { +// PossibleValuesParser::new([ +// PossibleValue::new("none").help("No merging"), +// PossibleValue::new("").help("Heuristic merging within indicated radius") +// ]) +// } +// } + +impl std::fmt::Display for SpikeMergingMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + match self { + Self::None => write!(f, "none"), + Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f), + } + } +} + +impl std::str::FromStr for SpikeMergingMethod { + type Err = F::Err; + + fn from_str(s: &str) -> Result { + if s == "none" { + Ok(Self::None) + } else { + Ok(Self::HeuristicRadius(F::from_str(s)?)) + } + } +} + +#[replace_float_literals(F::cast_from(literal))] +impl Default for SpikeMergingMethod { + fn default() -> Self { + SpikeMergingMethod::HeuristicRadius(0.02) + } +} + +/// Trait for dimension-dependent implementation of heuristic peak merging strategies. +pub trait SpikeMerging { + /// Attempt spike merging according to [`SpikeMerging`] method. + /// + /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure + /// `accept` if any merging is performed. The closure should accept as its only parameter a + /// new candidate measure (it will generally be internally mutated `self`, although this is + /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of + /// an arbitrary value. This method will return that value for the *last* accepted merge, or + /// [`None`] if no merge was accepted. + /// + /// This method is stable with respect to spike locations: on merge, the weight of existing + /// spikes is set to zero, and a new one inserted at the end of the spike vector. + fn merge_spikes(&mut self, method : SpikeMergingMethod, accept : G) -> Option + where G : Fn(&'_ Self) -> Option { + match method { + SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept), + SpikeMergingMethod::None => None, + } + } + + /// Attempt to merge spikes based on a value and a fitness function. + /// + /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of + /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` + // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial + /// `self` is returned. + fn merge_spikes_fitness( + &mut self, + method : SpikeMergingMethod, + value : G, + fitness : H + ) -> V + where G : Fn(&'_ Self) -> V, + H : Fn(&'_ V) -> O, + O : PartialOrd { + let initial_res = value(self); + let initial_fitness = fitness(&initial_res); + self.merge_spikes(method, |μ| { + let res = value(μ); + (fitness(&res) <= initial_fitness).then_some(res) + }).unwrap_or(initial_res) + } + + /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). + /// + /// This method implements [`SpikeMerging::merge_spikes`] for + /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are + /// as for that method. + fn do_merge_spikes_radius(&mut self, ρ : F, accept : G) -> Option + where G : Fn(&'_ Self) -> Option; +} + +#[replace_float_literals(F::cast_from(literal))] +impl DiscreteMeasure, F> { + /// Attempts to merge spikes with indices `i` and `j`. + /// + /// This assumes that the weights of the two spikes have already been checked not to be zero. + /// + /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. + /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its + /// return value. + fn attempt_merge( + &mut self, + res : &mut Option, + i : usize, + j : usize, + accept : &G + ) -> bool + where G : Fn(&'_ Self) -> Option { + let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; + let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; + + // Merge inplace + self.spikes[i].α = 0.0; + self.spikes[j].α = 0.0; + //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }); + self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) }); + match accept(self) { + some@Some(..) => { + // Merge accepted, update our return value + *res = some; + // On next iteration process the newly merged spike. + //indices[k+1] = self.spikes.len() - 1; + true + }, + None => { + // Merge not accepted, restore modification + self.spikes[i].α = αi; + self.spikes[j].α = αj; + self.spikes.pop(); + false + } + } + } + + /* + /// Attempts to merge spikes with indices i and j, acceptance through a delta. + fn attempt_merge_change( + &mut self, + res : &mut Option, + i : usize, + j : usize, + accept_change : &G + ) -> bool + where G : Fn(&'_ Self) -> Option { + let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; + let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; + let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }; + let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into(); + + match accept_change(&λ) { + some@Some(..) => { + // Merge accepted, update our return value + *res = some; + self.spikes[i].α = 0.0; + self.spikes[j].α = 0.0; + self.spikes.push(δ); + true + }, + None => { + false + } + } + }*/ + +} + +/// Sorts a vector of indices into `slice` by `compare`. +/// +/// The closure `compare` operators on references to elements of `slice`. +/// Returns the sorted vector of indices into `slice`. +pub fn sort_indices_by(slice : &[V], mut compare : F) -> Vec +where F : FnMut(&V, &V) -> Ordering +{ + let mut indices = Vec::from_iter(0..slice.len()); + indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); + indices +} + +#[replace_float_literals(F::cast_from(literal))] +impl SpikeMerging for DiscreteMeasure, F> { + + fn do_merge_spikes_radius( + &mut self, + ρ : F, + accept : G + ) -> Option + where G : Fn(&'_ Self) -> Option { + // Sort by coordinate into an indexing array. + let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { + let &Loc([x1]) = &δ1.x; + let &Loc([x2]) = &δ2.x; + // nan-ignoring ordering of floats + NaNLeast(x1).cmp(&NaNLeast(x2)) + }); + + // Initialise result + let mut res = None; + + // Scan consecutive pairs and merge if close enough and accepted by `accept`. + if indices.len() == 0 { + return res + } + for k in 0..(indices.len()-1) { + let i = indices[k]; + let j = indices[k+1]; + let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i]; + let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j]; + debug_assert!(xi <= xj); + // If close enough, attempt merging + if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { + if self.attempt_merge(&mut res, i, j, &accept) { + indices[k+1] = self.spikes.len() - 1; + } + } + } + + res + } +} + +/// Orders `δ1` and `δ1` according to the first coordinate. +fn compare_first_coordinate( + δ1 : &DeltaMeasure, F>, + δ2 : &DeltaMeasure, F> +) -> Ordering { + let &Loc([x11, ..]) = &δ1.x; + let &Loc([x21, ..]) = &δ2.x; + // nan-ignoring ordering of floats + NaNLeast(x11).cmp(&NaNLeast(x21)) +} + +#[replace_float_literals(F::cast_from(literal))] +impl SpikeMerging for DiscreteMeasure, F> { + + fn do_merge_spikes_radius(&mut self, ρ : F, accept : G) -> Option + where G : Fn(&'_ Self) -> Option { + // Sort by first coordinate into an indexing array. + let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); + + // Initialise result + let mut res = None; + let mut start_scan_2nd = 0; + + // Scan in order + if indices.len() == 0 { + return res + } + for k in 0..indices.len()-1 { + let i = indices[k]; + let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i]; + + if αi == 0.0 { + // Nothin to be done if the weight is already zero + continue + } + + let mut closest = None; + + // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd` + // the smallest invalid merging index on the previous loop iteration, because a + // the _closest_ mergeable spike might have index less than `k` in `indices`, and a + // merge with it might have not been attempted with this spike if a different closer + // spike was discovered based on the second coordinate. + 'scan_2nd: for l in (start_scan_2nd+1)..indices.len() { + if l == k { + // Do not attempt to merge a spike with itself + continue + } + let j = indices[l]; + let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j]; + + if xj1 < xi1 - ρ { + // Spike `j = indices[l]` has too low first coordinate. Update starting index + // for next iteration, and continue scanning. + start_scan_2nd = l; + continue 'scan_2nd + } else if xj1 > xi1 + ρ { + // Break out: spike `j = indices[l]` has already too high first coordinate, no + // more close enough spikes can be found due to the sorting of `indices`. + break 'scan_2nd + } + + // If also second coordinate is close enough, attempt merging if closer than + // previously discovered mergeable spikes. + let d2 = (xi2-xj2).abs(); + if αj != 0.0 && d2 <= ρ { + let r1 = xi1-xj1; + let d = (d2*d2 + r1*r1).sqrt(); + match closest { + None => closest = Some((l, j, d)), + Some((_, _, r)) if r > d => closest = Some((l, j, d)), + _ => {}, + } + } + } + + // Attempt merging closest close-enough spike + if let Some((l, j, _)) = closest { + if self.attempt_merge(&mut res, i, j, &accept) { + // If merge was succesfull, make new spike candidate for merging. + indices[l] = self.spikes.len() - 1; + let compare = |i, j| compare_first_coordinate(&self.spikes[i], + &self.spikes[j]); + // Re-sort relevant range of indices + if l < k { + indices[l..k].sort_by(|&i, &j| compare(i, j)); + } else { + indices[k+1..=l].sort_by(|&i, &j| compare(i, j)); + } + } + } + } + + res + } +} +