src/merging.rs

changeset 0
e8f3b6c55ce7
equal deleted inserted replaced
-1:000000000000 0:e8f3b6c55ce7
1 /*!
2 Spike merging heuristics for [`DiscreteMeasure`]s.
3
4 This module primarily provides the [`SpikeMerging`] trait, and within it,
5 the [`SpikeMerging::merge_spikes`] method. The trait is implemented on
6 [`DiscreteMeasure<Loc<N, F>, F>`]s in dimensions `N=1` and `N=2`.
7 */
8
9 use numeric_literals::replace_float_literals;
10 use serde::{Deserialize, Serialize};
11 use std::cmp::Ordering;
12 //use clap::builder::{PossibleValuesParser, PossibleValue};
13 use alg_tools::nanleast::NaNLeast;
14
15 use super::delta::*;
16 use super::discrete::*;
17 use alg_tools::loc::Loc;
18 use alg_tools::types::*;
19
20 /// Spike merging heuristic selection
21 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
22 #[allow(dead_code)]
23 pub struct SpikeMergingMethod<F> {
24 // Merging radius
25 pub radius: F,
26 // Enabled
27 pub enabled: bool,
28 // Interpolate merged points
29 pub interp: bool,
30 }
31
32 #[replace_float_literals(F::cast_from(literal))]
33 impl<F: Float> Default for SpikeMergingMethod<F> {
34 fn default() -> Self {
35 SpikeMergingMethod { radius: 0.01, enabled: false, interp: true }
36 }
37 }
38
39 /// Trait for dimension-dependent implementation of heuristic peak merging strategies.
40 pub trait SpikeMerging<F> {
41 /// Attempt spike merging according to [`SpikeMerging`] method.
42 ///
43 /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure
44 /// `accept` if any merging is performed. The closure should accept as its only parameter a
45 /// new candidate measure (it will generally be internally mutated `self`, although this is
46 /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of
47 /// an arbitrary value. This method will return that value for the *last* accepted merge, or
48 /// [`None`] if no merge was accepted.
49 ///
50 /// This method is stable with respect to spike locations: on merge, the weights of existing
51 /// removed spikes is set to zero, new ones inserted at the end of the spike vector.
52 /// They merge may also be performed by increasing the weights of the existing spikes,
53 /// without inserting new spikes.
54 fn merge_spikes<G>(&mut self, method: SpikeMergingMethod<F>, accept: G) -> usize
55 where
56 G: FnMut(&'_ Self) -> bool,
57 {
58 if method.enabled {
59 self.do_merge_spikes_radius(method.radius, method.interp, accept)
60 } else {
61 0
62 }
63 }
64
65 /// Attempt to merge spikes based on a value and a fitness function.
66 ///
67 /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of
68 /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value`
69 // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial
70 /// `self` is returned. also the number of merges is returned;
71 fn merge_spikes_fitness<G, H, V, O>(
72 &mut self,
73 method: SpikeMergingMethod<F>,
74 value: G,
75 fitness: H,
76 ) -> (V, usize)
77 where
78 G: Fn(&'_ Self) -> V,
79 H: Fn(&'_ V) -> O,
80 O: PartialOrd,
81 {
82 let mut res = value(self);
83 let initial_fitness = fitness(&res);
84 let count = self.merge_spikes(method, |μ| {
85 res = value(μ);
86 fitness(&res) <= initial_fitness
87 });
88 (res, count)
89 }
90
91 /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm).
92 ///
93 /// This method implements [`SpikeMerging::merge_spikes`].
94 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, accept: G) -> usize
95 where
96 G: FnMut(&'_ Self) -> bool;
97 }
98
99 #[replace_float_literals(F::cast_from(literal))]
100 impl<F: Float, const N: usize> DiscreteMeasure<Loc<N, F>, F> {
101 /// Attempts to merge spikes with indices `i` and `j`.
102 ///
103 /// This assumes that the weights of the two spikes have already been checked not to be zero.
104 ///
105 /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`].
106 /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its
107 /// return value.
108 ///
109 /// Returns the index of `self.spikes` storing the new spike.
110 fn attempt_merge<G>(
111 &mut self,
112 i: usize,
113 j: usize,
114 interp: bool,
115 accept: &mut G,
116 ) -> Option<usize>
117 where
118 G: FnMut(&'_ Self) -> bool,
119 {
120 let &DeltaMeasure { x: xi, α: αi } = &self.spikes[i];
121 let &DeltaMeasure { x: xj, α: αj } = &self.spikes[j];
122
123 if interp {
124 // Merge inplace
125 self.spikes[i].α = 0.0;
126 self.spikes[j].α = 0.0;
127 let αia = αi.abs();
128 let αja = αj.abs();
129 self.spikes
130 .push(DeltaMeasure { α: αi + αj, x: (xi * αia + xj * αja) / (αia + αja) });
131 if accept(self) {
132 Some(self.spikes.len() - 1)
133 } else {
134 // Merge not accepted, restore modification
135 self.spikes[i].α = αi;
136 self.spikes[j].α = αj;
137 self.spikes.pop();
138 None
139 }
140 } else {
141 // Attempt merge inplace, first combination
142 self.spikes[i].α = αi + αj;
143 self.spikes[j].α = 0.0;
144 if accept(self) {
145 // Merge accepted
146 Some(i)
147 } else {
148 // Attempt merge inplace, second combination
149 self.spikes[i].α = 0.0;
150 self.spikes[j].α = αi + αj;
151 if accept(self) {
152 // Merge accepted
153 Some(j)
154 } else {
155 // Merge not accepted, restore modification
156 self.spikes[i].α = αi;
157 self.spikes[j].α = αj;
158 None
159 }
160 }
161 }
162 }
163 }
164
165 /// Sorts a vector of indices into `slice` by `compare`.
166 ///
167 /// The closure `compare` operators on references to elements of `slice`.
168 /// Returns the sorted vector of indices into `slice`.
169 pub fn sort_indices_by<V, F>(slice: &[V], mut compare: F) -> Vec<usize>
170 where
171 F: FnMut(&V, &V) -> Ordering,
172 {
173 let mut indices = Vec::from_iter(0..slice.len());
174 indices.sort_by(|&i, &j| compare(&slice[i], &slice[j]));
175 indices
176 }
177
178 #[replace_float_literals(F::cast_from(literal))]
179 impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<1, F>, F> {
180 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize
181 where
182 G: FnMut(&'_ Self) -> bool,
183 {
184 // Sort by coordinate into an indexing array.
185 let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| {
186 let &Loc([x1]) = &δ1.x;
187 let &Loc([x2]) = &δ2.x;
188 // nan-ignoring ordering of floats
189 NaNLeast(x1).cmp(&NaNLeast(x2))
190 });
191
192 // Initialise result
193 let mut count = 0;
194
195 // Scan consecutive pairs and merge if close enough and accepted by `accept`.
196 if indices.len() == 0 {
197 return count;
198 }
199 for k in 0..(indices.len() - 1) {
200 let i = indices[k];
201 let j = indices[k + 1];
202 let &DeltaMeasure { x: Loc([xi]), α: αi } = &self.spikes[i];
203 let &DeltaMeasure { x: Loc([xj]), α: αj } = &self.spikes[j];
204 debug_assert!(xi <= xj);
205 // If close enough, attempt merging
206 if αi != 0.0 && αj != 0.0 && xj <= xi + ρ {
207 if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) {
208 // For this to work (the debug_assert! to not trigger above), the new
209 // coordinate produced by attempt_merge has to be at most xj.
210 indices[k + 1] = l;
211 count += 1
212 }
213 }
214 }
215
216 count
217 }
218 }
219
220 /// Orders `δ1` and `δ1` according to the first coordinate.
221 fn compare_first_coordinate<F: Float>(
222 δ1: &DeltaMeasure<Loc<2, F>, F>,
223 δ2: &DeltaMeasure<Loc<2, F>, F>,
224 ) -> Ordering {
225 let &Loc([x11, ..]) = &δ1.x;
226 let &Loc([x21, ..]) = &δ2.x;
227 // nan-ignoring ordering of floats
228 NaNLeast(x11).cmp(&NaNLeast(x21))
229 }
230
231 #[replace_float_literals(F::cast_from(literal))]
232 impl<F: Float> SpikeMerging<F> for DiscreteMeasure<Loc<2, F>, F> {
233 fn do_merge_spikes_radius<G>(&mut self, ρ: F, interp: bool, mut accept: G) -> usize
234 where
235 G: FnMut(&'_ Self) -> bool,
236 {
237 // Sort by first coordinate into an indexing array.
238 let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate);
239
240 // Initialise result
241 let mut count = 0;
242 let mut start_scan_2nd = 0;
243
244 // Scan in order
245 if indices.len() == 0 {
246 return count;
247 }
248 for k in 0..indices.len() - 1 {
249 let i = indices[k];
250 let &DeltaMeasure { x: Loc([xi1, xi2]), α: αi } = &self[i];
251
252 if αi == 0.0 {
253 // Nothin to be done if the weight is already zero
254 continue;
255 }
256
257 let mut closest = None;
258
259 // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd`
260 // the smallest invalid merging index on the previous loop iteration, because a
261 // the _closest_ mergeable spike might have index less than `k` in `indices`, and a
262 // merge with it might have not been attempted with this spike if a different closer
263 // spike was discovered based on the second coordinate.
264 'scan_2nd: for l in (start_scan_2nd + 1)..indices.len() {
265 if l == k {
266 // Do not attempt to merge a spike with itself
267 continue;
268 }
269 let j = indices[l];
270 let &DeltaMeasure { x: Loc([xj1, xj2]), α: αj } = &self[j];
271
272 if xj1 < xi1 - ρ {
273 // Spike `j = indices[l]` has too low first coordinate. Update starting index
274 // for next iteration, and continue scanning.
275 start_scan_2nd = l;
276 continue 'scan_2nd;
277 } else if xj1 > xi1 + ρ {
278 // Break out: spike `j = indices[l]` has already too high first coordinate, no
279 // more close enough spikes can be found due to the sorting of `indices`.
280 break 'scan_2nd;
281 }
282
283 // If also second coordinate is close enough, attempt merging if closer than
284 // previously discovered mergeable spikes.
285 let d2 = (xi2 - xj2).abs();
286 if αj != 0.0 && d2 <= ρ {
287 let r1 = xi1 - xj1;
288 let d = (d2 * d2 + r1 * r1).sqrt();
289 match closest {
290 None => closest = Some((l, j, d)),
291 Some((_, _, r)) if r > d => closest = Some((l, j, d)),
292 _ => {}
293 }
294 }
295 }
296
297 // Attempt merging closest close-enough spike
298 if let Some((l, j, _)) = closest {
299 if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) {
300 // If merge was succesfull, make new spike candidate for merging.
301 indices[l] = n;
302 count += 1;
303 let compare = |i, j| compare_first_coordinate(&self.spikes[i], &self.spikes[j]);
304 // Re-sort relevant range of indices
305 if l < k {
306 indices[l..k].sort_by(|&i, &j| compare(i, j));
307 } else {
308 indices[k + 1..=l].sort_by(|&i, &j| compare(i, j));
309 }
310 }
311 }
312 }
313
314 count
315 }
316 }

mercurial