| 1 /*! |
|
| 2 Spike merging heuristics for [`DiscreteMeasure`]s. |
|
| 3 |
|
| 4 This module primarily provides the [`SpikeMerging`] trait, and within it, |
|
| 5 the [`SpikeMerging::merge_spikes`] method. The trait is implemented on |
|
| 6 [`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`. |
|
| 7 */ |
|
| 8 |
|
| 9 use numeric_literals::replace_float_literals; |
|
| 10 use serde::{Deserialize, Serialize}; |
|
| 11 use std::cmp::Ordering; |
|
| 12 //use clap::builder::{PossibleValuesParser, PossibleValue}; |
|
| 13 use alg_tools::nanleast::NaNLeast; |
|
| 14 |
|
| 15 use super::delta::*; |
|
| 16 use super::discrete::*; |
|
| 17 use crate::types::*; |
|
| 18 |
|
| 19 /// Spike merging heuristic selection |
|
| 20 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
| 21 #[allow(dead_code)] |
|
| 22 pub struct SpikeMergingMethod<F> { |
|
| 23 // Merging radius |
|
| 24 pub(crate) radius: F, |
|
| 25 // Enabled |
|
| 26 pub(crate) enabled: bool, |
|
| 27 // Interpolate merged points |
|
| 28 pub(crate) interp: bool, |
|
| 29 } |
|
| 30 |
|
| 31 #[replace_float_literals(F::cast_from(literal))] |
|
| 32 impl<F: Float> Default for SpikeMergingMethod<F> { |
|
| 33 fn default() -> Self { |
|
| 34 SpikeMergingMethod { |
|
| 35 radius: 0.01, |
|
| 36 enabled: false, |
|
| 37 interp: true, |
|
| 38 } |
|
| 39 } |
|
| 40 } |
|
| 41 |
|
| 42 /// Trait for dimension-dependent implementation of heuristic peak merging strategies. |
|
| 43 pub trait SpikeMerging<F> { |
|
| 44 /// Attempt spike merging according to [`SpikeMerging`] method. |
|
| 45 /// |
|
| 46 /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure |
|
| 47 /// `accept` if any merging is performed. The closure should accept as its only parameter a |
|
| 48 /// new candidate measure (it will generally be internally mutated `self`, although this is |
|
| 49 /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of |
|
| 50 /// an arbitrary value. This method will return that value for the *last* accepted merge, or |
|
| 51 /// [`None`] if no merge was accepted. |
|
| 52 /// |
|
| 53 /// This method is stable with respect to spike locations: on merge, the weights of existing |
|
| 54 /// removed spikes is set to zero, new ones inserted at the end of the spike vector. |
|
| 55 /// They merge may also be performed by increasing the weights of the existing spikes, |
|
| 56 /// without inserting new spikes. |
|
| 57 fn merge_spikes<G>(&mut self, method: SpikeMergingMethod<F>, accept: G) -> usize |
|
| 58 where |
|
| 59 G: FnMut(&'_ Self) -> bool, |
|
| 60 { |
|
| 61 if method.enabled { |
|
| 62 self.do_merge_spikes_radius(method.radius, method.interp, accept) |
|
| 63 } else { |
|
| 64 0 |
|
| 65 } |
|
| 66 } |
|
| 67 |
|
| 68 /// Attempt to merge spikes based on a value and a fitness function. |
|
| 69 /// |
|
| 70 /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of |
|
| 71 /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` |
|
| 72 // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial |
|
| 73 /// `self` is returned. also the number of merges is returned; |
|
| 74 fn merge_spikes_fitness<G, H, V, O>( |
|
| 75 &mut self, |
|
| 76 method: SpikeMergingMethod<F>, |
|
| 77 value: G, |
|
| 78 fitness: H, |
|
| 79 ) -> (V, usize) |
|
| 80 where |
|
| 81 G: Fn(&'_ Self) -> V, |
|
| 82 H: Fn(&'_ V) -> O, |
|
| 83 O: PartialOrd, |
|
| 84 { |
|
| 85 let mut res = value(self); |
|
| 86 let initial_fitness = fitness(&res); |
|
| 87 let count = self.merge_spikes(method, |μ| { |
|
| 88 res = value(μ); |
|
| 89 fitness(&res) <= initial_fitness |
|
| 90 }); |
|
| 91 (res, count) |
|
| 92 } |
|
| 93 |
|
| 94 /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). |
|
| 95 /// |
|
| 96 /// This method implements [`SpikeMerging::merge_spikes`]. |
|
| 97 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, accept: G) -> usize |
|
| 98 where |
|
| 99 G: FnMut(&'_ Self) -> bool; |
|
| 100 } |
|
| 101 |
|
| 102 #[replace_float_literals(F::cast_from(literal))] |
|
| 103 impl<F: Float, const N: usize> DiscreteMeasure<Loc<F, N>, F> { |
|
| 104 /// Attempts to merge spikes with indices `i` and `j`. |
|
| 105 /// |
|
| 106 /// This assumes that the weights of the two spikes have already been checked not to be zero. |
|
| 107 /// |
|
| 108 /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. |
|
| 109 /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its |
|
| 110 /// return value. |
|
| 111 /// |
|
| 112 /// Returns the index of `self.spikes` storing the new spike. |
|
| 113 fn attempt_merge<G>( |
|
| 114 &mut self, |
|
| 115 i: usize, |
|
| 116 j: usize, |
|
| 117 interp: bool, |
|
| 118 accept: &mut G, |
|
| 119 ) -> Option<usize> |
|
| 120 where |
|
| 121 G: FnMut(&'_ Self) -> bool, |
|
| 122 { |
|
| 123 let &DeltaMeasure { x: xi, α: αi } = &self.spikes[i]; |
|
| 124 let &DeltaMeasure { x: xj, α: αj } = &self.spikes[j]; |
|
| 125 |
|
| 126 if interp { |
|
| 127 // Merge inplace |
|
| 128 self.spikes[i].α = 0.0; |
|
| 129 self.spikes[j].α = 0.0; |
|
| 130 let αia = αi.abs(); |
|
| 131 let αja = αj.abs(); |
|
| 132 self.spikes.push(DeltaMeasure { |
|
| 133 α: αi + αj, |
|
| 134 x: (xi * αia + xj * αja) / (αia + αja), |
|
| 135 }); |
|
| 136 if accept(self) { |
|
| 137 Some(self.spikes.len() - 1) |
|
| 138 } else { |
|
| 139 // Merge not accepted, restore modification |
|
| 140 self.spikes[i].α = αi; |
|
| 141 self.spikes[j].α = αj; |
|
| 142 self.spikes.pop(); |
|
| 143 None |
|
| 144 } |
|
| 145 } else { |
|
| 146 // Attempt merge inplace, first combination |
|
| 147 self.spikes[i].α = αi + αj; |
|
| 148 self.spikes[j].α = 0.0; |
|
| 149 if accept(self) { |
|
| 150 // Merge accepted |
|
| 151 Some(i) |
|
| 152 } else { |
|
| 153 // Attempt merge inplace, second combination |
|
| 154 self.spikes[i].α = 0.0; |
|
| 155 self.spikes[j].α = αi + αj; |
|
| 156 if accept(self) { |
|
| 157 // Merge accepted |
|
| 158 Some(j) |
|
| 159 } else { |
|
| 160 // Merge not accepted, restore modification |
|
| 161 self.spikes[i].α = αi; |
|
| 162 self.spikes[j].α = αj; |
|
| 163 None |
|
| 164 } |
|
| 165 } |
|
| 166 } |
|
| 167 } |
|
| 168 } |
|
| 169 |
|
| 170 /// Sorts a vector of indices into `slice` by `compare`. |
|
| 171 /// |
|
| 172 /// The closure `compare` operators on references to elements of `slice`. |
|
| 173 /// Returns the sorted vector of indices into `slice`. |
|
| 174 pub fn sort_indices_by<V, F>(slice: &[V], mut compare: F) -> Vec<usize> |
|
| 175 where |
|
| 176 F: FnMut(&V, &V) -> Ordering, |
|
| 177 { |
|
| 178 let mut indices = Vec::from_iter(0..slice.len()); |
|
| 179 indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); |
|
| 180 indices |
|
| 181 } |
|
| 182 |
|
| 183 #[replace_float_literals(F::cast_from(literal))] |
|
| 184 impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> { |
|
| 185 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize |
|
| 186 where |
|
| 187 G: FnMut(&'_ Self) -> bool, |
|
| 188 { |
|
| 189 // Sort by coordinate into an indexing array. |
|
| 190 let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { |
|
| 191 let &Loc([x1]) = &δ1.x; |
|
| 192 let &Loc([x2]) = &δ2.x; |
|
| 193 // nan-ignoring ordering of floats |
|
| 194 NaNLeast(x1).cmp(&NaNLeast(x2)) |
|
| 195 }); |
|
| 196 |
|
| 197 // Initialise result |
|
| 198 let mut count = 0; |
|
| 199 |
|
| 200 // Scan consecutive pairs and merge if close enough and accepted by `accept`. |
|
| 201 if indices.len() == 0 { |
|
| 202 return count; |
|
| 203 } |
|
| 204 for k in 0..(indices.len() - 1) { |
|
| 205 let i = indices[k]; |
|
| 206 let j = indices[k + 1]; |
|
| 207 let &DeltaMeasure { |
|
| 208 x: Loc([xi]), |
|
| 209 α: αi, |
|
| 210 } = &self.spikes[i]; |
|
| 211 let &DeltaMeasure { |
|
| 212 x: Loc([xj]), |
|
| 213 α: αj, |
|
| 214 } = &self.spikes[j]; |
|
| 215 debug_assert!(xi <= xj); |
|
| 216 // If close enough, attempt merging |
|
| 217 if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { |
|
| 218 if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) { |
|
| 219 // For this to work (the debug_assert! to not trigger above), the new |
|
| 220 // coordinate produced by attempt_merge has to be at most xj. |
|
| 221 indices[k + 1] = l; |
|
| 222 count += 1 |
|
| 223 } |
|
| 224 } |
|
| 225 } |
|
| 226 |
|
| 227 count |
|
| 228 } |
|
| 229 } |
|
| 230 |
|
| 231 /// Orders `δ1` and `δ1` according to the first coordinate. |
|
| 232 fn compare_first_coordinate<F: Float>( |
|
| 233 δ1: &DeltaMeasure<Loc<F, 2>, F>, |
|
| 234 δ2: &DeltaMeasure<Loc<F, 2>, F>, |
|
| 235 ) -> Ordering { |
|
| 236 let &Loc([x11, ..]) = &δ1.x; |
|
| 237 let &Loc([x21, ..]) = &δ2.x; |
|
| 238 // nan-ignoring ordering of floats |
|
| 239 NaNLeast(x11).cmp(&NaNLeast(x21)) |
|
| 240 } |
|
| 241 |
|
| 242 #[replace_float_literals(F::cast_from(literal))] |
|
| 243 impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> { |
|
| 244 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize |
|
| 245 where |
|
| 246 G: FnMut(&'_ Self) -> bool, |
|
| 247 { |
|
| 248 // Sort by first coordinate into an indexing array. |
|
| 249 let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); |
|
| 250 |
|
| 251 // Initialise result |
|
| 252 let mut count = 0; |
|
| 253 let mut start_scan_2nd = 0; |
|
| 254 |
|
| 255 // Scan in order |
|
| 256 if indices.len() == 0 { |
|
| 257 return count; |
|
| 258 } |
|
| 259 for k in 0..indices.len() - 1 { |
|
| 260 let i = indices[k]; |
|
| 261 let &DeltaMeasure { |
|
| 262 x: Loc([xi1, xi2]), |
|
| 263 α: αi, |
|
| 264 } = &self[i]; |
|
| 265 |
|
| 266 if αi == 0.0 { |
|
| 267 // Nothin to be done if the weight is already zero |
|
| 268 continue; |
|
| 269 } |
|
| 270 |
|
| 271 let mut closest = None; |
|
| 272 |
|
| 273 // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd` |
|
| 274 // the smallest invalid merging index on the previous loop iteration, because a |
|
| 275 // the _closest_ mergeable spike might have index less than `k` in `indices`, and a |
|
| 276 // merge with it might have not been attempted with this spike if a different closer |
|
| 277 // spike was discovered based on the second coordinate. |
|
| 278 'scan_2nd: for l in (start_scan_2nd + 1)..indices.len() { |
|
| 279 if l == k { |
|
| 280 // Do not attempt to merge a spike with itself |
|
| 281 continue; |
|
| 282 } |
|
| 283 let j = indices[l]; |
|
| 284 let &DeltaMeasure { |
|
| 285 x: Loc([xj1, xj2]), |
|
| 286 α: αj, |
|
| 287 } = &self[j]; |
|
| 288 |
|
| 289 if xj1 < xi1 - ρ { |
|
| 290 // Spike `j = indices[l]` has too low first coordinate. Update starting index |
|
| 291 // for next iteration, and continue scanning. |
|
| 292 start_scan_2nd = l; |
|
| 293 continue 'scan_2nd; |
|
| 294 } else if xj1 > xi1 + ρ { |
|
| 295 // Break out: spike `j = indices[l]` has already too high first coordinate, no |
|
| 296 // more close enough spikes can be found due to the sorting of `indices`. |
|
| 297 break 'scan_2nd; |
|
| 298 } |
|
| 299 |
|
| 300 // If also second coordinate is close enough, attempt merging if closer than |
|
| 301 // previously discovered mergeable spikes. |
|
| 302 let d2 = (xi2 - xj2).abs(); |
|
| 303 if αj != 0.0 && d2 <= ρ { |
|
| 304 let r1 = xi1 - xj1; |
|
| 305 let d = (d2 * d2 + r1 * r1).sqrt(); |
|
| 306 match closest { |
|
| 307 None => closest = Some((l, j, d)), |
|
| 308 Some((_, _, r)) if r > d => closest = Some((l, j, d)), |
|
| 309 _ => {} |
|
| 310 } |
|
| 311 } |
|
| 312 } |
|
| 313 |
|
| 314 // Attempt merging closest close-enough spike |
|
| 315 if let Some((l, j, _)) = closest { |
|
| 316 if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) { |
|
| 317 // If merge was succesfull, make new spike candidate for merging. |
|
| 318 indices[l] = n; |
|
| 319 count += 1; |
|
| 320 let compare = |i, j| compare_first_coordinate(&self.spikes[i], &self.spikes[j]); |
|
| 321 // Re-sort relevant range of indices |
|
| 322 if l < k { |
|
| 323 indices[l..k].sort_by(|&i, &j| compare(i, j)); |
|
| 324 } else { |
|
| 325 indices[k + 1..=l].sort_by(|&i, &j| compare(i, j)); |
|
| 326 } |
|
| 327 } |
|
| 328 } |
|
| 329 } |
|
| 330 |
|
| 331 count |
|
| 332 } |
|
| 333 } |
|