|
1 /*! |
|
2 Spike merging heuristics for [`DiscreteMeasure`]s. |
|
3 |
|
4 This module primarily provides the [`SpikeMerging`] trait, and within it, |
|
5 the [`SpikeMerging::merge_spikes`] method. The trait is implemented on |
|
6 [`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`. |
|
7 */ |
|
8 |
|
9 use numeric_literals::replace_float_literals; |
|
10 use std::cmp::Ordering; |
|
11 use serde::{Serialize, Deserialize}; |
|
12 //use clap::builder::{PossibleValuesParser, PossibleValue}; |
|
13 use alg_tools::nanleast::NaNLeast; |
|
14 |
|
15 use crate::types::*; |
|
16 use super::delta::*; |
|
17 use super::discrete::*; |
|
18 |
|
19 /// Spike merging heuristic selection |
|
20 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
21 #[allow(dead_code)] |
|
22 pub enum SpikeMergingMethod<F> { |
|
23 /// Try to merge spikes within a given radius of eachother |
|
24 HeuristicRadius(F), |
|
25 /// No merging |
|
26 None, |
|
27 } |
|
28 |
|
29 // impl<F : Float> SpikeMergingMethod<F> { |
|
30 // /// This is for [`clap`] to display command line help. |
|
31 // pub fn value_parser() -> PossibleValuesParser { |
|
32 // PossibleValuesParser::new([ |
|
33 // PossibleValue::new("none").help("No merging"), |
|
34 // PossibleValue::new("<radius>").help("Heuristic merging within indicated radius") |
|
35 // ]) |
|
36 // } |
|
37 // } |
|
38 |
|
39 impl<F : ClapFloat> std::fmt::Display for SpikeMergingMethod<F> { |
|
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { |
|
41 match self { |
|
42 Self::None => write!(f, "none"), |
|
43 Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f), |
|
44 } |
|
45 } |
|
46 } |
|
47 |
|
48 impl<F : ClapFloat> std::str::FromStr for SpikeMergingMethod<F> { |
|
49 type Err = F::Err; |
|
50 |
|
51 fn from_str(s: &str) -> Result<Self, Self::Err> { |
|
52 if s == "none" { |
|
53 Ok(Self::None) |
|
54 } else { |
|
55 Ok(Self::HeuristicRadius(F::from_str(s)?)) |
|
56 } |
|
57 } |
|
58 } |
|
59 |
|
60 #[replace_float_literals(F::cast_from(literal))] |
|
61 impl<F : Float> Default for SpikeMergingMethod<F> { |
|
62 fn default() -> Self { |
|
63 SpikeMergingMethod::HeuristicRadius(0.02) |
|
64 } |
|
65 } |
|
66 |
|
67 /// Trait for dimension-dependent implementation of heuristic peak merging strategies. |
|
68 pub trait SpikeMerging<F> { |
|
69 /// Attempt spike merging according to [`SpikeMerging`] method. |
|
70 /// |
|
71 /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure |
|
72 /// `accept` if any merging is performed. The closure should accept as its only parameter a |
|
73 /// new candidate measure (it will generally be internally mutated `self`, although this is |
|
74 /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of |
|
75 /// an arbitrary value. This method will return that value for the *last* accepted merge, or |
|
76 /// [`None`] if no merge was accepted. |
|
77 /// |
|
78 /// This method is stable with respect to spike locations: on merge, the weight of existing |
|
79 /// spikes is set to zero, and a new one inserted at the end of the spike vector. |
|
80 fn merge_spikes<G, V>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> Option<V> |
|
81 where G : Fn(&'_ Self) -> Option<V> { |
|
82 match method { |
|
83 SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept), |
|
84 SpikeMergingMethod::None => None, |
|
85 } |
|
86 } |
|
87 |
|
88 /// Attempt to merge spikes based on a value and a fitness function. |
|
89 /// |
|
90 /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of |
|
91 /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` |
|
92 // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial |
|
93 /// `self` is returned. |
|
94 fn merge_spikes_fitness<G, H, V, O>( |
|
95 &mut self, |
|
96 method : SpikeMergingMethod<F>, |
|
97 value : G, |
|
98 fitness : H |
|
99 ) -> V |
|
100 where G : Fn(&'_ Self) -> V, |
|
101 H : Fn(&'_ V) -> O, |
|
102 O : PartialOrd { |
|
103 let initial_res = value(self); |
|
104 let initial_fitness = fitness(&initial_res); |
|
105 self.merge_spikes(method, |μ| { |
|
106 let res = value(μ); |
|
107 (fitness(&res) <= initial_fitness).then_some(res) |
|
108 }).unwrap_or(initial_res) |
|
109 } |
|
110 |
|
111 /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). |
|
112 /// |
|
113 /// This method implements [`SpikeMerging::merge_spikes`] for |
|
114 /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are |
|
115 /// as for that method. |
|
116 fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> |
|
117 where G : Fn(&'_ Self) -> Option<V>; |
|
118 } |
|
119 |
|
120 #[replace_float_literals(F::cast_from(literal))] |
|
121 impl<F : Float, const N : usize> DiscreteMeasure<Loc<F, N>, F> { |
|
122 /// Attempts to merge spikes with indices `i` and `j`. |
|
123 /// |
|
124 /// This assumes that the weights of the two spikes have already been checked not to be zero. |
|
125 /// |
|
126 /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. |
|
127 /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its |
|
128 /// return value. |
|
129 fn attempt_merge<G, V>( |
|
130 &mut self, |
|
131 res : &mut Option<V>, |
|
132 i : usize, |
|
133 j : usize, |
|
134 accept : &G |
|
135 ) -> bool |
|
136 where G : Fn(&'_ Self) -> Option<V> { |
|
137 let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; |
|
138 let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; |
|
139 |
|
140 // Merge inplace |
|
141 self.spikes[i].α = 0.0; |
|
142 self.spikes[j].α = 0.0; |
|
143 //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }); |
|
144 self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) }); |
|
145 match accept(self) { |
|
146 some@Some(..) => { |
|
147 // Merge accepted, update our return value |
|
148 *res = some; |
|
149 // On next iteration process the newly merged spike. |
|
150 //indices[k+1] = self.spikes.len() - 1; |
|
151 true |
|
152 }, |
|
153 None => { |
|
154 // Merge not accepted, restore modification |
|
155 self.spikes[i].α = αi; |
|
156 self.spikes[j].α = αj; |
|
157 self.spikes.pop(); |
|
158 false |
|
159 } |
|
160 } |
|
161 } |
|
162 |
|
163 /* |
|
164 /// Attempts to merge spikes with indices i and j, acceptance through a delta. |
|
165 fn attempt_merge_change<G, V>( |
|
166 &mut self, |
|
167 res : &mut Option<V>, |
|
168 i : usize, |
|
169 j : usize, |
|
170 accept_change : &G |
|
171 ) -> bool |
|
172 where G : Fn(&'_ Self) -> Option<V> { |
|
173 let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; |
|
174 let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; |
|
175 let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }; |
|
176 let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into(); |
|
177 |
|
178 match accept_change(&λ) { |
|
179 some@Some(..) => { |
|
180 // Merge accepted, update our return value |
|
181 *res = some; |
|
182 self.spikes[i].α = 0.0; |
|
183 self.spikes[j].α = 0.0; |
|
184 self.spikes.push(δ); |
|
185 true |
|
186 }, |
|
187 None => { |
|
188 false |
|
189 } |
|
190 } |
|
191 }*/ |
|
192 |
|
193 } |
|
194 |
|
195 /// Sorts a vector of indices into `slice` by `compare`. |
|
196 /// |
|
197 /// The closure `compare` operators on references to elements of `slice`. |
|
198 /// Returns the sorted vector of indices into `slice`. |
|
199 pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize> |
|
200 where F : FnMut(&V, &V) -> Ordering |
|
201 { |
|
202 let mut indices = Vec::from_iter(0..slice.len()); |
|
203 indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); |
|
204 indices |
|
205 } |
|
206 |
|
207 #[replace_float_literals(F::cast_from(literal))] |
|
208 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> { |
|
209 |
|
210 fn do_merge_spikes_radius<G, V>( |
|
211 &mut self, |
|
212 ρ : F, |
|
213 accept : G |
|
214 ) -> Option<V> |
|
215 where G : Fn(&'_ Self) -> Option<V> { |
|
216 // Sort by coordinate into an indexing array. |
|
217 let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { |
|
218 let &Loc([x1]) = &δ1.x; |
|
219 let &Loc([x2]) = &δ2.x; |
|
220 // nan-ignoring ordering of floats |
|
221 NaNLeast(x1).cmp(&NaNLeast(x2)) |
|
222 }); |
|
223 |
|
224 // Initialise result |
|
225 let mut res = None; |
|
226 |
|
227 // Scan consecutive pairs and merge if close enough and accepted by `accept`. |
|
228 if indices.len() == 0 { |
|
229 return res |
|
230 } |
|
231 for k in 0..(indices.len()-1) { |
|
232 let i = indices[k]; |
|
233 let j = indices[k+1]; |
|
234 let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i]; |
|
235 let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j]; |
|
236 debug_assert!(xi <= xj); |
|
237 // If close enough, attempt merging |
|
238 if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { |
|
239 if self.attempt_merge(&mut res, i, j, &accept) { |
|
240 indices[k+1] = self.spikes.len() - 1; |
|
241 } |
|
242 } |
|
243 } |
|
244 |
|
245 res |
|
246 } |
|
247 } |
|
248 |
|
249 /// Orders `δ1` and `δ1` according to the first coordinate. |
|
250 fn compare_first_coordinate<F : Float>( |
|
251 δ1 : &DeltaMeasure<Loc<F, 2>, F>, |
|
252 δ2 : &DeltaMeasure<Loc<F, 2>, F> |
|
253 ) -> Ordering { |
|
254 let &Loc([x11, ..]) = &δ1.x; |
|
255 let &Loc([x21, ..]) = &δ2.x; |
|
256 // nan-ignoring ordering of floats |
|
257 NaNLeast(x11).cmp(&NaNLeast(x21)) |
|
258 } |
|
259 |
|
260 #[replace_float_literals(F::cast_from(literal))] |
|
261 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> { |
|
262 |
|
263 fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> |
|
264 where G : Fn(&'_ Self) -> Option<V> { |
|
265 // Sort by first coordinate into an indexing array. |
|
266 let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); |
|
267 |
|
268 // Initialise result |
|
269 let mut res = None; |
|
270 let mut start_scan_2nd = 0; |
|
271 |
|
272 // Scan in order |
|
273 if indices.len() == 0 { |
|
274 return res |
|
275 } |
|
276 for k in 0..indices.len()-1 { |
|
277 let i = indices[k]; |
|
278 let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i]; |
|
279 |
|
280 if αi == 0.0 { |
|
281 // Nothin to be done if the weight is already zero |
|
282 continue |
|
283 } |
|
284 |
|
285 let mut closest = None; |
|
286 |
|
287 // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd` |
|
288 // the smallest invalid merging index on the previous loop iteration, because a |
|
289 // the _closest_ mergeable spike might have index less than `k` in `indices`, and a |
|
290 // merge with it might have not been attempted with this spike if a different closer |
|
291 // spike was discovered based on the second coordinate. |
|
292 'scan_2nd: for l in (start_scan_2nd+1)..indices.len() { |
|
293 if l == k { |
|
294 // Do not attempt to merge a spike with itself |
|
295 continue |
|
296 } |
|
297 let j = indices[l]; |
|
298 let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j]; |
|
299 |
|
300 if xj1 < xi1 - ρ { |
|
301 // Spike `j = indices[l]` has too low first coordinate. Update starting index |
|
302 // for next iteration, and continue scanning. |
|
303 start_scan_2nd = l; |
|
304 continue 'scan_2nd |
|
305 } else if xj1 > xi1 + ρ { |
|
306 // Break out: spike `j = indices[l]` has already too high first coordinate, no |
|
307 // more close enough spikes can be found due to the sorting of `indices`. |
|
308 break 'scan_2nd |
|
309 } |
|
310 |
|
311 // If also second coordinate is close enough, attempt merging if closer than |
|
312 // previously discovered mergeable spikes. |
|
313 let d2 = (xi2-xj2).abs(); |
|
314 if αj != 0.0 && d2 <= ρ { |
|
315 let r1 = xi1-xj1; |
|
316 let d = (d2*d2 + r1*r1).sqrt(); |
|
317 match closest { |
|
318 None => closest = Some((l, j, d)), |
|
319 Some((_, _, r)) if r > d => closest = Some((l, j, d)), |
|
320 _ => {}, |
|
321 } |
|
322 } |
|
323 } |
|
324 |
|
325 // Attempt merging closest close-enough spike |
|
326 if let Some((l, j, _)) = closest { |
|
327 if self.attempt_merge(&mut res, i, j, &accept) { |
|
328 // If merge was succesfull, make new spike candidate for merging. |
|
329 indices[l] = self.spikes.len() - 1; |
|
330 let compare = |i, j| compare_first_coordinate(&self.spikes[i], |
|
331 &self.spikes[j]); |
|
332 // Re-sort relevant range of indices |
|
333 if l < k { |
|
334 indices[l..k].sort_by(|&i, &j| compare(i, j)); |
|
335 } else { |
|
336 indices[k+1..=l].sort_by(|&i, &j| compare(i, j)); |
|
337 } |
|
338 } |
|
339 } |
|
340 } |
|
341 |
|
342 res |
|
343 } |
|
344 } |
|
345 |