src/measures/merging.rs

changeset 0
eb3c7813b67a
equal deleted inserted replaced
-1:000000000000 0:eb3c7813b67a
1 /*!
2 Spike merging heuristics for [`DiscreteMeasure`]s.
3
4 This module primarily provides the [`SpikeMerging`] trait, and within it,
5 the [`SpikeMerging::merge_spikes`] method. The trait is implemented on
6 [`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`.
7 */
8
9 use numeric_literals::replace_float_literals;
10 use std::cmp::Ordering;
11 use serde::{Serialize, Deserialize};
12 //use clap::builder::{PossibleValuesParser, PossibleValue};
13 use alg_tools::nanleast::NaNLeast;
14
15 use crate::types::*;
16 use super::delta::*;
17 use super::discrete::*;
18
19 /// Spike merging heuristic selection
20 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
21 #[allow(dead_code)]
22 pub enum SpikeMergingMethod<F> {
23 /// Try to merge spikes within a given radius of eachother
24 HeuristicRadius(F),
25 /// No merging
26 None,
27 }
28
29 // impl<F : Float> SpikeMergingMethod<F> {
30 // /// This is for [`clap`] to display command line help.
31 // pub fn value_parser() -> PossibleValuesParser {
32 // PossibleValuesParser::new([
33 // PossibleValue::new("none").help("No merging"),
34 // PossibleValue::new("<radius>").help("Heuristic merging within indicated radius")
35 // ])
36 // }
37 // }
38
39 impl<F : ClapFloat> std::fmt::Display for SpikeMergingMethod<F> {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
41 match self {
42 Self::None => write!(f, "none"),
43 Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f),
44 }
45 }
46 }
47
48 impl<F : ClapFloat> std::str::FromStr for SpikeMergingMethod<F> {
49 type Err = F::Err;
50
51 fn from_str(s: &str) -> Result<Self, Self::Err> {
52 if s == "none" {
53 Ok(Self::None)
54 } else {
55 Ok(Self::HeuristicRadius(F::from_str(s)?))
56 }
57 }
58 }
59
60 #[replace_float_literals(F::cast_from(literal))]
61 impl<F : Float> Default for SpikeMergingMethod<F> {
62 fn default() -> Self {
63 SpikeMergingMethod::HeuristicRadius(0.02)
64 }
65 }
66
67 /// Trait for dimension-dependent implementation of heuristic peak merging strategies.
68 pub trait SpikeMerging<F> {
69 /// Attempt spike merging according to [`SpikeMerging`] method.
70 ///
71 /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure
72 /// `accept` if any merging is performed. The closure should accept as its only parameter a
73 /// new candidate measure (it will generally be internally mutated `self`, although this is
74 /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of
75 /// an arbitrary value. This method will return that value for the *last* accepted merge, or
76 /// [`None`] if no merge was accepted.
77 ///
78 /// This method is stable with respect to spike locations: on merge, the weight of existing
79 /// spikes is set to zero, and a new one inserted at the end of the spike vector.
80 fn merge_spikes<G, V>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> Option<V>
81 where G : Fn(&'_ Self) -> Option<V> {
82 match method {
83 SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept),
84 SpikeMergingMethod::None => None,
85 }
86 }
87
88 /// Attempt to merge spikes based on a value and a fitness function.
89 ///
90 /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of
91 /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value`
92 // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial
93 /// `self` is returned.
94 fn merge_spikes_fitness<G, H, V, O>(
95 &mut self,
96 method : SpikeMergingMethod<F>,
97 value : G,
98 fitness : H
99 ) -> V
100 where G : Fn(&'_ Self) -> V,
101 H : Fn(&'_ V) -> O,
102 O : PartialOrd {
103 let initial_res = value(self);
104 let initial_fitness = fitness(&initial_res);
105 self.merge_spikes(method, |μ| {
106 let res = value(μ);
107 (fitness(&res) <= initial_fitness).then_some(res)
108 }).unwrap_or(initial_res)
109 }
110
111 /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm).
112 ///
113 /// This method implements [`SpikeMerging::merge_spikes`] for
114 /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are
115 /// as for that method.
116 fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V>
117 where G : Fn(&'_ Self) -> Option<V>;
118 }
119
120 #[replace_float_literals(F::cast_from(literal))]
121 impl<F : Float, const N : usize> DiscreteMeasure<Loc<F, N>, F> {
122 /// Attempts to merge spikes with indices `i` and `j`.
123 ///
124 /// This assumes that the weights of the two spikes have already been checked not to be zero.
125 ///
126 /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`].
127 /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its
128 /// return value.
129 fn attempt_merge<G, V>(
130 &mut self,
131 res : &mut Option<V>,
132 i : usize,
133 j : usize,
134 accept : &G
135 ) -> bool
136 where G : Fn(&'_ Self) -> Option<V> {
137 let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
138 let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];
139
140 // Merge inplace
141 self.spikes[i].α = 0.0;
142 self.spikes[j].α = 0.0;
143 //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 });
144 self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) });
145 match accept(self) {
146 some@Some(..) => {
147 // Merge accepted, update our return value
148 *res = some;
149 // On next iteration process the newly merged spike.
150 //indices[k+1] = self.spikes.len() - 1;
151 true
152 },
153 None => {
154 // Merge not accepted, restore modification
155 self.spikes[i].α = αi;
156 self.spikes[j].α = αj;
157 self.spikes.pop();
158 false
159 }
160 }
161 }
162
163 /*
164 /// Attempts to merge spikes with indices i and j, acceptance through a delta.
165 fn attempt_merge_change<G, V>(
166 &mut self,
167 res : &mut Option<V>,
168 i : usize,
169 j : usize,
170 accept_change : &G
171 ) -> bool
172 where G : Fn(&'_ Self) -> Option<V> {
173 let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
174 let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];
175 let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 };
176 let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into();
177
178 match accept_change(&λ) {
179 some@Some(..) => {
180 // Merge accepted, update our return value
181 *res = some;
182 self.spikes[i].α = 0.0;
183 self.spikes[j].α = 0.0;
184 self.spikes.push(δ);
185 true
186 },
187 None => {
188 false
189 }
190 }
191 }*/
192
193 }
194
195 /// Sorts a vector of indices into `slice` by `compare`.
196 ///
197 /// The closure `compare` operators on references to elements of `slice`.
198 /// Returns the sorted vector of indices into `slice`.
199 pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize>
200 where F : FnMut(&V, &V) -> Ordering
201 {
202 let mut indices = Vec::from_iter(0..slice.len());
203 indices.sort_by(|&i, &j| compare(&slice[i], &slice[j]));
204 indices
205 }
206
207 #[replace_float_literals(F::cast_from(literal))]
208 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> {
209
210 fn do_merge_spikes_radius<G, V>(
211 &mut self,
212 ρ : F,
213 accept : G
214 ) -> Option<V>
215 where G : Fn(&'_ Self) -> Option<V> {
216 // Sort by coordinate into an indexing array.
217 let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| {
218 let &Loc([x1]) = &δ1.x;
219 let &Loc([x2]) = &δ2.x;
220 // nan-ignoring ordering of floats
221 NaNLeast(x1).cmp(&NaNLeast(x2))
222 });
223
224 // Initialise result
225 let mut res = None;
226
227 // Scan consecutive pairs and merge if close enough and accepted by `accept`.
228 if indices.len() == 0 {
229 return res
230 }
231 for k in 0..(indices.len()-1) {
232 let i = indices[k];
233 let j = indices[k+1];
234 let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i];
235 let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j];
236 debug_assert!(xi <= xj);
237 // If close enough, attempt merging
238 if αi != 0.0 && αj != 0.0 && xj <= xi + ρ {
239 if self.attempt_merge(&mut res, i, j, &accept) {
240 indices[k+1] = self.spikes.len() - 1;
241 }
242 }
243 }
244
245 res
246 }
247 }
248
249 /// Orders `δ1` and `δ1` according to the first coordinate.
250 fn compare_first_coordinate<F : Float>(
251 δ1 : &DeltaMeasure<Loc<F, 2>, F>,
252 δ2 : &DeltaMeasure<Loc<F, 2>, F>
253 ) -> Ordering {
254 let &Loc([x11, ..]) = &δ1.x;
255 let &Loc([x21, ..]) = &δ2.x;
256 // nan-ignoring ordering of floats
257 NaNLeast(x11).cmp(&NaNLeast(x21))
258 }
259
260 #[replace_float_literals(F::cast_from(literal))]
261 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> {
262
263 fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V>
264 where G : Fn(&'_ Self) -> Option<V> {
265 // Sort by first coordinate into an indexing array.
266 let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate);
267
268 // Initialise result
269 let mut res = None;
270 let mut start_scan_2nd = 0;
271
272 // Scan in order
273 if indices.len() == 0 {
274 return res
275 }
276 for k in 0..indices.len()-1 {
277 let i = indices[k];
278 let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i];
279
280 if αi == 0.0 {
281 // Nothin to be done if the weight is already zero
282 continue
283 }
284
285 let mut closest = None;
286
287 // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd`
288 // the smallest invalid merging index on the previous loop iteration, because a
289 // the _closest_ mergeable spike might have index less than `k` in `indices`, and a
290 // merge with it might have not been attempted with this spike if a different closer
291 // spike was discovered based on the second coordinate.
292 'scan_2nd: for l in (start_scan_2nd+1)..indices.len() {
293 if l == k {
294 // Do not attempt to merge a spike with itself
295 continue
296 }
297 let j = indices[l];
298 let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j];
299
300 if xj1 < xi1 - ρ {
301 // Spike `j = indices[l]` has too low first coordinate. Update starting index
302 // for next iteration, and continue scanning.
303 start_scan_2nd = l;
304 continue 'scan_2nd
305 } else if xj1 > xi1 + ρ {
306 // Break out: spike `j = indices[l]` has already too high first coordinate, no
307 // more close enough spikes can be found due to the sorting of `indices`.
308 break 'scan_2nd
309 }
310
311 // If also second coordinate is close enough, attempt merging if closer than
312 // previously discovered mergeable spikes.
313 let d2 = (xi2-xj2).abs();
314 if αj != 0.0 && d2 <= ρ {
315 let r1 = xi1-xj1;
316 let d = (d2*d2 + r1*r1).sqrt();
317 match closest {
318 None => closest = Some((l, j, d)),
319 Some((_, _, r)) if r > d => closest = Some((l, j, d)),
320 _ => {},
321 }
322 }
323 }
324
325 // Attempt merging closest close-enough spike
326 if let Some((l, j, _)) = closest {
327 if self.attempt_merge(&mut res, i, j, &accept) {
328 // If merge was succesfull, make new spike candidate for merging.
329 indices[l] = self.spikes.len() - 1;
330 let compare = |i, j| compare_first_coordinate(&self.spikes[i],
331 &self.spikes[j]);
332 // Re-sort relevant range of indices
333 if l < k {
334 indices[l..k].sort_by(|&i, &j| compare(i, j));
335 } else {
336 indices[k+1..=l].sort_by(|&i, &j| compare(i, j));
337 }
338 }
339 }
340 }
341
342 res
343 }
344 }
345

mercurial