Wed, 22 Mar 2023 20:37:49 +0200
Bump version
/*! Spike merging heuristics for [`DiscreteMeasure`]s. This module primarily provides the [`SpikeMerging`] trait, and within it, the [`SpikeMerging::merge_spikes`] method. The trait is implemented on [`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`. */ use numeric_literals::replace_float_literals; use std::cmp::Ordering; use serde::{Serialize, Deserialize}; //use clap::builder::{PossibleValuesParser, PossibleValue}; use alg_tools::nanleast::NaNLeast; use crate::types::*; use super::delta::*; use super::discrete::*; /// Spike merging heuristic selection #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[allow(dead_code)] pub enum SpikeMergingMethod<F> { /// Try to merge spikes within a given radius of eachother HeuristicRadius(F), /// No merging None, } // impl<F : Float> SpikeMergingMethod<F> { // /// This is for [`clap`] to display command line help. // pub fn value_parser() -> PossibleValuesParser { // PossibleValuesParser::new([ // PossibleValue::new("none").help("No merging"), // PossibleValue::new("<radius>").help("Heuristic merging within indicated radius") // ]) // } // } impl<F : ClapFloat> std::fmt::Display for SpikeMergingMethod<F> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { match self { Self::None => write!(f, "none"), Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f), } } } impl<F : ClapFloat> std::str::FromStr for SpikeMergingMethod<F> { type Err = F::Err; fn from_str(s: &str) -> Result<Self, Self::Err> { if s == "none" { Ok(Self::None) } else { Ok(Self::HeuristicRadius(F::from_str(s)?)) } } } #[replace_float_literals(F::cast_from(literal))] impl<F : Float> Default for SpikeMergingMethod<F> { fn default() -> Self { SpikeMergingMethod::HeuristicRadius(0.02) } } /// Trait for dimension-dependent implementation of heuristic peak merging strategies. pub trait SpikeMerging<F> { /// Attempt spike merging according to [`SpikeMerging`] method. /// /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure /// `accept` if any merging is performed. The closure should accept as its only parameter a /// new candidate measure (it will generally be internally mutated `self`, although this is /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of /// an arbitrary value. This method will return that value for the *last* accepted merge, or /// [`None`] if no merge was accepted. /// /// This method is stable with respect to spike locations: on merge, the weight of existing /// spikes is set to zero, and a new one inserted at the end of the spike vector. fn merge_spikes<G, V>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> Option<V> where G : Fn(&'_ Self) -> Option<V> { match method { SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept), SpikeMergingMethod::None => None, } } /// Attempt to merge spikes based on a value and a fitness function. /// /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial /// `self` is returned. fn merge_spikes_fitness<G, H, V, O>( &mut self, method : SpikeMergingMethod<F>, value : G, fitness : H ) -> V where G : Fn(&'_ Self) -> V, H : Fn(&'_ V) -> O, O : PartialOrd { let initial_res = value(self); let initial_fitness = fitness(&initial_res); self.merge_spikes(method, |μ| { let res = value(μ); (fitness(&res) <= initial_fitness).then_some(res) }).unwrap_or(initial_res) } /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). /// /// This method implements [`SpikeMerging::merge_spikes`] for /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are /// as for that method. fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> where G : Fn(&'_ Self) -> Option<V>; } #[replace_float_literals(F::cast_from(literal))] impl<F : Float, const N : usize> DiscreteMeasure<Loc<F, N>, F> { /// Attempts to merge spikes with indices `i` and `j`. /// /// This assumes that the weights of the two spikes have already been checked not to be zero. /// /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its /// return value. fn attempt_merge<G, V>( &mut self, res : &mut Option<V>, i : usize, j : usize, accept : &G ) -> bool where G : Fn(&'_ Self) -> Option<V> { let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; // Merge inplace self.spikes[i].α = 0.0; self.spikes[j].α = 0.0; //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }); self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) }); match accept(self) { some@Some(..) => { // Merge accepted, update our return value *res = some; // On next iteration process the newly merged spike. //indices[k+1] = self.spikes.len() - 1; true }, None => { // Merge not accepted, restore modification self.spikes[i].α = αi; self.spikes[j].α = αj; self.spikes.pop(); false } } } /* /// Attempts to merge spikes with indices i and j, acceptance through a delta. fn attempt_merge_change<G, V>( &mut self, res : &mut Option<V>, i : usize, j : usize, accept_change : &G ) -> bool where G : Fn(&'_ Self) -> Option<V> { let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }; let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into(); match accept_change(&λ) { some@Some(..) => { // Merge accepted, update our return value *res = some; self.spikes[i].α = 0.0; self.spikes[j].α = 0.0; self.spikes.push(δ); true }, None => { false } } }*/ } /// Sorts a vector of indices into `slice` by `compare`. /// /// The closure `compare` operators on references to elements of `slice`. /// Returns the sorted vector of indices into `slice`. pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize> where F : FnMut(&V, &V) -> Ordering { let mut indices = Vec::from_iter(0..slice.len()); indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); indices } #[replace_float_literals(F::cast_from(literal))] impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> { fn do_merge_spikes_radius<G, V>( &mut self, ρ : F, accept : G ) -> Option<V> where G : Fn(&'_ Self) -> Option<V> { // Sort by coordinate into an indexing array. let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { let &Loc([x1]) = &δ1.x; let &Loc([x2]) = &δ2.x; // nan-ignoring ordering of floats NaNLeast(x1).cmp(&NaNLeast(x2)) }); // Initialise result let mut res = None; // Scan consecutive pairs and merge if close enough and accepted by `accept`. if indices.len() == 0 { return res } for k in 0..(indices.len()-1) { let i = indices[k]; let j = indices[k+1]; let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i]; let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j]; debug_assert!(xi <= xj); // If close enough, attempt merging if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { if self.attempt_merge(&mut res, i, j, &accept) { indices[k+1] = self.spikes.len() - 1; } } } res } } /// Orders `δ1` and `δ1` according to the first coordinate. fn compare_first_coordinate<F : Float>( δ1 : &DeltaMeasure<Loc<F, 2>, F>, δ2 : &DeltaMeasure<Loc<F, 2>, F> ) -> Ordering { let &Loc([x11, ..]) = &δ1.x; let &Loc([x21, ..]) = &δ2.x; // nan-ignoring ordering of floats NaNLeast(x11).cmp(&NaNLeast(x21)) } #[replace_float_literals(F::cast_from(literal))] impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> { fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> where G : Fn(&'_ Self) -> Option<V> { // Sort by first coordinate into an indexing array. let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); // Initialise result let mut res = None; let mut start_scan_2nd = 0; // Scan in order if indices.len() == 0 { return res } for k in 0..indices.len()-1 { let i = indices[k]; let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i]; if αi == 0.0 { // Nothin to be done if the weight is already zero continue } let mut closest = None; // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd` // the smallest invalid merging index on the previous loop iteration, because a // the _closest_ mergeable spike might have index less than `k` in `indices`, and a // merge with it might have not been attempted with this spike if a different closer // spike was discovered based on the second coordinate. 'scan_2nd: for l in (start_scan_2nd+1)..indices.len() { if l == k { // Do not attempt to merge a spike with itself continue } let j = indices[l]; let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j]; if xj1 < xi1 - ρ { // Spike `j = indices[l]` has too low first coordinate. Update starting index // for next iteration, and continue scanning. start_scan_2nd = l; continue 'scan_2nd } else if xj1 > xi1 + ρ { // Break out: spike `j = indices[l]` has already too high first coordinate, no // more close enough spikes can be found due to the sorting of `indices`. break 'scan_2nd } // If also second coordinate is close enough, attempt merging if closer than // previously discovered mergeable spikes. let d2 = (xi2-xj2).abs(); if αj != 0.0 && d2 <= ρ { let r1 = xi1-xj1; let d = (d2*d2 + r1*r1).sqrt(); match closest { None => closest = Some((l, j, d)), Some((_, _, r)) if r > d => closest = Some((l, j, d)), _ => {}, } } } // Attempt merging closest close-enough spike if let Some((l, j, _)) = closest { if self.attempt_merge(&mut res, i, j, &accept) { // If merge was succesfull, make new spike candidate for merging. indices[l] = self.spikes.len() - 1; let compare = |i, j| compare_first_coordinate(&self.spikes[i], &self.spikes[j]); // Re-sort relevant range of indices if l < k { indices[l..k].sort_by(|&i, &j| compare(i, j)); } else { indices[k+1..=l].sort_by(|&i, &j| compare(i, j)); } } } } res } }