| |
1 /*! |
| |
2 Spike merging heuristics for [`DiscreteMeasure`]s. |
| |
3 |
| |
4 This module primarily provides the [`SpikeMerging`] trait, and within it, |
| |
5 the [`SpikeMerging::merge_spikes`] method. The trait is implemented on |
| |
6 [`DiscreteMeasure<Loc<N, F>, F>`]s in dimensions `N=1` and `N=2`. |
| |
7 */ |
| |
8 |
| |
9 use numeric_literals::replace_float_literals; |
| |
10 use serde::{Deserialize, Serialize}; |
| |
11 use std::cmp::Ordering; |
| |
12 //use clap::builder::{PossibleValuesParser, PossibleValue}; |
| |
13 use alg_tools::nanleast::NaNLeast; |
| |
14 |
| |
15 use super::delta::*; |
| |
16 use super::discrete::*; |
| |
17 use alg_tools::loc::Loc; |
| |
18 use alg_tools::types::*; |
| |
19 |
| |
20 /// Spike merging heuristic selection |
| |
21 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| |
22 #[allow(dead_code)] |
| |
23 pub struct SpikeMergingMethod<F> { |
| |
24 // Merging radius |
| |
25 pub radius: F, |
| |
26 // Enabled |
| |
27 pub enabled: bool, |
| |
28 // Interpolate merged points |
| |
29 pub interp: bool, |
| |
30 } |
| |
31 |
| |
32 #[replace_float_literals(F::cast_from(literal))] |
| |
33 impl<F: Float> Default for SpikeMergingMethod<F> { |
| |
34 fn default() -> Self { |
| |
35 SpikeMergingMethod { radius: 0.01, enabled: false, interp: true } |
| |
36 } |
| |
37 } |
| |
38 |
| |
39 /// Trait for dimension-dependent implementation of heuristic peak merging strategies. |
| |
40 pub trait SpikeMerging<F> { |
| |
41 /// Attempt spike merging according to [`SpikeMerging`] method. |
| |
42 /// |
| |
43 /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure |
| |
44 /// `accept` if any merging is performed. The closure should accept as its only parameter a |
| |
45 /// new candidate measure (it will generally be internally mutated `self`, although this is |
| |
46 /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of |
| |
47 /// an arbitrary value. This method will return that value for the *last* accepted merge, or |
| |
48 /// [`None`] if no merge was accepted. |
| |
49 /// |
| |
50 /// This method is stable with respect to spike locations: on merge, the weights of existing |
| |
51 /// removed spikes is set to zero, new ones inserted at the end of the spike vector. |
| |
52 /// They merge may also be performed by increasing the weights of the existing spikes, |
| |
53 /// without inserting new spikes. |
| |
54 fn merge_spikes<G>(&mut self, method: SpikeMergingMethod<F>, accept: G) -> usize |
| |
55 where |
| |
56 G: FnMut(&'_ Self) -> bool, |
| |
57 { |
| |
58 if method.enabled { |
| |
59 self.do_merge_spikes_radius(method.radius, method.interp, accept) |
| |
60 } else { |
| |
61 0 |
| |
62 } |
| |
63 } |
| |
64 |
| |
65 /// Attempt to merge spikes based on a value and a fitness function. |
| |
66 /// |
| |
67 /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of |
| |
68 /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` |
| |
69 // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial |
| |
70 /// `self` is returned. also the number of merges is returned; |
| |
71 fn merge_spikes_fitness<G, H, V, O>( |
| |
72 &mut self, |
| |
73 method: SpikeMergingMethod<F>, |
| |
74 value: G, |
| |
75 fitness: H, |
| |
76 ) -> (V, usize) |
| |
77 where |
| |
78 G: Fn(&'_ Self) -> V, |
| |
79 H: Fn(&'_ V) -> O, |
| |
80 O: PartialOrd, |
| |
81 { |
| |
82 let mut res = value(self); |
| |
83 let initial_fitness = fitness(&res); |
| |
84 let count = self.merge_spikes(method, |μ| { |
| |
85 res = value(μ); |
| |
86 fitness(&res) <= initial_fitness |
| |
87 }); |
| |
88 (res, count) |
| |
89 } |
| |
90 |
| |
91 /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). |
| |
92 /// |
| |
93 /// This method implements [`SpikeMerging::merge_spikes`]. |
| |
94 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, accept: G) -> usize |
| |
95 where |
| |
96 G: FnMut(&'_ Self) -> bool; |
| |
97 } |
| |
98 |
| |
99 #[replace_float_literals(F::cast_from(literal))] |
| |
100 impl<F: Float, const N: usize> DiscreteMeasure<Loc<N, F>, F> { |
| |
101 /// Attempts to merge spikes with indices `i` and `j`. |
| |
102 /// |
| |
103 /// This assumes that the weights of the two spikes have already been checked not to be zero. |
| |
104 /// |
| |
105 /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. |
| |
106 /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its |
| |
107 /// return value. |
| |
108 /// |
| |
109 /// Returns the index of `self.spikes` storing the new spike. |
| |
110 fn attempt_merge<G>( |
| |
111 &mut self, |
| |
112 i: usize, |
| |
113 j: usize, |
| |
114 interp: bool, |
| |
115 accept: &mut G, |
| |
116 ) -> Option<usize> |
| |
117 where |
| |
118 G: FnMut(&'_ Self) -> bool, |
| |
119 { |
| |
120 let &DeltaMeasure { x: xi, α: αi } = &self.spikes[i]; |
| |
121 let &DeltaMeasure { x: xj, α: αj } = &self.spikes[j]; |
| |
122 |
| |
123 if interp { |
| |
124 // Merge inplace |
| |
125 self.spikes[i].α = 0.0; |
| |
126 self.spikes[j].α = 0.0; |
| |
127 let αia = αi.abs(); |
| |
128 let αja = αj.abs(); |
| |
129 self.spikes |
| |
130 .push(DeltaMeasure { α: αi + αj, x: (xi * αia + xj * αja) / (αia + αja) }); |
| |
131 if accept(self) { |
| |
132 Some(self.spikes.len() - 1) |
| |
133 } else { |
| |
134 // Merge not accepted, restore modification |
| |
135 self.spikes[i].α = αi; |
| |
136 self.spikes[j].α = αj; |
| |
137 self.spikes.pop(); |
| |
138 None |
| |
139 } |
| |
140 } else { |
| |
141 // Attempt merge inplace, first combination |
| |
142 self.spikes[i].α = αi + αj; |
| |
143 self.spikes[j].α = 0.0; |
| |
144 if accept(self) { |
| |
145 // Merge accepted |
| |
146 Some(i) |
| |
147 } else { |
| |
148 // Attempt merge inplace, second combination |
| |
149 self.spikes[i].α = 0.0; |
| |
150 self.spikes[j].α = αi + αj; |
| |
151 if accept(self) { |
| |
152 // Merge accepted |
| |
153 Some(j) |
| |
154 } else { |
| |
155 // Merge not accepted, restore modification |
| |
156 self.spikes[i].α = αi; |
| |
157 self.spikes[j].α = αj; |
| |
158 None |
| |
159 } |
| |
160 } |
| |
161 } |
| |
162 } |
| |
163 } |
| |
164 |
| |
165 /// Sorts a vector of indices into `slice` by `compare`. |
| |
166 /// |
| |
167 /// The closure `compare` operators on references to elements of `slice`. |
| |
168 /// Returns the sorted vector of indices into `slice`. |
| |
169 pub fn sort_indices_by<V, F>(slice: &[V], mut compare: F) -> Vec<usize> |
| |
170 where |
| |
171 F: FnMut(&V, &V) -> Ordering, |
| |
172 { |
| |
173 let mut indices = Vec::from_iter(0..slice.len()); |
| |
174 indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); |
| |
175 indices |
| |
176 } |
| |
177 |
| |
178 #[replace_float_literals(F::cast_from(literal))] |
| |
179 impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<1, F>, F> { |
| |
180 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize |
| |
181 where |
| |
182 G: FnMut(&'_ Self) -> bool, |
| |
183 { |
| |
184 // Sort by coordinate into an indexing array. |
| |
185 let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { |
| |
186 let &Loc([x1]) = &δ1.x; |
| |
187 let &Loc([x2]) = &δ2.x; |
| |
188 // nan-ignoring ordering of floats |
| |
189 NaNLeast(x1).cmp(&NaNLeast(x2)) |
| |
190 }); |
| |
191 |
| |
192 // Initialise result |
| |
193 let mut count = 0; |
| |
194 |
| |
195 // Scan consecutive pairs and merge if close enough and accepted by `accept`. |
| |
196 if indices.len() == 0 { |
| |
197 return count; |
| |
198 } |
| |
199 for k in 0..(indices.len() - 1) { |
| |
200 let i = indices[k]; |
| |
201 let j = indices[k + 1]; |
| |
202 let &DeltaMeasure { x: Loc([xi]), α: αi } = &self.spikes[i]; |
| |
203 let &DeltaMeasure { x: Loc([xj]), α: αj } = &self.spikes[j]; |
| |
204 debug_assert!(xi <= xj); |
| |
205 // If close enough, attempt merging |
| |
206 if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { |
| |
207 if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) { |
| |
208 // For this to work (the debug_assert! to not trigger above), the new |
| |
209 // coordinate produced by attempt_merge has to be at most xj. |
| |
210 indices[k + 1] = l; |
| |
211 count += 1 |
| |
212 } |
| |
213 } |
| |
214 } |
| |
215 |
| |
216 count |
| |
217 } |
| |
218 } |
| |
219 |
| |
220 /// Orders `δ1` and `δ1` according to the first coordinate. |
| |
221 fn compare_first_coordinate<F: Float>( |
| |
222 δ1: &DeltaMeasure<Loc<2, F>, F>, |
| |
223 δ2: &DeltaMeasure<Loc<2, F>, F>, |
| |
224 ) -> Ordering { |
| |
225 let &Loc([x11, ..]) = &δ1.x; |
| |
226 let &Loc([x21, ..]) = &δ2.x; |
| |
227 // nan-ignoring ordering of floats |
| |
228 NaNLeast(x11).cmp(&NaNLeast(x21)) |
| |
229 } |
| |
230 |
| |
231 #[replace_float_literals(F::cast_from(literal))] |
| |
232 impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<2, F>, F> { |
| |
233 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize |
| |
234 where |
| |
235 G: FnMut(&'_ Self) -> bool, |
| |
236 { |
| |
237 // Sort by first coordinate into an indexing array. |
| |
238 let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); |
| |
239 |
| |
240 // Initialise result |
| |
241 let mut count = 0; |
| |
242 let mut start_scan_2nd = 0; |
| |
243 |
| |
244 // Scan in order |
| |
245 if indices.len() == 0 { |
| |
246 return count; |
| |
247 } |
| |
248 for k in 0..indices.len() - 1 { |
| |
249 let i = indices[k]; |
| |
250 let &DeltaMeasure { x: Loc([xi1, xi2]), α: αi } = &self[i]; |
| |
251 |
| |
252 if αi == 0.0 { |
| |
253 // Nothin to be done if the weight is already zero |
| |
254 continue; |
| |
255 } |
| |
256 |
| |
257 let mut closest = None; |
| |
258 |
| |
259 // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd` |
| |
260 // the smallest invalid merging index on the previous loop iteration, because a |
| |
261 // the _closest_ mergeable spike might have index less than `k` in `indices`, and a |
| |
262 // merge with it might have not been attempted with this spike if a different closer |
| |
263 // spike was discovered based on the second coordinate. |
| |
264 'scan_2nd: for l in (start_scan_2nd + 1)..indices.len() { |
| |
265 if l == k { |
| |
266 // Do not attempt to merge a spike with itself |
| |
267 continue; |
| |
268 } |
| |
269 let j = indices[l]; |
| |
270 let &DeltaMeasure { x: Loc([xj1, xj2]), α: αj } = &self[j]; |
| |
271 |
| |
272 if xj1 < xi1 - ρ { |
| |
273 // Spike `j = indices[l]` has too low first coordinate. Update starting index |
| |
274 // for next iteration, and continue scanning. |
| |
275 start_scan_2nd = l; |
| |
276 continue 'scan_2nd; |
| |
277 } else if xj1 > xi1 + ρ { |
| |
278 // Break out: spike `j = indices[l]` has already too high first coordinate, no |
| |
279 // more close enough spikes can be found due to the sorting of `indices`. |
| |
280 break 'scan_2nd; |
| |
281 } |
| |
282 |
| |
283 // If also second coordinate is close enough, attempt merging if closer than |
| |
284 // previously discovered mergeable spikes. |
| |
285 let d2 = (xi2 - xj2).abs(); |
| |
286 if αj != 0.0 && d2 <= ρ { |
| |
287 let r1 = xi1 - xj1; |
| |
288 let d = (d2 * d2 + r1 * r1).sqrt(); |
| |
289 match closest { |
| |
290 None => closest = Some((l, j, d)), |
| |
291 Some((_, _, r)) if r > d => closest = Some((l, j, d)), |
| |
292 _ => {} |
| |
293 } |
| |
294 } |
| |
295 } |
| |
296 |
| |
297 // Attempt merging closest close-enough spike |
| |
298 if let Some((l, j, _)) = closest { |
| |
299 if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) { |
| |
300 // If merge was succesfull, make new spike candidate for merging. |
| |
301 indices[l] = n; |
| |
302 count += 1; |
| |
303 let compare = |i, j| compare_first_coordinate(&self.spikes[i], &self.spikes[j]); |
| |
304 // Re-sort relevant range of indices |
| |
305 if l < k { |
| |
306 indices[l..k].sort_by(|&i, &j| compare(i, j)); |
| |
307 } else { |
| |
308 indices[k + 1..=l].sort_by(|&i, &j| compare(i, j)); |
| |
309 } |
| |
310 } |
| |
311 } |
| |
312 } |
| |
313 |
| |
314 count |
| |
315 } |
| |
316 } |