src/measures/merging.rs

branch
dev
changeset 34
efa60bc4f743
parent 0
eb3c7813b67a
child 39
6316d68b58af
equal deleted inserted replaced
33:aec67cdd6b14 34:efa60bc4f743
18 18
19 /// Spike merging heuristic selection 19 /// Spike merging heuristic selection
20 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 20 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
21 #[allow(dead_code)] 21 #[allow(dead_code)]
22 pub enum SpikeMergingMethod<F> { 22 pub enum SpikeMergingMethod<F> {
23 /// Try to merge spikes within a given radius of eachother 23 /// Try to merge spikes within a given radius of each other, averaging the location
24 HeuristicRadius(F), 24 HeuristicRadius(F),
25 /// Try to merge spikes within a given radius of each other, attempting original locations
26 HeuristicRadiusNoInterp(F),
25 /// No merging 27 /// No merging
26 None, 28 None,
27 } 29 }
28 30
29 // impl<F : Float> SpikeMergingMethod<F> { 31 // impl<F : Float> SpikeMergingMethod<F> {
38 40
39 impl<F : ClapFloat> std::fmt::Display for SpikeMergingMethod<F> { 41 impl<F : ClapFloat> std::fmt::Display for SpikeMergingMethod<F> {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { 42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
41 match self { 43 match self {
42 Self::None => write!(f, "none"), 44 Self::None => write!(f, "none"),
43 Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f), 45 Self::HeuristicRadius(r) => write!(f, "i:{}", r),
46 Self::HeuristicRadiusNoInterp(r) => write!(f, "n:{}", r),
44 } 47 }
45 } 48 }
46 } 49 }
47 50
48 impl<F : ClapFloat> std::str::FromStr for SpikeMergingMethod<F> { 51 impl<F : ClapFloat> std::str::FromStr for SpikeMergingMethod<F> {
50 53
51 fn from_str(s: &str) -> Result<Self, Self::Err> { 54 fn from_str(s: &str) -> Result<Self, Self::Err> {
52 if s == "none" { 55 if s == "none" {
53 Ok(Self::None) 56 Ok(Self::None)
54 } else { 57 } else {
55 Ok(Self::HeuristicRadius(F::from_str(s)?)) 58 let mut subs = s.split(':');
59 match subs.next() {
60 None => Ok(Self::HeuristicRadius(F::from_str(s)?)),
61 Some(t) if t == "n" => match subs.next() {
62 None => Err(core::num::dec2flt::pfe_invalid()),
63 Some(v) => Ok(Self::HeuristicRadiusNoInterp(F::from_str(v)?))
64 },
65 Some(t) if t == "i" => match subs.next() {
66 None => Err(core::num::dec2flt::pfe_invalid()),
67 Some(v) => Ok(Self::HeuristicRadius(F::from_str(v)?))
68 },
69 Some(v) => Ok(Self::HeuristicRadius(F::from_str(v)?))
70 }
56 } 71 }
57 } 72 }
58 } 73 }
59 74
60 #[replace_float_literals(F::cast_from(literal))] 75 #[replace_float_literals(F::cast_from(literal))]
75 /// an arbitrary value. This method will return that value for the *last* accepted merge, or 90 /// an arbitrary value. This method will return that value for the *last* accepted merge, or
76 /// [`None`] if no merge was accepted. 91 /// [`None`] if no merge was accepted.
77 /// 92 ///
78 /// This method is stable with respect to spike locations: on merge, the weight of existing 93 /// This method is stable with respect to spike locations: on merge, the weight of existing
79 /// spikes is set to zero, and a new one inserted at the end of the spike vector. 94 /// spikes is set to zero, and a new one inserted at the end of the spike vector.
80 fn merge_spikes<G, V>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> Option<V> 95 fn merge_spikes<G>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> usize
81 where G : Fn(&'_ Self) -> Option<V> { 96 where G : FnMut(&'_ Self) -> bool {
82 match method { 97 match method {
83 SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept), 98 SpikeMergingMethod::HeuristicRadius(ρ) =>
84 SpikeMergingMethod::None => None, 99 self.do_merge_spikes_radius(ρ, true, accept),
100 SpikeMergingMethod::HeuristicRadiusNoInterp(ρ) =>
101 self.do_merge_spikes_radius(ρ, false, accept),
102 SpikeMergingMethod::None => 0,
85 } 103 }
86 } 104 }
87 105
88 /// Attempt to merge spikes based on a value and a fitness function. 106 /// Attempt to merge spikes based on a value and a fitness function.
89 /// 107 ///
90 /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of 108 /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of
91 /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` 109 /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value`
92 // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial 110 // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial
93 /// `self` is returned. 111 /// `self` is returned. also the number of merges is returned;
94 fn merge_spikes_fitness<G, H, V, O>( 112 fn merge_spikes_fitness<G, H, V, O>(
95 &mut self, 113 &mut self,
96 method : SpikeMergingMethod<F>, 114 method : SpikeMergingMethod<F>,
97 value : G, 115 value : G,
98 fitness : H 116 fitness : H
99 ) -> V 117 ) -> (V, usize)
100 where G : Fn(&'_ Self) -> V, 118 where G : Fn(&'_ Self) -> V,
101 H : Fn(&'_ V) -> O, 119 H : Fn(&'_ V) -> O,
102 O : PartialOrd { 120 O : PartialOrd {
103 let initial_res = value(self); 121 let mut res = value(self);
104 let initial_fitness = fitness(&initial_res); 122 let initial_fitness = fitness(&res);
105 self.merge_spikes(method, |μ| { 123 let count = self.merge_spikes(method, |μ| {
106 let res = value(μ); 124 res = value(μ);
107 (fitness(&res) <= initial_fitness).then_some(res) 125 fitness(&res) <= initial_fitness
108 }).unwrap_or(initial_res) 126 });
127 (res, count)
109 } 128 }
110 129
111 /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). 130 /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm).
112 /// 131 ///
113 /// This method implements [`SpikeMerging::merge_spikes`] for 132 /// This method implements [`SpikeMerging::merge_spikes`] for
114 /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are 133 /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are
115 /// as for that method. 134 /// as for that method.
116 fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> 135 fn do_merge_spikes_radius<G>(&mut self, ρ : F, interp : bool, accept : G) -> usize
117 where G : Fn(&'_ Self) -> Option<V>; 136 where G : FnMut(&'_ Self) -> bool;
118 } 137 }
119 138
120 #[replace_float_literals(F::cast_from(literal))] 139 #[replace_float_literals(F::cast_from(literal))]
121 impl<F : Float, const N : usize> DiscreteMeasure<Loc<F, N>, F> { 140 impl<F : Float, const N : usize> DiscreteMeasure<Loc<F, N>, F> {
122 /// Attempts to merge spikes with indices `i` and `j`. 141 /// Attempts to merge spikes with indices `i` and `j`.
124 /// This assumes that the weights of the two spikes have already been checked not to be zero. 143 /// This assumes that the weights of the two spikes have already been checked not to be zero.
125 /// 144 ///
126 /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. 145 /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`].
127 /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its 146 /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its
128 /// return value. 147 /// return value.
129 fn attempt_merge<G, V>( 148 ///
149 /// Returns the index of `self.spikes` storing the new spike.
150 fn attempt_merge<G>(
130 &mut self, 151 &mut self,
131 res : &mut Option<V>,
132 i : usize, 152 i : usize,
133 j : usize, 153 j : usize,
134 accept : &G 154 interp : bool,
135 ) -> bool 155 accept : &mut G
136 where G : Fn(&'_ Self) -> Option<V> { 156 ) -> Option<usize>
157 where G : FnMut(&'_ Self) -> bool {
137 let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; 158 let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
138 let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; 159 let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];
139 160
140 // Merge inplace 161 if interp {
141 self.spikes[i].α = 0.0; 162 // Merge inplace
142 self.spikes[j].α = 0.0; 163 self.spikes[i].α = 0.0;
143 //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }); 164 self.spikes[j].α = 0.0;
144 self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) }); 165 let αia = αi.abs();
145 match accept(self) { 166 let αja = αj.abs();
146 some@Some(..) => { 167 self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αia + xj * αja) / (αia + αja) });
147 // Merge accepted, update our return value 168 if accept(self) {
148 *res = some; 169 Some(self.spikes.len()-1)
149 // On next iteration process the newly merged spike. 170 } else {
150 //indices[k+1] = self.spikes.len() - 1;
151 true
152 },
153 None => {
154 // Merge not accepted, restore modification 171 // Merge not accepted, restore modification
155 self.spikes[i].α = αi; 172 self.spikes[i].α = αi;
156 self.spikes[j].α = αj; 173 self.spikes[j].α = αj;
157 self.spikes.pop(); 174 self.spikes.pop();
158 false 175 None
159 } 176 }
160 } 177 } else {
161 } 178 // Attempt merge inplace, first combination
162 179 self.spikes[i].α = αi + αj;
163 /* 180 self.spikes[j].α = 0.0;
164 /// Attempts to merge spikes with indices i and j, acceptance through a delta. 181 if accept(self) {
165 fn attempt_merge_change<G, V>( 182 // Merge accepted
166 &mut self, 183 Some(i)
167 res : &mut Option<V>, 184 } else {
168 i : usize, 185 // Attempt merge inplace, second combination
169 j : usize,
170 accept_change : &G
171 ) -> bool
172 where G : Fn(&'_ Self) -> Option<V> {
173 let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
174 let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];
175 let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 };
176 let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into();
177
178 match accept_change(&λ) {
179 some@Some(..) => {
180 // Merge accepted, update our return value
181 *res = some;
182 self.spikes[i].α = 0.0; 186 self.spikes[i].α = 0.0;
183 self.spikes[j].α = 0.0; 187 self.spikes[j].α = αi + αj;
184 self.spikes.push(δ); 188 if accept(self) {
185 true 189 // Merge accepted
186 }, 190 Some(j)
187 None => { 191 } else {
188 false 192 // Merge not accepted, restore modification
189 } 193 self.spikes[i].α = αi;
190 } 194 self.spikes[j].α = αj;
191 }*/ 195 None
192 196 }
197 }
198 }
199 }
193 } 200 }
194 201
195 /// Sorts a vector of indices into `slice` by `compare`. 202 /// Sorts a vector of indices into `slice` by `compare`.
196 /// 203 ///
197 /// The closure `compare` operators on references to elements of `slice`. 204 /// The closure `compare` operators on references to elements of `slice`.
205 } 212 }
206 213
207 #[replace_float_literals(F::cast_from(literal))] 214 #[replace_float_literals(F::cast_from(literal))]
208 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> { 215 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> {
209 216
210 fn do_merge_spikes_radius<G, V>( 217 fn do_merge_spikes_radius<G>(
211 &mut self, 218 &mut self,
212 ρ : F, 219 ρ : F,
213 accept : G 220 interp : bool,
214 ) -> Option<V> 221 mut accept : G
215 where G : Fn(&'_ Self) -> Option<V> { 222 ) -> usize
223 where G : FnMut(&'_ Self) -> bool {
216 // Sort by coordinate into an indexing array. 224 // Sort by coordinate into an indexing array.
217 let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { 225 let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| {
218 let &Loc([x1]) = &δ1.x; 226 let &Loc([x1]) = &δ1.x;
219 let &Loc([x2]) = &δ2.x; 227 let &Loc([x2]) = &δ2.x;
220 // nan-ignoring ordering of floats 228 // nan-ignoring ordering of floats
221 NaNLeast(x1).cmp(&NaNLeast(x2)) 229 NaNLeast(x1).cmp(&NaNLeast(x2))
222 }); 230 });
223 231
224 // Initialise result 232 // Initialise result
225 let mut res = None; 233 let mut count = 0;
226 234
227 // Scan consecutive pairs and merge if close enough and accepted by `accept`. 235 // Scan consecutive pairs and merge if close enough and accepted by `accept`.
228 if indices.len() == 0 { 236 if indices.len() == 0 {
229 return res 237 return count
230 } 238 }
231 for k in 0..(indices.len()-1) { 239 for k in 0..(indices.len()-1) {
232 let i = indices[k]; 240 let i = indices[k];
233 let j = indices[k+1]; 241 let j = indices[k+1];
234 let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i]; 242 let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i];
235 let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j]; 243 let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j];
236 debug_assert!(xi <= xj); 244 debug_assert!(xi <= xj);
237 // If close enough, attempt merging 245 // If close enough, attempt merging
238 if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { 246 if αi != 0.0 && αj != 0.0 && xj <= xi + ρ {
239 if self.attempt_merge(&mut res, i, j, &accept) { 247 if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) {
240 indices[k+1] = self.spikes.len() - 1; 248 // For this to work (the debug_assert! to not trigger above), the new
241 } 249 // coordinate produced by attempt_merge has to be at most xj.
242 } 250 indices[k+1] = l;
243 } 251 count += 1
244 252 }
245 res 253 }
254 }
255
256 count
246 } 257 }
247 } 258 }
248 259
249 /// Orders `δ1` and `δ1` according to the first coordinate. 260 /// Orders `δ1` and `δ1` according to the first coordinate.
250 fn compare_first_coordinate<F : Float>( 261 fn compare_first_coordinate<F : Float>(
258 } 269 }
259 270
260 #[replace_float_literals(F::cast_from(literal))] 271 #[replace_float_literals(F::cast_from(literal))]
261 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> { 272 impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> {
262 273
263 fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> 274 fn do_merge_spikes_radius<G>(&mut self, ρ : F, interp : bool, mut accept : G) -> usize
264 where G : Fn(&'_ Self) -> Option<V> { 275 where G : FnMut(&'_ Self) -> bool {
265 // Sort by first coordinate into an indexing array. 276 // Sort by first coordinate into an indexing array.
266 let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); 277 let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate);
267 278
268 // Initialise result 279 // Initialise result
269 let mut res = None; 280 let mut count = 0;
270 let mut start_scan_2nd = 0; 281 let mut start_scan_2nd = 0;
271 282
272 // Scan in order 283 // Scan in order
273 if indices.len() == 0 { 284 if indices.len() == 0 {
274 return res 285 return count
275 } 286 }
276 for k in 0..indices.len()-1 { 287 for k in 0..indices.len()-1 {
277 let i = indices[k]; 288 let i = indices[k];
278 let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i]; 289 let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i];
279 290
322 } 333 }
323 } 334 }
324 335
325 // Attempt merging closest close-enough spike 336 // Attempt merging closest close-enough spike
326 if let Some((l, j, _)) = closest { 337 if let Some((l, j, _)) = closest {
327 if self.attempt_merge(&mut res, i, j, &accept) { 338 if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) {
328 // If merge was succesfull, make new spike candidate for merging. 339 // If merge was succesfull, make new spike candidate for merging.
329 indices[l] = self.spikes.len() - 1; 340 indices[l] = n;
341 count += 1;
330 let compare = |i, j| compare_first_coordinate(&self.spikes[i], 342 let compare = |i, j| compare_first_coordinate(&self.spikes[i],
331 &self.spikes[j]); 343 &self.spikes[j]);
332 // Re-sort relevant range of indices 344 // Re-sort relevant range of indices
333 if l < k { 345 if l < k {
334 indices[l..k].sort_by(|&i, &j| compare(i, j)); 346 indices[l..k].sort_by(|&i, &j| compare(i, j));
337 } 349 }
338 } 350 }
339 } 351 }
340 } 352 }
341 353
342 res 354 count
343 } 355 }
344 } 356 }
345 357

mercurial