| |
1 /*! |
| |
2 Spike merging heuristics for [`DiscreteMeasure`]s. |
| |
3 |
| |
4 This module primarily provides the [`SpikeMerging`] trait, and within it, |
| |
5 the [`SpikeMerging::merge_spikes`] method. The trait is implemented on |
| |
6 [`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`. |
| |
7 */ |
| |
8 |
| |
9 use numeric_literals::replace_float_literals; |
| |
10 use std::cmp::Ordering; |
| |
11 use serde::{Serialize, Deserialize}; |
| |
12 //use clap::builder::{PossibleValuesParser, PossibleValue}; |
| |
13 use alg_tools::nanleast::NaNLeast; |
| |
14 |
| |
15 use crate::types::*; |
| |
16 use super::delta::*; |
| |
17 use super::discrete::*; |
| |
18 |
| |
19 /// Spike merging heuristic selection |
| |
20 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| |
21 #[allow(dead_code)] |
| |
22 pub enum SpikeMergingMethod<F> { |
| |
23 /// Try to merge spikes within a given radius of eachother |
| |
24 HeuristicRadius(F), |
| |
25 /// No merging |
| |
26 None, |
| |
27 } |
| |
28 |
| |
29 // impl<F : Float> SpikeMergingMethod<F> { |
| |
30 // /// This is for [`clap`] to display command line help. |
| |
31 // pub fn value_parser() -> PossibleValuesParser { |
| |
32 // PossibleValuesParser::new([ |
| |
33 // PossibleValue::new("none").help("No merging"), |
| |
34 // PossibleValue::new("<radius>").help("Heuristic merging within indicated radius") |
| |
35 // ]) |
| |
36 // } |
| |
37 // } |
| |
38 |
| |
39 impl<F : ClapFloat> std::fmt::Display for SpikeMergingMethod<F> { |
| |
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { |
| |
41 match self { |
| |
42 Self::None => write!(f, "none"), |
| |
43 Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f), |
| |
44 } |
| |
45 } |
| |
46 } |
| |
47 |
| |
48 impl<F : ClapFloat> std::str::FromStr for SpikeMergingMethod<F> { |
| |
49 type Err = F::Err; |
| |
50 |
| |
51 fn from_str(s: &str) -> Result<Self, Self::Err> { |
| |
52 if s == "none" { |
| |
53 Ok(Self::None) |
| |
54 } else { |
| |
55 Ok(Self::HeuristicRadius(F::from_str(s)?)) |
| |
56 } |
| |
57 } |
| |
58 } |
| |
59 |
| |
60 #[replace_float_literals(F::cast_from(literal))] |
| |
61 impl<F : Float> Default for SpikeMergingMethod<F> { |
| |
62 fn default() -> Self { |
| |
63 SpikeMergingMethod::HeuristicRadius(0.02) |
| |
64 } |
| |
65 } |
| |
66 |
| |
67 /// Trait for dimension-dependent implementation of heuristic peak merging strategies. |
| |
68 pub trait SpikeMerging<F> { |
| |
69 /// Attempt spike merging according to [`SpikeMerging`] method. |
| |
70 /// |
| |
71 /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure |
| |
72 /// `accept` if any merging is performed. The closure should accept as its only parameter a |
| |
73 /// new candidate measure (it will generally be internally mutated `self`, although this is |
| |
74 /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of |
| |
75 /// an arbitrary value. This method will return that value for the *last* accepted merge, or |
| |
76 /// [`None`] if no merge was accepted. |
| |
77 /// |
| |
78 /// This method is stable with respect to spike locations: on merge, the weight of existing |
| |
79 /// spikes is set to zero, and a new one inserted at the end of the spike vector. |
| |
80 fn merge_spikes<G, V>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> Option<V> |
| |
81 where G : Fn(&'_ Self) -> Option<V> { |
| |
82 match method { |
| |
83 SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept), |
| |
84 SpikeMergingMethod::None => None, |
| |
85 } |
| |
86 } |
| |
87 |
| |
88 /// Attempt to merge spikes based on a value and a fitness function. |
| |
89 /// |
| |
90 /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of |
| |
91 /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` |
| |
92 // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial |
| |
93 /// `self` is returned. |
| |
94 fn merge_spikes_fitness<G, H, V, O>( |
| |
95 &mut self, |
| |
96 method : SpikeMergingMethod<F>, |
| |
97 value : G, |
| |
98 fitness : H |
| |
99 ) -> V |
| |
100 where G : Fn(&'_ Self) -> V, |
| |
101 H : Fn(&'_ V) -> O, |
| |
102 O : PartialOrd { |
| |
103 let initial_res = value(self); |
| |
104 let initial_fitness = fitness(&initial_res); |
| |
105 self.merge_spikes(method, |μ| { |
| |
106 let res = value(μ); |
| |
107 (fitness(&res) <= initial_fitness).then_some(res) |
| |
108 }).unwrap_or(initial_res) |
| |
109 } |
| |
110 |
| |
111 /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). |
| |
112 /// |
| |
113 /// This method implements [`SpikeMerging::merge_spikes`] for |
| |
114 /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are |
| |
115 /// as for that method. |
| |
116 fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> |
| |
117 where G : Fn(&'_ Self) -> Option<V>; |
| |
118 } |
| |
119 |
| |
120 #[replace_float_literals(F::cast_from(literal))] |
| |
121 impl<F : Float, const N : usize> DiscreteMeasure<Loc<F, N>, F> { |
| |
122 /// Attempts to merge spikes with indices `i` and `j`. |
| |
123 /// |
| |
124 /// This assumes that the weights of the two spikes have already been checked not to be zero. |
| |
125 /// |
| |
126 /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. |
| |
127 /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its |
| |
128 /// return value. |
| |
129 fn attempt_merge<G, V>( |
| |
130 &mut self, |
| |
131 res : &mut Option<V>, |
| |
132 i : usize, |
| |
133 j : usize, |
| |
134 accept : &G |
| |
135 ) -> bool |
| |
136 where G : Fn(&'_ Self) -> Option<V> { |
| |
137 let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; |
| |
138 let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; |
| |
139 |
| |
140 // Merge inplace |
| |
141 self.spikes[i].α = 0.0; |
| |
142 self.spikes[j].α = 0.0; |
| |
143 //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }); |
| |
144 self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) }); |
| |
145 match accept(self) { |
| |
146 some@Some(..) => { |
| |
147 // Merge accepted, update our return value |
| |
148 *res = some; |
| |
149 // On next iteration process the newly merged spike. |
| |
150 //indices[k+1] = self.spikes.len() - 1; |
| |
151 true |
| |
152 }, |
| |
153 None => { |
| |
154 // Merge not accepted, restore modification |
| |
155 self.spikes[i].α = αi; |
| |
156 self.spikes[j].α = αj; |
| |
157 self.spikes.pop(); |
| |
158 false |
| |
159 } |
| |
160 } |
| |
161 } |
| |
162 |
| |
163 /* |
| |
164 /// Attempts to merge spikes with indices i and j, acceptance through a delta. |
| |
165 fn attempt_merge_change<G, V>( |
| |
166 &mut self, |
| |
167 res : &mut Option<V>, |
| |
168 i : usize, |
| |
169 j : usize, |
| |
170 accept_change : &G |
| |
171 ) -> bool |
| |
172 where G : Fn(&'_ Self) -> Option<V> { |
| |
173 let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; |
| |
174 let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; |
| |
175 let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }; |
| |
176 let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into(); |
| |
177 |
| |
178 match accept_change(&λ) { |
| |
179 some@Some(..) => { |
| |
180 // Merge accepted, update our return value |
| |
181 *res = some; |
| |
182 self.spikes[i].α = 0.0; |
| |
183 self.spikes[j].α = 0.0; |
| |
184 self.spikes.push(δ); |
| |
185 true |
| |
186 }, |
| |
187 None => { |
| |
188 false |
| |
189 } |
| |
190 } |
| |
191 }*/ |
| |
192 |
| |
193 } |
| |
194 |
| |
195 /// Sorts a vector of indices into `slice` by `compare`. |
| |
196 /// |
| |
197 /// The closure `compare` operators on references to elements of `slice`. |
| |
198 /// Returns the sorted vector of indices into `slice`. |
| |
199 pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize> |
| |
200 where F : FnMut(&V, &V) -> Ordering |
| |
201 { |
| |
202 let mut indices = Vec::from_iter(0..slice.len()); |
| |
203 indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); |
| |
204 indices |
| |
205 } |
| |
206 |
| |
207 #[replace_float_literals(F::cast_from(literal))] |
| |
208 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> { |
| |
209 |
| |
210 fn do_merge_spikes_radius<G, V>( |
| |
211 &mut self, |
| |
212 ρ : F, |
| |
213 accept : G |
| |
214 ) -> Option<V> |
| |
215 where G : Fn(&'_ Self) -> Option<V> { |
| |
216 // Sort by coordinate into an indexing array. |
| |
217 let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { |
| |
218 let &Loc([x1]) = &δ1.x; |
| |
219 let &Loc([x2]) = &δ2.x; |
| |
220 // nan-ignoring ordering of floats |
| |
221 NaNLeast(x1).cmp(&NaNLeast(x2)) |
| |
222 }); |
| |
223 |
| |
224 // Initialise result |
| |
225 let mut res = None; |
| |
226 |
| |
227 // Scan consecutive pairs and merge if close enough and accepted by `accept`. |
| |
228 if indices.len() == 0 { |
| |
229 return res |
| |
230 } |
| |
231 for k in 0..(indices.len()-1) { |
| |
232 let i = indices[k]; |
| |
233 let j = indices[k+1]; |
| |
234 let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i]; |
| |
235 let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j]; |
| |
236 debug_assert!(xi <= xj); |
| |
237 // If close enough, attempt merging |
| |
238 if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { |
| |
239 if self.attempt_merge(&mut res, i, j, &accept) { |
| |
240 indices[k+1] = self.spikes.len() - 1; |
| |
241 } |
| |
242 } |
| |
243 } |
| |
244 |
| |
245 res |
| |
246 } |
| |
247 } |
| |
248 |
| |
249 /// Orders `δ1` and `δ1` according to the first coordinate. |
| |
250 fn compare_first_coordinate<F : Float>( |
| |
251 δ1 : &DeltaMeasure<Loc<F, 2>, F>, |
| |
252 δ2 : &DeltaMeasure<Loc<F, 2>, F> |
| |
253 ) -> Ordering { |
| |
254 let &Loc([x11, ..]) = &δ1.x; |
| |
255 let &Loc([x21, ..]) = &δ2.x; |
| |
256 // nan-ignoring ordering of floats |
| |
257 NaNLeast(x11).cmp(&NaNLeast(x21)) |
| |
258 } |
| |
259 |
| |
260 #[replace_float_literals(F::cast_from(literal))] |
| |
261 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> { |
| |
262 |
| |
263 fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> |
| |
264 where G : Fn(&'_ Self) -> Option<V> { |
| |
265 // Sort by first coordinate into an indexing array. |
| |
266 let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); |
| |
267 |
| |
268 // Initialise result |
| |
269 let mut res = None; |
| |
270 let mut start_scan_2nd = 0; |
| |
271 |
| |
272 // Scan in order |
| |
273 if indices.len() == 0 { |
| |
274 return res |
| |
275 } |
| |
276 for k in 0..indices.len()-1 { |
| |
277 let i = indices[k]; |
| |
278 let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i]; |
| |
279 |
| |
280 if αi == 0.0 { |
| |
281 // Nothin to be done if the weight is already zero |
| |
282 continue |
| |
283 } |
| |
284 |
| |
285 let mut closest = None; |
| |
286 |
| |
287 // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd` |
| |
288 // the smallest invalid merging index on the previous loop iteration, because a |
| |
289 // the _closest_ mergeable spike might have index less than `k` in `indices`, and a |
| |
290 // merge with it might have not been attempted with this spike if a different closer |
| |
291 // spike was discovered based on the second coordinate. |
| |
292 'scan_2nd: for l in (start_scan_2nd+1)..indices.len() { |
| |
293 if l == k { |
| |
294 // Do not attempt to merge a spike with itself |
| |
295 continue |
| |
296 } |
| |
297 let j = indices[l]; |
| |
298 let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j]; |
| |
299 |
| |
300 if xj1 < xi1 - ρ { |
| |
301 // Spike `j = indices[l]` has too low first coordinate. Update starting index |
| |
302 // for next iteration, and continue scanning. |
| |
303 start_scan_2nd = l; |
| |
304 continue 'scan_2nd |
| |
305 } else if xj1 > xi1 + ρ { |
| |
306 // Break out: spike `j = indices[l]` has already too high first coordinate, no |
| |
307 // more close enough spikes can be found due to the sorting of `indices`. |
| |
308 break 'scan_2nd |
| |
309 } |
| |
310 |
| |
311 // If also second coordinate is close enough, attempt merging if closer than |
| |
312 // previously discovered mergeable spikes. |
| |
313 let d2 = (xi2-xj2).abs(); |
| |
314 if αj != 0.0 && d2 <= ρ { |
| |
315 let r1 = xi1-xj1; |
| |
316 let d = (d2*d2 + r1*r1).sqrt(); |
| |
317 match closest { |
| |
318 None => closest = Some((l, j, d)), |
| |
319 Some((_, _, r)) if r > d => closest = Some((l, j, d)), |
| |
320 _ => {}, |
| |
321 } |
| |
322 } |
| |
323 } |
| |
324 |
| |
325 // Attempt merging closest close-enough spike |
| |
326 if let Some((l, j, _)) = closest { |
| |
327 if self.attempt_merge(&mut res, i, j, &accept) { |
| |
328 // If merge was succesfull, make new spike candidate for merging. |
| |
329 indices[l] = self.spikes.len() - 1; |
| |
330 let compare = |i, j| compare_first_coordinate(&self.spikes[i], |
| |
331 &self.spikes[j]); |
| |
332 // Re-sort relevant range of indices |
| |
333 if l < k { |
| |
334 indices[l..k].sort_by(|&i, &j| compare(i, j)); |
| |
335 } else { |
| |
336 indices[k+1..=l].sort_by(|&i, &j| compare(i, j)); |
| |
337 } |
| |
338 } |
| |
339 } |
| |
340 } |
| |
341 |
| |
342 res |
| |
343 } |
| |
344 } |
| |
345 |