Sun, 11 Dec 2022 23:19:17 +0200
Print out experiment information when running it
0 | 1 | /*! |
2 | Spike merging heuristics for [`DiscreteMeasure`]s. | |
3 | ||
4 | This module primarily provides the [`SpikeMerging`] trait, and within it, | |
5 | the [`SpikeMerging::merge_spikes`] method. The trait is implemented on | |
6 | [`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`. | |
7 | */ | |
8 | ||
9 | use numeric_literals::replace_float_literals; | |
10 | use std::cmp::Ordering; | |
11 | use serde::{Serialize, Deserialize}; | |
12 | //use clap::builder::{PossibleValuesParser, PossibleValue}; | |
13 | use alg_tools::nanleast::NaNLeast; | |
14 | ||
15 | use crate::types::*; | |
16 | use super::delta::*; | |
17 | use super::discrete::*; | |
18 | ||
19 | /// Spike merging heuristic selection | |
20 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
21 | #[allow(dead_code)] | |
22 | pub enum SpikeMergingMethod<F> { | |
23 | /// Try to merge spikes within a given radius of eachother | |
24 | HeuristicRadius(F), | |
25 | /// No merging | |
26 | None, | |
27 | } | |
28 | ||
29 | // impl<F : Float> SpikeMergingMethod<F> { | |
30 | // /// This is for [`clap`] to display command line help. | |
31 | // pub fn value_parser() -> PossibleValuesParser { | |
32 | // PossibleValuesParser::new([ | |
33 | // PossibleValue::new("none").help("No merging"), | |
34 | // PossibleValue::new("<radius>").help("Heuristic merging within indicated radius") | |
35 | // ]) | |
36 | // } | |
37 | // } | |
38 | ||
39 | impl<F : ClapFloat> std::fmt::Display for SpikeMergingMethod<F> { | |
40 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { | |
41 | match self { | |
42 | Self::None => write!(f, "none"), | |
43 | Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f), | |
44 | } | |
45 | } | |
46 | } | |
47 | ||
48 | impl<F : ClapFloat> std::str::FromStr for SpikeMergingMethod<F> { | |
49 | type Err = F::Err; | |
50 | ||
51 | fn from_str(s: &str) -> Result<Self, Self::Err> { | |
52 | if s == "none" { | |
53 | Ok(Self::None) | |
54 | } else { | |
55 | Ok(Self::HeuristicRadius(F::from_str(s)?)) | |
56 | } | |
57 | } | |
58 | } | |
59 | ||
60 | #[replace_float_literals(F::cast_from(literal))] | |
61 | impl<F : Float> Default for SpikeMergingMethod<F> { | |
62 | fn default() -> Self { | |
63 | SpikeMergingMethod::HeuristicRadius(0.02) | |
64 | } | |
65 | } | |
66 | ||
67 | /// Trait for dimension-dependent implementation of heuristic peak merging strategies. | |
68 | pub trait SpikeMerging<F> { | |
69 | /// Attempt spike merging according to [`SpikeMerging`] method. | |
70 | /// | |
71 | /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure | |
72 | /// `accept` if any merging is performed. The closure should accept as its only parameter a | |
73 | /// new candidate measure (it will generally be internally mutated `self`, although this is | |
74 | /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of | |
75 | /// an arbitrary value. This method will return that value for the *last* accepted merge, or | |
76 | /// [`None`] if no merge was accepted. | |
77 | /// | |
78 | /// This method is stable with respect to spike locations: on merge, the weight of existing | |
79 | /// spikes is set to zero, and a new one inserted at the end of the spike vector. | |
80 | fn merge_spikes<G, V>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> Option<V> | |
81 | where G : Fn(&'_ Self) -> Option<V> { | |
82 | match method { | |
83 | SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept), | |
84 | SpikeMergingMethod::None => None, | |
85 | } | |
86 | } | |
87 | ||
88 | /// Attempt to merge spikes based on a value and a fitness function. | |
89 | /// | |
90 | /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of | |
91 | /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` | |
92 | // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial | |
93 | /// `self` is returned. | |
94 | fn merge_spikes_fitness<G, H, V, O>( | |
95 | &mut self, | |
96 | method : SpikeMergingMethod<F>, | |
97 | value : G, | |
98 | fitness : H | |
99 | ) -> V | |
100 | where G : Fn(&'_ Self) -> V, | |
101 | H : Fn(&'_ V) -> O, | |
102 | O : PartialOrd { | |
103 | let initial_res = value(self); | |
104 | let initial_fitness = fitness(&initial_res); | |
105 | self.merge_spikes(method, |μ| { | |
106 | let res = value(μ); | |
107 | (fitness(&res) <= initial_fitness).then_some(res) | |
108 | }).unwrap_or(initial_res) | |
109 | } | |
110 | ||
111 | /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). | |
112 | /// | |
113 | /// This method implements [`SpikeMerging::merge_spikes`] for | |
114 | /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are | |
115 | /// as for that method. | |
116 | fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> | |
117 | where G : Fn(&'_ Self) -> Option<V>; | |
118 | } | |
119 | ||
120 | #[replace_float_literals(F::cast_from(literal))] | |
121 | impl<F : Float, const N : usize> DiscreteMeasure<Loc<F, N>, F> { | |
122 | /// Attempts to merge spikes with indices `i` and `j`. | |
123 | /// | |
124 | /// This assumes that the weights of the two spikes have already been checked not to be zero. | |
125 | /// | |
126 | /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. | |
127 | /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its | |
128 | /// return value. | |
129 | fn attempt_merge<G, V>( | |
130 | &mut self, | |
131 | res : &mut Option<V>, | |
132 | i : usize, | |
133 | j : usize, | |
134 | accept : &G | |
135 | ) -> bool | |
136 | where G : Fn(&'_ Self) -> Option<V> { | |
137 | let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; | |
138 | let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; | |
139 | ||
140 | // Merge inplace | |
141 | self.spikes[i].α = 0.0; | |
142 | self.spikes[j].α = 0.0; | |
143 | //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }); | |
144 | self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) }); | |
145 | match accept(self) { | |
146 | some@Some(..) => { | |
147 | // Merge accepted, update our return value | |
148 | *res = some; | |
149 | // On next iteration process the newly merged spike. | |
150 | //indices[k+1] = self.spikes.len() - 1; | |
151 | true | |
152 | }, | |
153 | None => { | |
154 | // Merge not accepted, restore modification | |
155 | self.spikes[i].α = αi; | |
156 | self.spikes[j].α = αj; | |
157 | self.spikes.pop(); | |
158 | false | |
159 | } | |
160 | } | |
161 | } | |
162 | ||
163 | /* | |
164 | /// Attempts to merge spikes with indices i and j, acceptance through a delta. | |
165 | fn attempt_merge_change<G, V>( | |
166 | &mut self, | |
167 | res : &mut Option<V>, | |
168 | i : usize, | |
169 | j : usize, | |
170 | accept_change : &G | |
171 | ) -> bool | |
172 | where G : Fn(&'_ Self) -> Option<V> { | |
173 | let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; | |
174 | let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; | |
175 | let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }; | |
176 | let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into(); | |
177 | ||
178 | match accept_change(&λ) { | |
179 | some@Some(..) => { | |
180 | // Merge accepted, update our return value | |
181 | *res = some; | |
182 | self.spikes[i].α = 0.0; | |
183 | self.spikes[j].α = 0.0; | |
184 | self.spikes.push(δ); | |
185 | true | |
186 | }, | |
187 | None => { | |
188 | false | |
189 | } | |
190 | } | |
191 | }*/ | |
192 | ||
193 | } | |
194 | ||
195 | /// Sorts a vector of indices into `slice` by `compare`. | |
196 | /// | |
197 | /// The closure `compare` operators on references to elements of `slice`. | |
198 | /// Returns the sorted vector of indices into `slice`. | |
199 | pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize> | |
200 | where F : FnMut(&V, &V) -> Ordering | |
201 | { | |
202 | let mut indices = Vec::from_iter(0..slice.len()); | |
203 | indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); | |
204 | indices | |
205 | } | |
206 | ||
207 | #[replace_float_literals(F::cast_from(literal))] | |
208 | impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> { | |
209 | ||
210 | fn do_merge_spikes_radius<G, V>( | |
211 | &mut self, | |
212 | ρ : F, | |
213 | accept : G | |
214 | ) -> Option<V> | |
215 | where G : Fn(&'_ Self) -> Option<V> { | |
216 | // Sort by coordinate into an indexing array. | |
217 | let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { | |
218 | let &Loc([x1]) = &δ1.x; | |
219 | let &Loc([x2]) = &δ2.x; | |
220 | // nan-ignoring ordering of floats | |
221 | NaNLeast(x1).cmp(&NaNLeast(x2)) | |
222 | }); | |
223 | ||
224 | // Initialise result | |
225 | let mut res = None; | |
226 | ||
227 | // Scan consecutive pairs and merge if close enough and accepted by `accept`. | |
228 | if indices.len() == 0 { | |
229 | return res | |
230 | } | |
231 | for k in 0..(indices.len()-1) { | |
232 | let i = indices[k]; | |
233 | let j = indices[k+1]; | |
234 | let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i]; | |
235 | let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j]; | |
236 | debug_assert!(xi <= xj); | |
237 | // If close enough, attempt merging | |
238 | if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { | |
239 | if self.attempt_merge(&mut res, i, j, &accept) { | |
240 | indices[k+1] = self.spikes.len() - 1; | |
241 | } | |
242 | } | |
243 | } | |
244 | ||
245 | res | |
246 | } | |
247 | } | |
248 | ||
249 | /// Orders `δ1` and `δ1` according to the first coordinate. | |
250 | fn compare_first_coordinate<F : Float>( | |
251 | δ1 : &DeltaMeasure<Loc<F, 2>, F>, | |
252 | δ2 : &DeltaMeasure<Loc<F, 2>, F> | |
253 | ) -> Ordering { | |
254 | let &Loc([x11, ..]) = &δ1.x; | |
255 | let &Loc([x21, ..]) = &δ2.x; | |
256 | // nan-ignoring ordering of floats | |
257 | NaNLeast(x11).cmp(&NaNLeast(x21)) | |
258 | } | |
259 | ||
260 | #[replace_float_literals(F::cast_from(literal))] | |
261 | impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> { | |
262 | ||
263 | fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> | |
264 | where G : Fn(&'_ Self) -> Option<V> { | |
265 | // Sort by first coordinate into an indexing array. | |
266 | let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); | |
267 | ||
268 | // Initialise result | |
269 | let mut res = None; | |
270 | let mut start_scan_2nd = 0; | |
271 | ||
272 | // Scan in order | |
273 | if indices.len() == 0 { | |
274 | return res | |
275 | } | |
276 | for k in 0..indices.len()-1 { | |
277 | let i = indices[k]; | |
278 | let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i]; | |
279 | ||
280 | if αi == 0.0 { | |
281 | // Nothin to be done if the weight is already zero | |
282 | continue | |
283 | } | |
284 | ||
285 | let mut closest = None; | |
286 | ||
287 | // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd` | |
288 | // the smallest invalid merging index on the previous loop iteration, because a | |
289 | // the _closest_ mergeable spike might have index less than `k` in `indices`, and a | |
290 | // merge with it might have not been attempted with this spike if a different closer | |
291 | // spike was discovered based on the second coordinate. | |
292 | 'scan_2nd: for l in (start_scan_2nd+1)..indices.len() { | |
293 | if l == k { | |
294 | // Do not attempt to merge a spike with itself | |
295 | continue | |
296 | } | |
297 | let j = indices[l]; | |
298 | let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j]; | |
299 | ||
300 | if xj1 < xi1 - ρ { | |
301 | // Spike `j = indices[l]` has too low first coordinate. Update starting index | |
302 | // for next iteration, and continue scanning. | |
303 | start_scan_2nd = l; | |
304 | continue 'scan_2nd | |
305 | } else if xj1 > xi1 + ρ { | |
306 | // Break out: spike `j = indices[l]` has already too high first coordinate, no | |
307 | // more close enough spikes can be found due to the sorting of `indices`. | |
308 | break 'scan_2nd | |
309 | } | |
310 | ||
311 | // If also second coordinate is close enough, attempt merging if closer than | |
312 | // previously discovered mergeable spikes. | |
313 | let d2 = (xi2-xj2).abs(); | |
314 | if αj != 0.0 && d2 <= ρ { | |
315 | let r1 = xi1-xj1; | |
316 | let d = (d2*d2 + r1*r1).sqrt(); | |
317 | match closest { | |
318 | None => closest = Some((l, j, d)), | |
319 | Some((_, _, r)) if r > d => closest = Some((l, j, d)), | |
320 | _ => {}, | |
321 | } | |
322 | } | |
323 | } | |
324 | ||
325 | // Attempt merging closest close-enough spike | |
326 | if let Some((l, j, _)) = closest { | |
327 | if self.attempt_merge(&mut res, i, j, &accept) { | |
328 | // If merge was succesfull, make new spike candidate for merging. | |
329 | indices[l] = self.spikes.len() - 1; | |
330 | let compare = |i, j| compare_first_coordinate(&self.spikes[i], | |
331 | &self.spikes[j]); | |
332 | // Re-sort relevant range of indices | |
333 | if l < k { | |
334 | indices[l..k].sort_by(|&i, &j| compare(i, j)); | |
335 | } else { | |
336 | indices[k+1..=l].sort_by(|&i, &j| compare(i, j)); | |
337 | } | |
338 | } | |
339 | } | |
340 | } | |
341 | ||
342 | res | |
343 | } | |
344 | } | |
345 |