Tue, 21 Mar 2023 20:31:01 +0200
Implement non-negativity constraints for the conditional gradient methods
| 0 | 1 | /*! |
| 2 | Spike merging heuristics for [`DiscreteMeasure`]s. | |
| 3 | ||
| 4 | This module primarily provides the [`SpikeMerging`] trait, and within it, | |
| 5 | the [`SpikeMerging::merge_spikes`] method. The trait is implemented on | |
| 6 | [`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`. | |
| 7 | */ | |
| 8 | ||
| 9 | use numeric_literals::replace_float_literals; | |
| 10 | use std::cmp::Ordering; | |
| 11 | use serde::{Serialize, Deserialize}; | |
| 12 | //use clap::builder::{PossibleValuesParser, PossibleValue}; | |
| 13 | use alg_tools::nanleast::NaNLeast; | |
| 14 | ||
| 15 | use crate::types::*; | |
| 16 | use super::delta::*; | |
| 17 | use super::discrete::*; | |
| 18 | ||
| 19 | /// Spike merging heuristic selection | |
| 20 | #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] | |
| 21 | #[allow(dead_code)] | |
| 22 | pub enum SpikeMergingMethod<F> { | |
| 23 | /// Try to merge spikes within a given radius of eachother | |
| 24 | HeuristicRadius(F), | |
| 25 | /// No merging | |
| 26 | None, | |
| 27 | } | |
| 28 | ||
| 29 | // impl<F : Float> SpikeMergingMethod<F> { | |
| 30 | // /// This is for [`clap`] to display command line help. | |
| 31 | // pub fn value_parser() -> PossibleValuesParser { | |
| 32 | // PossibleValuesParser::new([ | |
| 33 | // PossibleValue::new("none").help("No merging"), | |
| 34 | // PossibleValue::new("<radius>").help("Heuristic merging within indicated radius") | |
| 35 | // ]) | |
| 36 | // } | |
| 37 | // } | |
| 38 | ||
| 39 | impl<F : ClapFloat> std::fmt::Display for SpikeMergingMethod<F> { | |
| 40 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { | |
| 41 | match self { | |
| 42 | Self::None => write!(f, "none"), | |
| 43 | Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f), | |
| 44 | } | |
| 45 | } | |
| 46 | } | |
| 47 | ||
| 48 | impl<F : ClapFloat> std::str::FromStr for SpikeMergingMethod<F> { | |
| 49 | type Err = F::Err; | |
| 50 | ||
| 51 | fn from_str(s: &str) -> Result<Self, Self::Err> { | |
| 52 | if s == "none" { | |
| 53 | Ok(Self::None) | |
| 54 | } else { | |
| 55 | Ok(Self::HeuristicRadius(F::from_str(s)?)) | |
| 56 | } | |
| 57 | } | |
| 58 | } | |
| 59 | ||
| 60 | #[replace_float_literals(F::cast_from(literal))] | |
| 61 | impl<F : Float> Default for SpikeMergingMethod<F> { | |
| 62 | fn default() -> Self { | |
| 63 | SpikeMergingMethod::HeuristicRadius(0.02) | |
| 64 | } | |
| 65 | } | |
| 66 | ||
| 67 | /// Trait for dimension-dependent implementation of heuristic peak merging strategies. | |
| 68 | pub trait SpikeMerging<F> { | |
| 69 | /// Attempt spike merging according to [`SpikeMerging`] method. | |
| 70 | /// | |
| 71 | /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure | |
| 72 | /// `accept` if any merging is performed. The closure should accept as its only parameter a | |
| 73 | /// new candidate measure (it will generally be internally mutated `self`, although this is | |
| 74 | /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of | |
| 75 | /// an arbitrary value. This method will return that value for the *last* accepted merge, or | |
| 76 | /// [`None`] if no merge was accepted. | |
| 77 | /// | |
| 78 | /// This method is stable with respect to spike locations: on merge, the weight of existing | |
| 79 | /// spikes is set to zero, and a new one inserted at the end of the spike vector. | |
| 80 | fn merge_spikes<G, V>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> Option<V> | |
| 81 | where G : Fn(&'_ Self) -> Option<V> { | |
| 82 | match method { | |
| 83 | SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept), | |
| 84 | SpikeMergingMethod::None => None, | |
| 85 | } | |
| 86 | } | |
| 87 | ||
| 88 | /// Attempt to merge spikes based on a value and a fitness function. | |
| 89 | /// | |
| 90 | /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of | |
| 91 | /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value` | |
| 92 | // for a merge accepted by `fitness`. If no merge was accepted, `value` applied to the initial | |
| 93 | /// `self` is returned. | |
| 94 | fn merge_spikes_fitness<G, H, V, O>( | |
| 95 | &mut self, | |
| 96 | method : SpikeMergingMethod<F>, | |
| 97 | value : G, | |
| 98 | fitness : H | |
| 99 | ) -> V | |
| 100 | where G : Fn(&'_ Self) -> V, | |
| 101 | H : Fn(&'_ V) -> O, | |
| 102 | O : PartialOrd { | |
| 103 | let initial_res = value(self); | |
| 104 | let initial_fitness = fitness(&initial_res); | |
| 105 | self.merge_spikes(method, |μ| { | |
| 106 | let res = value(μ); | |
| 107 | (fitness(&res) <= initial_fitness).then_some(res) | |
| 108 | }).unwrap_or(initial_res) | |
| 109 | } | |
| 110 | ||
| 111 | /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm). | |
| 112 | /// | |
| 113 | /// This method implements [`SpikeMerging::merge_spikes`] for | |
| 114 | /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are | |
| 115 | /// as for that method. | |
| 116 | fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> | |
| 117 | where G : Fn(&'_ Self) -> Option<V>; | |
| 118 | } | |
| 119 | ||
| 120 | #[replace_float_literals(F::cast_from(literal))] | |
| 121 | impl<F : Float, const N : usize> DiscreteMeasure<Loc<F, N>, F> { | |
| 122 | /// Attempts to merge spikes with indices `i` and `j`. | |
| 123 | /// | |
| 124 | /// This assumes that the weights of the two spikes have already been checked not to be zero. | |
| 125 | /// | |
| 126 | /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`]. | |
| 127 | /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its | |
| 128 | /// return value. | |
| 129 | fn attempt_merge<G, V>( | |
| 130 | &mut self, | |
| 131 | res : &mut Option<V>, | |
| 132 | i : usize, | |
| 133 | j : usize, | |
| 134 | accept : &G | |
| 135 | ) -> bool | |
| 136 | where G : Fn(&'_ Self) -> Option<V> { | |
| 137 | let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; | |
| 138 | let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; | |
| 139 | ||
| 140 | // Merge inplace | |
| 141 | self.spikes[i].α = 0.0; | |
| 142 | self.spikes[j].α = 0.0; | |
| 143 | //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }); | |
| 144 | self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) }); | |
| 145 | match accept(self) { | |
| 146 | some@Some(..) => { | |
| 147 | // Merge accepted, update our return value | |
| 148 | *res = some; | |
| 149 | // On next iteration process the newly merged spike. | |
| 150 | //indices[k+1] = self.spikes.len() - 1; | |
| 151 | true | |
| 152 | }, | |
| 153 | None => { | |
| 154 | // Merge not accepted, restore modification | |
| 155 | self.spikes[i].α = αi; | |
| 156 | self.spikes[j].α = αj; | |
| 157 | self.spikes.pop(); | |
| 158 | false | |
| 159 | } | |
| 160 | } | |
| 161 | } | |
| 162 | ||
| 163 | /* | |
| 164 | /// Attempts to merge spikes with indices i and j, acceptance through a delta. | |
| 165 | fn attempt_merge_change<G, V>( | |
| 166 | &mut self, | |
| 167 | res : &mut Option<V>, | |
| 168 | i : usize, | |
| 169 | j : usize, | |
| 170 | accept_change : &G | |
| 171 | ) -> bool | |
| 172 | where G : Fn(&'_ Self) -> Option<V> { | |
| 173 | let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i]; | |
| 174 | let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j]; | |
| 175 | let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 }; | |
| 176 | let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into(); | |
| 177 | ||
| 178 | match accept_change(&λ) { | |
| 179 | some@Some(..) => { | |
| 180 | // Merge accepted, update our return value | |
| 181 | *res = some; | |
| 182 | self.spikes[i].α = 0.0; | |
| 183 | self.spikes[j].α = 0.0; | |
| 184 | self.spikes.push(δ); | |
| 185 | true | |
| 186 | }, | |
| 187 | None => { | |
| 188 | false | |
| 189 | } | |
| 190 | } | |
| 191 | }*/ | |
| 192 | ||
| 193 | } | |
| 194 | ||
| 195 | /// Sorts a vector of indices into `slice` by `compare`. | |
| 196 | /// | |
| 197 | /// The closure `compare` operators on references to elements of `slice`. | |
| 198 | /// Returns the sorted vector of indices into `slice`. | |
| 199 | pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize> | |
| 200 | where F : FnMut(&V, &V) -> Ordering | |
| 201 | { | |
| 202 | let mut indices = Vec::from_iter(0..slice.len()); | |
| 203 | indices.sort_by(|&i, &j| compare(&slice[i], &slice[j])); | |
| 204 | indices | |
| 205 | } | |
| 206 | ||
| 207 | #[replace_float_literals(F::cast_from(literal))] | |
| 208 | impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> { | |
| 209 | ||
| 210 | fn do_merge_spikes_radius<G, V>( | |
| 211 | &mut self, | |
| 212 | ρ : F, | |
| 213 | accept : G | |
| 214 | ) -> Option<V> | |
| 215 | where G : Fn(&'_ Self) -> Option<V> { | |
| 216 | // Sort by coordinate into an indexing array. | |
| 217 | let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| { | |
| 218 | let &Loc([x1]) = &δ1.x; | |
| 219 | let &Loc([x2]) = &δ2.x; | |
| 220 | // nan-ignoring ordering of floats | |
| 221 | NaNLeast(x1).cmp(&NaNLeast(x2)) | |
| 222 | }); | |
| 223 | ||
| 224 | // Initialise result | |
| 225 | let mut res = None; | |
| 226 | ||
| 227 | // Scan consecutive pairs and merge if close enough and accepted by `accept`. | |
| 228 | if indices.len() == 0 { | |
| 229 | return res | |
| 230 | } | |
| 231 | for k in 0..(indices.len()-1) { | |
| 232 | let i = indices[k]; | |
| 233 | let j = indices[k+1]; | |
| 234 | let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i]; | |
| 235 | let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j]; | |
| 236 | debug_assert!(xi <= xj); | |
| 237 | // If close enough, attempt merging | |
| 238 | if αi != 0.0 && αj != 0.0 && xj <= xi + ρ { | |
| 239 | if self.attempt_merge(&mut res, i, j, &accept) { | |
| 240 | indices[k+1] = self.spikes.len() - 1; | |
| 241 | } | |
| 242 | } | |
| 243 | } | |
| 244 | ||
| 245 | res | |
| 246 | } | |
| 247 | } | |
| 248 | ||
| 249 | /// Orders `δ1` and `δ1` according to the first coordinate. | |
| 250 | fn compare_first_coordinate<F : Float>( | |
| 251 | δ1 : &DeltaMeasure<Loc<F, 2>, F>, | |
| 252 | δ2 : &DeltaMeasure<Loc<F, 2>, F> | |
| 253 | ) -> Ordering { | |
| 254 | let &Loc([x11, ..]) = &δ1.x; | |
| 255 | let &Loc([x21, ..]) = &δ2.x; | |
| 256 | // nan-ignoring ordering of floats | |
| 257 | NaNLeast(x11).cmp(&NaNLeast(x21)) | |
| 258 | } | |
| 259 | ||
| 260 | #[replace_float_literals(F::cast_from(literal))] | |
| 261 | impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> { | |
| 262 | ||
| 263 | fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V> | |
| 264 | where G : Fn(&'_ Self) -> Option<V> { | |
| 265 | // Sort by first coordinate into an indexing array. | |
| 266 | let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate); | |
| 267 | ||
| 268 | // Initialise result | |
| 269 | let mut res = None; | |
| 270 | let mut start_scan_2nd = 0; | |
| 271 | ||
| 272 | // Scan in order | |
| 273 | if indices.len() == 0 { | |
| 274 | return res | |
| 275 | } | |
| 276 | for k in 0..indices.len()-1 { | |
| 277 | let i = indices[k]; | |
| 278 | let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i]; | |
| 279 | ||
| 280 | if αi == 0.0 { | |
| 281 | // Nothin to be done if the weight is already zero | |
| 282 | continue | |
| 283 | } | |
| 284 | ||
| 285 | let mut closest = None; | |
| 286 | ||
| 287 | // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd` | |
| 288 | // the smallest invalid merging index on the previous loop iteration, because a | |
| 289 | // the _closest_ mergeable spike might have index less than `k` in `indices`, and a | |
| 290 | // merge with it might have not been attempted with this spike if a different closer | |
| 291 | // spike was discovered based on the second coordinate. | |
| 292 | 'scan_2nd: for l in (start_scan_2nd+1)..indices.len() { | |
| 293 | if l == k { | |
| 294 | // Do not attempt to merge a spike with itself | |
| 295 | continue | |
| 296 | } | |
| 297 | let j = indices[l]; | |
| 298 | let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j]; | |
| 299 | ||
| 300 | if xj1 < xi1 - ρ { | |
| 301 | // Spike `j = indices[l]` has too low first coordinate. Update starting index | |
| 302 | // for next iteration, and continue scanning. | |
| 303 | start_scan_2nd = l; | |
| 304 | continue 'scan_2nd | |
| 305 | } else if xj1 > xi1 + ρ { | |
| 306 | // Break out: spike `j = indices[l]` has already too high first coordinate, no | |
| 307 | // more close enough spikes can be found due to the sorting of `indices`. | |
| 308 | break 'scan_2nd | |
| 309 | } | |
| 310 | ||
| 311 | // If also second coordinate is close enough, attempt merging if closer than | |
| 312 | // previously discovered mergeable spikes. | |
| 313 | let d2 = (xi2-xj2).abs(); | |
| 314 | if αj != 0.0 && d2 <= ρ { | |
| 315 | let r1 = xi1-xj1; | |
| 316 | let d = (d2*d2 + r1*r1).sqrt(); | |
| 317 | match closest { | |
| 318 | None => closest = Some((l, j, d)), | |
| 319 | Some((_, _, r)) if r > d => closest = Some((l, j, d)), | |
| 320 | _ => {}, | |
| 321 | } | |
| 322 | } | |
| 323 | } | |
| 324 | ||
| 325 | // Attempt merging closest close-enough spike | |
| 326 | if let Some((l, j, _)) = closest { | |
| 327 | if self.attempt_merge(&mut res, i, j, &accept) { | |
| 328 | // If merge was succesfull, make new spike candidate for merging. | |
| 329 | indices[l] = self.spikes.len() - 1; | |
| 330 | let compare = |i, j| compare_first_coordinate(&self.spikes[i], | |
| 331 | &self.spikes[j]); | |
| 332 | // Re-sort relevant range of indices | |
| 333 | if l < k { | |
| 334 | indices[l..k].sort_by(|&i, &j| compare(i, j)); | |
| 335 | } else { | |
| 336 | indices[k+1..=l].sort_by(|&i, &j| compare(i, j)); | |
| 337 | } | |
| 338 | } | |
| 339 | } | |
| 340 | } | |
| 341 | ||
| 342 | res | |
| 343 | } | |
| 344 | } | |
| 345 |