src/measures/merging.rs

branch
dev
changeset 51
0693cc9ba9f0
parent 39
6316d68b58af
equal deleted inserted replaced
50:39c5e6c7759d 51:0693cc9ba9f0
5 the [`SpikeMerging::merge_spikes`] method. The trait is implemented on 5 the [`SpikeMerging::merge_spikes`] method. The trait is implemented on
6 [`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`. 6 [`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`.
7 */ 7 */
8 8
9 use numeric_literals::replace_float_literals; 9 use numeric_literals::replace_float_literals;
10 use serde::{Deserialize, Serialize};
10 use std::cmp::Ordering; 11 use std::cmp::Ordering;
11 use serde::{Serialize, Deserialize};
12 //use clap::builder::{PossibleValuesParser, PossibleValue}; 12 //use clap::builder::{PossibleValuesParser, PossibleValue};
13 use alg_tools::nanleast::NaNLeast; 13 use alg_tools::nanleast::NaNLeast;
14 14
15 use crate::types::*;
16 use super::delta::*; 15 use super::delta::*;
17 use super::discrete::*; 16 use super::discrete::*;
17 use crate::types::*;
18 18
19 /// Spike merging heuristic selection 19 /// Spike merging heuristic selection
20 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 20 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
21 #[allow(dead_code)] 21 #[allow(dead_code)]
22 pub struct SpikeMergingMethod<F> { 22 pub struct SpikeMergingMethod<F> {
23 // Merging radius 23 // Merging radius
24 pub(crate) radius : F, 24 pub(crate) radius: F,
25 // Enabled 25 // Enabled
26 pub(crate) enabled : bool, 26 pub(crate) enabled: bool,
27 // Interpolate merged points 27 // Interpolate merged points
28 pub(crate) interp : bool, 28 pub(crate) interp: bool,
29 } 29 }
30
31 30
32 #[replace_float_literals(F::cast_from(literal))] 31 #[replace_float_literals(F::cast_from(literal))]
33 impl<F : Float> Default for SpikeMergingMethod<F> { 32 impl<F: Float> Default for SpikeMergingMethod<F> {
34 fn default() -> Self { 33 fn default() -> Self {
35 SpikeMergingMethod{ 34 SpikeMergingMethod {
36 radius : 0.01, 35 radius: 0.01,
37 enabled : false, 36 enabled: false,
38 interp : true, 37 interp: true,
39 } 38 }
40 } 39 }
41 } 40 }
42 41
43 /// Trait for dimension-dependent implementation of heuristic peak merging strategies. 42 /// Trait for dimension-dependent implementation of heuristic peak merging strategies.
53 /// 52 ///
54 /// This method is stable with respect to spike locations: on merge, the weights of existing 53 /// This method is stable with respect to spike locations: on merge, the weights of existing
55 /// removed spikes is set to zero, new ones inserted at the end of the spike vector. 54 /// removed spikes is set to zero, new ones inserted at the end of the spike vector.
56 /// They merge may also be performed by increasing the weights of the existing spikes, 55 /// They merge may also be performed by increasing the weights of the existing spikes,
57 /// without inserting new spikes. 56 /// without inserting new spikes.
58 fn merge_spikes<G>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> usize 57 fn merge_spikes<G>(&mut self, method: SpikeMergingMethod<F>, accept: G) -> usize
59 where G : FnMut(&'_ Self) -> bool { 58 where
59 G: FnMut(&'_ Self) -> bool,
60 {
60 if method.enabled { 61 if method.enabled {
61 self.do_merge_spikes_radius(method.radius, method.interp, accept) 62 self.do_merge_spikes_radius(method.radius, method.interp, accept)
62 } else { 63 } else {
63 0 64 0
64 } 65 }
70 /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` 71 /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value`
71 // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial 72 // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial
72 /// `self` is returned. also the number of merges is returned; 73 /// `self` is returned. also the number of merges is returned;
73 fn merge_spikes_fitness<G, H, V, O>( 74 fn merge_spikes_fitness<G, H, V, O>(
74 &mut self, 75 &mut self,
75 method : SpikeMergingMethod<F>, 76 method: SpikeMergingMethod<F>,
76 value : G, 77 value: G,
77 fitness : H 78 fitness: H,
78 ) -> (V, usize) 79 ) -> (V, usize)
79 where G : Fn(&'_ Self) -> V, 80 where
80 H : Fn(&'_ V) -> O, 81 G: Fn(&'_ Self) -> V,
81 O : PartialOrd { 82 H: Fn(&'_ V) -> O,
83 O: PartialOrd,
84 {
82 let mut res = value(self); 85 let mut res = value(self);
83 let initial_fitness = fitness(&res); 86 let initial_fitness = fitness(&res);
84 let count = self.merge_spikes(method, |μ| { 87 let count = self.merge_spikes(method, |μ| {
85 res = value(μ); 88 res = value(μ);
86 fitness(&res) <= initial_fitness 89 fitness(&res) <= initial_fitness
88 (res, count) 91 (res, count)
89 } 92 }
90 93
91 /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). 94 /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm).
92 /// 95 ///
93 /// This method implements [`SpikeMerging::merge_spikes`] for 96 /// This method implements [`SpikeMerging::merge_spikes`].
94 /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are 97 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, accept: G) -> usize
95 /// as for that method. 98 where
96 fn do_merge_spikes_radius<G>(&mut self, ρ : F, interp : bool, accept : G) -> usize 99 G: FnMut(&'_ Self) -> bool;
97 where G : FnMut(&'_ Self) -> bool;
98 } 100 }
99 101
100 #[replace_float_literals(F::cast_from(literal))] 102 #[replace_float_literals(F::cast_from(literal))]
101 impl<F : Float, const N : usize> DiscreteMeasure<Loc<F, N>, F> { 103 impl<F: Float, const N: usize> DiscreteMeasure<Loc<F, N>, F> {
102 /// Attempts to merge spikes with indices `i` and `j`. 104 /// Attempts to merge spikes with indices `i` and `j`.
103 /// 105 ///
104 /// This assumes that the weights of the two spikes have already been checked not to be zero. 106 /// This assumes that the weights of the two spikes have already been checked not to be zero.
105 /// 107 ///
106 /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. 108 /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`].
108 /// return value. 110 /// return value.
109 /// 111 ///
110 /// Returns the index of `self.spikes` storing the new spike. 112 /// Returns the index of `self.spikes` storing the new spike.
111 fn attempt_merge<G>( 113 fn attempt_merge<G>(
112 &mut self, 114 &mut self,
113 i : usize, 115 i: usize,
114 j : usize, 116 j: usize,
115 interp : bool, 117 interp: bool,
116 accept : &mut G 118 accept: &mut G,
117 ) -> Option<usize> 119 ) -> Option<usize>
118 where G : FnMut(&'_ Self) -> bool { 120 where
119 let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; 121 G: FnMut(&'_ Self) -> bool,
120 let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; 122 {
123 let &DeltaMeasure { x: xi, α: αi } = &self.spikes[i];
124 let &DeltaMeasure { x: xj, α: αj } = &self.spikes[j];
121 125
122 if interp { 126 if interp {
123 // Merge inplace 127 // Merge inplace
124 self.spikes[i].α = 0.0; 128 self.spikes[i].α = 0.0;
125 self.spikes[j].α = 0.0; 129 self.spikes[j].α = 0.0;
126 let αia = αi.abs(); 130 let αia = αi.abs();
127 let αja = αj.abs(); 131 let αja = αj.abs();
128 self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αia + xj * αja) / (αia + αja) }); 132 self.spikes.push(DeltaMeasure {
133 α: αi + αj,
134 x: (xi * αia + xj * αja) / (αia + αja),
135 });
129 if accept(self) { 136 if accept(self) {
130 Some(self.spikes.len()-1) 137 Some(self.spikes.len() - 1)
131 } else { 138 } else {
132 // Merge not accepted, restore modification 139 // Merge not accepted, restore modification
133 self.spikes[i].α = αi; 140 self.spikes[i].α = αi;
134 self.spikes[j].α = αj; 141 self.spikes[j].α = αj;
135 self.spikes.pop(); 142 self.spikes.pop();
162 169
163 /// Sorts a vector of indices into `slice` by `compare`. 170 /// Sorts a vector of indices into `slice` by `compare`.
164 /// 171 ///
165 /// The closure `compare` operators on references to elements of `slice`. 172 /// The closure `compare` operators on references to elements of `slice`.
166 /// Returns the sorted vector of indices into `slice`. 173 /// Returns the sorted vector of indices into `slice`.
167 pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize> 174 pub fn sort_indices_by<V, F>(slice: &[V], mut compare: F) -> Vec<usize>
168 where F : FnMut(&V, &V) -> Ordering 175 where
176 F: FnMut(&V, &V) -> Ordering,
169 { 177 {
170 let mut indices = Vec::from_iter(0..slice.len()); 178 let mut indices = Vec::from_iter(0..slice.len());
171 indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); 179 indices.sort_by(|&i, &j| compare(&slice[i], &slice[j]));
172 indices 180 indices
173 } 181 }
174 182
175 #[replace_float_literals(F::cast_from(literal))] 183 #[replace_float_literals(F::cast_from(literal))]
176 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> { 184 impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> {
177 185 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize
178 fn do_merge_spikes_radius<G>( 186 where
179 &mut self, 187 G: FnMut(&'_ Self) -> bool,
180 ρ : F, 188 {
181 interp : bool,
182 mut accept : G
183 ) -> usize
184 where G : FnMut(&'_ Self) -> bool {
185 // Sort by coordinate into an indexing array. 189 // Sort by coordinate into an indexing array.
186 let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { 190 let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| {
187 let &Loc([x1]) = &δ1.x; 191 let &Loc([x1]) = &δ1.x;
188 let &Loc([x2]) = &δ2.x; 192 let &Loc([x2]) = &δ2.x;
189 // nan-ignoring ordering of floats 193 // nan-ignoring ordering of floats
193 // Initialise result 197 // Initialise result
194 let mut count = 0; 198 let mut count = 0;
195 199
196 // Scan consecutive pairs and merge if close enough and accepted by `accept`. 200 // Scan consecutive pairs and merge if close enough and accepted by `accept`.
197 if indices.len() == 0 { 201 if indices.len() == 0 {
198 return count 202 return count;
199 } 203 }
200 for k in 0..(indices.len()-1) { 204 for k in 0..(indices.len() - 1) {
201 let i = indices[k]; 205 let i = indices[k];
202 let j = indices[k+1]; 206 let j = indices[k + 1];
203 let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i]; 207 let &DeltaMeasure {
204 let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j]; 208 x: Loc([xi]),
209 α: αi,
210 } = &self.spikes[i];
211 let &DeltaMeasure {
212 x: Loc([xj]),
213 α: αj,
214 } = &self.spikes[j];
205 debug_assert!(xi <= xj); 215 debug_assert!(xi <= xj);
206 // If close enough, attempt merging 216 // If close enough, attempt merging
207 if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { 217 if αi != 0.0 && αj != 0.0 && xj <= xi + ρ {
208 if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) { 218 if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) {
209 // For this to work (the debug_assert! to not trigger above), the new 219 // For this to work (the debug_assert! to not trigger above), the new
210 // coordinate produced by attempt_merge has to be at most xj. 220 // coordinate produced by attempt_merge has to be at most xj.
211 indices[k+1] = l; 221 indices[k + 1] = l;
212 count += 1 222 count += 1
213 } 223 }
214 } 224 }
215 } 225 }
216 226
217 count 227 count
218 } 228 }
219 } 229 }
220 230
221 /// Orders `δ1` and `δ1` according to the first coordinate. 231 /// Orders `δ1` and `δ1` according to the first coordinate.
222 fn compare_first_coordinate<F : Float>( 232 fn compare_first_coordinate<F: Float>(
223 δ1 : &DeltaMeasure<Loc<F, 2>, F>, 233 δ1: &DeltaMeasure<Loc<F, 2>, F>,
224 δ2 : &DeltaMeasure<Loc<F, 2>, F> 234 δ2: &DeltaMeasure<Loc<F, 2>, F>,
225 ) -> Ordering { 235 ) -> Ordering {
226 let &Loc([x11, ..]) = &δ1.x; 236 let &Loc([x11, ..]) = &δ1.x;
227 let &Loc([x21, ..]) = &δ2.x; 237 let &Loc([x21, ..]) = &δ2.x;
228 // nan-ignoring ordering of floats 238 // nan-ignoring ordering of floats
229 NaNLeast(x11).cmp(&NaNLeast(x21)) 239 NaNLeast(x11).cmp(&NaNLeast(x21))
230 } 240 }
231 241
232 #[replace_float_literals(F::cast_from(literal))] 242 #[replace_float_literals(F::cast_from(literal))]
233 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> { 243 impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> {
234 244 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize
235 fn do_merge_spikes_radius<G>(&mut self, ρ : F, interp : bool, mut accept : G) -> usize 245 where
236 where G : FnMut(&'_ Self) -> bool { 246 G: FnMut(&'_ Self) -> bool,
247 {
237 // Sort by first coordinate into an indexing array. 248 // Sort by first coordinate into an indexing array.
238 let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); 249 let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate);
239 250
240 // Initialise result 251 // Initialise result
241 let mut count = 0; 252 let mut count = 0;
242 let mut start_scan_2nd = 0; 253 let mut start_scan_2nd = 0;
243 254
244 // Scan in order 255 // Scan in order
245 if indices.len() == 0 { 256 if indices.len() == 0 {
246 return count 257 return count;
247 } 258 }
248 for k in 0..indices.len()-1 { 259 for k in 0..indices.len() - 1 {
249 let i = indices[k]; 260 let i = indices[k];
250 let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i]; 261 let &DeltaMeasure {
262 x: Loc([xi1, xi2]),
263 α: αi,
264 } = &self[i];
251 265
252 if αi == 0.0 { 266 if αi == 0.0 {
253 // Nothin to be done if the weight is already zero 267 // Nothin to be done if the weight is already zero
254 continue 268 continue;
255 } 269 }
256 270
257 let mut closest = None; 271 let mut closest = None;
258 272
259 // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd` 273 // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd`
260 // the smallest invalid merging index on the previous loop iteration, because a 274 // the smallest invalid merging index on the previous loop iteration, because a
261 // the _closest_ mergeable spike might have index less than `k` in `indices`, and a 275 // the _closest_ mergeable spike might have index less than `k` in `indices`, and a
262 // merge with it might have not been attempted with this spike if a different closer 276 // merge with it might have not been attempted with this spike if a different closer
263 // spike was discovered based on the second coordinate. 277 // spike was discovered based on the second coordinate.
264 'scan_2nd: for l in (start_scan_2nd+1)..indices.len() { 278 'scan_2nd: for l in (start_scan_2nd + 1)..indices.len() {
265 if l == k { 279 if l == k {
266 // Do not attempt to merge a spike with itself 280 // Do not attempt to merge a spike with itself
267 continue 281 continue;
268 } 282 }
269 let j = indices[l]; 283 let j = indices[l];
270 let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j]; 284 let &DeltaMeasure {
285 x: Loc([xj1, xj2]),
286 α: αj,
287 } = &self[j];
271 288
272 if xj1 < xi1 - ρ { 289 if xj1 < xi1 - ρ {
273 // Spike `j = indices[l]` has too low first coordinate. Update starting index 290 // Spike `j = indices[l]` has too low first coordinate. Update starting index
274 // for next iteration, and continue scanning. 291 // for next iteration, and continue scanning.
275 start_scan_2nd = l; 292 start_scan_2nd = l;
276 continue 'scan_2nd 293 continue 'scan_2nd;
277 } else if xj1 > xi1 + ρ { 294 } else if xj1 > xi1 + ρ {
278 // Break out: spike `j = indices[l]` has already too high first coordinate, no 295 // Break out: spike `j = indices[l]` has already too high first coordinate, no
279 // more close enough spikes can be found due to the sorting of `indices`. 296 // more close enough spikes can be found due to the sorting of `indices`.
280 break 'scan_2nd 297 break 'scan_2nd;
281 } 298 }
282 299
283 // If also second coordinate is close enough, attempt merging if closer than 300 // If also second coordinate is close enough, attempt merging if closer than
284 // previously discovered mergeable spikes. 301 // previously discovered mergeable spikes.
285 let d2 = (xi2-xj2).abs(); 302 let d2 = (xi2 - xj2).abs();
286 if αj != 0.0 && d2 <= ρ { 303 if αj != 0.0 && d2 <= ρ {
287 let r1 = xi1-xj1; 304 let r1 = xi1 - xj1;
288 let d = (d2*d2 + r1*r1).sqrt(); 305 let d = (d2 * d2 + r1 * r1).sqrt();
289 match closest { 306 match closest {
290 None => closest = Some((l, j, d)), 307 None => closest = Some((l, j, d)),
291 Some((_, _, r)) if r > d => closest = Some((l, j, d)), 308 Some((_, _, r)) if r > d => closest = Some((l, j, d)),
292 _ => {}, 309 _ => {}
293 } 310 }
294 } 311 }
295 } 312 }
296 313
297 // Attempt merging closest close-enough spike 314 // Attempt merging closest close-enough spike
298 if let Some((l, j, _)) = closest { 315 if let Some((l, j, _)) = closest {
299 if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) { 316 if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) {
300 // If merge was succesfull, make new spike candidate for merging. 317 // If merge was succesfull, make new spike candidate for merging.
301 indices[l] = n; 318 indices[l] = n;
302 count += 1; 319 count += 1;
303 let compare = |i, j| compare_first_coordinate(&self.spikes[i], 320 let compare = |i, j| compare_first_coordinate(&self.spikes[i], &self.spikes[j]);
304 &self.spikes[j]);
305 // Re-sort relevant range of indices 321 // Re-sort relevant range of indices
306 if l < k { 322 if l < k {
307 indices[l..k].sort_by(|&i, &j| compare(i, j)); 323 indices[l..k].sort_by(|&i, &j| compare(i, j));
308 } else { 324 } else {
309 indices[k+1..=l].sort_by(|&i, &j| compare(i, j)); 325 indices[k + 1..=l].sort_by(|&i, &j| compare(i, j));
310 } 326 }
311 } 327 }
312 } 328 }
313 } 329 }
314 330
315 count 331 count
316 } 332 }
317 } 333 }
318

mercurial