--- a/src/measures/merging.rs Tue Aug 01 10:25:09 2023 +0300 +++ b/src/measures/merging.rs Mon Feb 17 13:54:53 2025 -0500 @@ -7,60 +7,35 @@ */ use numeric_literals::replace_float_literals; +use serde::{Deserialize, Serialize}; use std::cmp::Ordering; -use serde::{Serialize, Deserialize}; //use clap::builder::{PossibleValuesParser, PossibleValue}; use alg_tools::nanleast::NaNLeast; -use crate::types::*; use super::delta::*; use super::discrete::*; +use crate::types::*; /// Spike merging heuristic selection #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] #[allow(dead_code)] -pub enum SpikeMergingMethod<F> { - /// Try to merge spikes within a given radius of eachother - HeuristicRadius(F), - /// No merging - None, -} - -// impl<F : Float> SpikeMergingMethod<F> { -// /// This is for [`clap`] to display command line help. -// pub fn value_parser() -> PossibleValuesParser { -// PossibleValuesParser::new([ -// PossibleValue::new("none").help("No merging"), -// PossibleValue::new("<radius>").help("Heuristic merging within indicated radius") -// ]) -// } -// } - -impl<F : ClapFloat> std::fmt::Display for SpikeMergingMethod<F> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - match self { - Self::None => write!(f, "none"), - Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f), - } - } -} - -impl<F : ClapFloat> std::str::FromStr for SpikeMergingMethod<F> { - type Err = F::Err; - - fn from_str(s: &str) -> Result<Self, Self::Err> { - if s == "none" { - Ok(Self::None) - } else { - Ok(Self::HeuristicRadius(F::from_str(s)?)) - } - } +pub struct SpikeMergingMethod<F> { + // Merging radius + pub(crate) radius: F, + // Enabled + pub(crate) enabled: bool, + // Interpolate merged points + pub(crate) interp: bool, } #[replace_float_literals(F::cast_from(literal))] -impl<F : Float> Default for SpikeMergingMethod<F> { +impl<F: Float> Default for SpikeMergingMethod<F> { fn default() -> Self { - SpikeMergingMethod::HeuristicRadius(0.02) + SpikeMergingMethod { + radius: 0.01, + enabled: false, + interp: true, + } } } @@ -75,13 +50,18 @@ /// an arbitrary value. This method will return that value for the *last* accepted merge, or /// [`None`] if no merge was accepted. /// - /// This method is stable with respect to spike locations: on merge, the weight of existing - /// spikes is set to zero, and a new one inserted at the end of the spike vector. - fn merge_spikes<G, V>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> Option<V> - where G : Fn(&'_ Self) -> Option<V> { - match method { - SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept), - SpikeMergingMethod::None => None, + /// This method is stable with respect to spike locations: on merge, the weights of existing + /// removed spikes is set to zero, new ones inserted at the end of the spike vector. + /// They merge may also be performed by increasing the weights of the existing spikes, + /// without inserting new spikes. + fn merge_spikes<G>(&mut self, method: SpikeMergingMethod<F>, accept: G) -> usize + where + G: FnMut(&'_ Self) -> bool, + { + if method.enabled { + self.do_merge_spikes_radius(method.radius, method.interp, accept) + } else { + 0 } } @@ -90,35 +70,37 @@ /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial - /// `self` is returned. + /// `self` is returned. also the number of merges is returned; fn merge_spikes_fitness<G, H, V, O>( &mut self, - method : SpikeMergingMethod<F>, - value : G, - fitness : H - ) -> V - where G : Fn(&'_ Self) -> V, - H : Fn(&'_ V) -> O, - O : PartialOrd { - let initial_res = value(self); - let initial_fitness = fitness(&initial_res); - self.merge_spikes(method, |μ| { - let res = value(μ); - (fitness(&res) <= initial_fitness).then_some(res) - }).unwrap_or(initial_res) + method: SpikeMergingMethod<F>, + value: G, + fitness: H, + ) -> (V, usize) + where + G: Fn(&'_ Self) -> V, + H: Fn(&'_ V) -> O, + O: PartialOrd, + { + let mut res = value(self); + let initial_fitness = fitness(&res); + let count = self.merge_spikes(method, |μ| { + res = value(μ); + fitness(&res) <= initial_fitness + }); + (res, count) } /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). /// - /// This method implements [`SpikeMerging::merge_spikes`] for - /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are - /// as for that method. - fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> - where G : Fn(&'_ Self) -> Option<V>; + /// This method implements [`SpikeMerging::merge_spikes`]. + fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, accept: G) -> usize + where + G: FnMut(&'_ Self) -> bool; } #[replace_float_literals(F::cast_from(literal))] -impl<F : Float, const N : usize> DiscreteMeasure<Loc<F, N>, F> { +impl<F: Float, const N: usize> DiscreteMeasure<Loc<F, N>, F> { /// Attempts to merge spikes with indices `i` and `j`. /// /// This assumes that the weights of the two spikes have already been checked not to be zero. @@ -126,78 +108,72 @@ /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its /// return value. - fn attempt_merge<G, V>( + /// + /// Returns the index of `self.spikes` storing the new spike. + fn attempt_merge<G>( &mut self, - res : &mut Option<V>, - i : usize, - j : usize, - accept : &G - ) -> bool - where G : Fn(&'_ Self) -> Option<V> { - let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; - let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; + i: usize, + j: usize, + interp: bool, + accept: &mut G, + ) -> Option<usize> + where + G: FnMut(&'_ Self) -> bool, + { + let &DeltaMeasure { x: xi, α: αi } = &self.spikes[i]; + let &DeltaMeasure { x: xj, α: αj } = &self.spikes[j]; - // Merge inplace - self.spikes[i].α = 0.0; - self.spikes[j].α = 0.0; - //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }); - self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) }); - match accept(self) { - some@Some(..) => { - // Merge accepted, update our return value - *res = some; - // On next iteration process the newly merged spike. - //indices[k+1] = self.spikes.len() - 1; - true - }, - None => { + if interp { + // Merge inplace + self.spikes[i].α = 0.0; + self.spikes[j].α = 0.0; + let αia = αi.abs(); + let αja = αj.abs(); + self.spikes.push(DeltaMeasure { + α: αi + αj, + x: (xi * αia + xj * αja) / (αia + αja), + }); + if accept(self) { + Some(self.spikes.len() - 1) + } else { // Merge not accepted, restore modification self.spikes[i].α = αi; self.spikes[j].α = αj; self.spikes.pop(); - false + None + } + } else { + // Attempt merge inplace, first combination + self.spikes[i].α = αi + αj; + self.spikes[j].α = 0.0; + if accept(self) { + // Merge accepted + Some(i) + } else { + // Attempt merge inplace, second combination + self.spikes[i].α = 0.0; + self.spikes[j].α = αi + αj; + if accept(self) { + // Merge accepted + Some(j) + } else { + // Merge not accepted, restore modification + self.spikes[i].α = αi; + self.spikes[j].α = αj; + None + } } } } - - /* - /// Attempts to merge spikes with indices i and j, acceptance through a delta. - fn attempt_merge_change<G, V>( - &mut self, - res : &mut Option<V>, - i : usize, - j : usize, - accept_change : &G - ) -> bool - where G : Fn(&'_ Self) -> Option<V> { - let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; - let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; - let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }; - let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into(); - - match accept_change(&λ) { - some@Some(..) => { - // Merge accepted, update our return value - *res = some; - self.spikes[i].α = 0.0; - self.spikes[j].α = 0.0; - self.spikes.push(δ); - true - }, - None => { - false - } - } - }*/ - } /// Sorts a vector of indices into `slice` by `compare`. /// /// The closure `compare` operators on references to elements of `slice`. /// Returns the sorted vector of indices into `slice`. -pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize> -where F : FnMut(&V, &V) -> Ordering +pub fn sort_indices_by<V, F>(slice: &[V], mut compare: F) -> Vec<usize> +where + F: FnMut(&V, &V) -> Ordering, { let mut indices = Vec::from_iter(0..slice.len()); indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); @@ -205,14 +181,11 @@ } #[replace_float_literals(F::cast_from(literal))] -impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> { - - fn do_merge_spikes_radius<G, V>( - &mut self, - ρ : F, - accept : G - ) -> Option<V> - where G : Fn(&'_ Self) -> Option<V> { +impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> { + fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize + where + G: FnMut(&'_ Self) -> bool, + { // Sort by coordinate into an indexing array. let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { let &Loc([x1]) = &δ1.x; @@ -222,34 +195,43 @@ }); // Initialise result - let mut res = None; + let mut count = 0; // Scan consecutive pairs and merge if close enough and accepted by `accept`. if indices.len() == 0 { - return res + return count; } - for k in 0..(indices.len()-1) { + for k in 0..(indices.len() - 1) { let i = indices[k]; - let j = indices[k+1]; - let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i]; - let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j]; + let j = indices[k + 1]; + let &DeltaMeasure { + x: Loc([xi]), + α: αi, + } = &self.spikes[i]; + let &DeltaMeasure { + x: Loc([xj]), + α: αj, + } = &self.spikes[j]; debug_assert!(xi <= xj); // If close enough, attempt merging if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { - if self.attempt_merge(&mut res, i, j, &accept) { - indices[k+1] = self.spikes.len() - 1; + if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) { + // For this to work (the debug_assert! to not trigger above), the new + // coordinate produced by attempt_merge has to be at most xj. + indices[k + 1] = l; + count += 1 } } } - res + count } } /// Orders `δ1` and `δ1` according to the first coordinate. -fn compare_first_coordinate<F : Float>( - δ1 : &DeltaMeasure<Loc<F, 2>, F>, - δ2 : &DeltaMeasure<Loc<F, 2>, F> +fn compare_first_coordinate<F: Float>( + δ1: &DeltaMeasure<Loc<F, 2>, F>, + δ2: &DeltaMeasure<Loc<F, 2>, F>, ) -> Ordering { let &Loc([x11, ..]) = &δ1.x; let &Loc([x21, ..]) = &δ2.x; @@ -258,28 +240,32 @@ } #[replace_float_literals(F::cast_from(literal))] -impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> { - - fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> - where G : Fn(&'_ Self) -> Option<V> { +impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> { + fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize + where + G: FnMut(&'_ Self) -> bool, + { // Sort by first coordinate into an indexing array. let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); // Initialise result - let mut res = None; + let mut count = 0; let mut start_scan_2nd = 0; // Scan in order if indices.len() == 0 { - return res + return count; } - for k in 0..indices.len()-1 { + for k in 0..indices.len() - 1 { let i = indices[k]; - let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i]; + let &DeltaMeasure { + x: Loc([xi1, xi2]), + α: αi, + } = &self[i]; if αi == 0.0 { // Nothin to be done if the weight is already zero - continue + continue; } let mut closest = None; @@ -289,57 +275,59 @@ // the _closest_ mergeable spike might have index less than `k` in `indices`, and a // merge with it might have not been attempted with this spike if a different closer // spike was discovered based on the second coordinate. - 'scan_2nd: for l in (start_scan_2nd+1)..indices.len() { + 'scan_2nd: for l in (start_scan_2nd + 1)..indices.len() { if l == k { // Do not attempt to merge a spike with itself - continue + continue; } let j = indices[l]; - let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j]; + let &DeltaMeasure { + x: Loc([xj1, xj2]), + α: αj, + } = &self[j]; if xj1 < xi1 - ρ { // Spike `j = indices[l]` has too low first coordinate. Update starting index // for next iteration, and continue scanning. start_scan_2nd = l; - continue 'scan_2nd + continue 'scan_2nd; } else if xj1 > xi1 + ρ { // Break out: spike `j = indices[l]` has already too high first coordinate, no // more close enough spikes can be found due to the sorting of `indices`. - break 'scan_2nd + break 'scan_2nd; } // If also second coordinate is close enough, attempt merging if closer than // previously discovered mergeable spikes. - let d2 = (xi2-xj2).abs(); + let d2 = (xi2 - xj2).abs(); if αj != 0.0 && d2 <= ρ { - let r1 = xi1-xj1; - let d = (d2*d2 + r1*r1).sqrt(); + let r1 = xi1 - xj1; + let d = (d2 * d2 + r1 * r1).sqrt(); match closest { None => closest = Some((l, j, d)), Some((_, _, r)) if r > d => closest = Some((l, j, d)), - _ => {}, + _ => {} } } } // Attempt merging closest close-enough spike if let Some((l, j, _)) = closest { - if self.attempt_merge(&mut res, i, j, &accept) { + if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) { // If merge was succesfull, make new spike candidate for merging. - indices[l] = self.spikes.len() - 1; - let compare = |i, j| compare_first_coordinate(&self.spikes[i], - &self.spikes[j]); + indices[l] = n; + count += 1; + let compare = |i, j| compare_first_coordinate(&self.spikes[i], &self.spikes[j]); // Re-sort relevant range of indices if l < k { indices[l..k].sort_by(|&i, &j| compare(i, j)); } else { - indices[k+1..=l].sort_by(|&i, &j| compare(i, j)); + indices[k + 1..=l].sort_by(|&i, &j| compare(i, j)); } } } } - res + count } } -