Tue, 25 Oct 2022 23:05:40 +0300
Added NormExponent trait for exponents of norms
| 0 | 1 | |
| 2 | use std::collections::BinaryHeap; | |
| 3 | use std::cmp::{PartialOrd,Ord,Ordering,Ordering::*,max}; | |
| 4 | use std::rc::Rc; | |
| 5 | use std::marker::PhantomData; | |
| 6 | use crate::types::*; | |
| 5 | 7 | use crate::nanleast::NaNLeast; |
| 8 | use crate::sets::Cube; | |
| 0 | 9 | use super::support::*; |
| 10 | use super::bt::*; | |
| 11 | use super::aggregator::*; | |
| 12 | ||
| 13 | /// Trait for sorting [`Aggregator`]s for [`BT`] refinement. | |
| 14 | /// | |
| 5 | 15 | /// The sorting involves two sorting keys, the “upper” and the “lower” key. Any [`BT`] nodes |
| 0 | 16 | /// with upper key less the lower key of another are discarded from the refinement process. |
| 5 | 17 | /// Nodes with the highest upper sorting key are picked for refinement. |
| 0 | 18 | pub trait AggregatorSorting { |
| 19 | // Priority | |
| 20 | type Agg : Aggregator; | |
| 21 | type Sort : Ord + Copy + std::fmt::Debug; | |
| 22 | ||
| 23 | /// Returns lower sorting key | |
| 24 | fn sort_lower(aggregator : &Self::Agg) -> Self::Sort; | |
| 25 | ||
| 26 | /// Returns upper sorting key | |
| 27 | fn sort_upper(aggregator : &Self::Agg) -> Self::Sort; | |
| 28 | ||
| 5 | 29 | /// Returns a sorting key that is less than any other sorting key. |
| 0 | 30 | fn bottom() -> Self::Sort; |
| 31 | } | |
| 32 | ||
| 33 | /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the upper/lower key. | |
| 34 | /// | |
| 35 | /// See [`LowerBoundSorting`] for the opposite ordering. | |
| 36 | pub struct UpperBoundSorting<F : Float>(PhantomData<F>); | |
| 37 | ||
| 38 | /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the lower/upper key. | |
| 39 | /// | |
| 40 | /// See [`UpperBoundSorting`] for the opposite ordering. | |
| 41 | pub struct LowerBoundSorting<F : Float>(PhantomData<F>); | |
| 42 | ||
| 43 | impl<F : Float> AggregatorSorting for UpperBoundSorting<F> { | |
| 44 | type Agg = Bounds<F>; | |
| 45 | type Sort = NaNLeast<F>; | |
| 46 | ||
| 47 | #[inline] | |
| 48 | fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.lower()) } | |
| 49 | ||
| 50 | #[inline] | |
| 51 | fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.upper()) } | |
| 52 | ||
| 53 | #[inline] | |
| 54 | fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } | |
| 55 | } | |
| 56 | ||
| 57 | ||
| 58 | impl<F : Float> AggregatorSorting for LowerBoundSorting<F> { | |
| 59 | type Agg = Bounds<F>; | |
| 60 | type Sort = NaNLeast<F>; | |
| 61 | ||
| 62 | #[inline] | |
| 63 | fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.lower()) } | |
| 64 | ||
| 65 | #[inline] | |
| 66 | fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.upper()) } | |
| 67 | ||
| 68 | #[inline] | |
| 69 | fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } | |
| 70 | } | |
| 71 | ||
| 5 | 72 | /// Return type of [`Refiner::refine`]. |
| 73 | /// | |
| 74 | /// The parameter `R` is the result type of the refiner acting on an [`Aggregator`] of type `A`. | |
| 0 | 75 | pub enum RefinerResult<A : Aggregator, R> { |
| 5 | 76 | /// Indicates an insufficiently refined state: the [`BT`] needs to be further refined. |
| 0 | 77 | NeedRefinement, |
| 78 | /// Indicates a certain result `R`, stop refinement immediately. | |
| 79 | Certain(R), | |
| 80 | /// Indicates an uncertain result: continue refinement until candidates have been exhausted | |
| 81 | /// or a certain result found. | |
| 82 | Uncertain(A, R) | |
| 83 | } | |
| 84 | ||
| 85 | use RefinerResult::*; | |
| 86 | ||
| 5 | 87 | /// A `Refiner` is used to search a [`BT`], refining the subdivision when necessary. |
| 88 | /// | |
| 89 | /// The search is performed by [`BTSearch::search_and_refine`]. | |
| 90 | /// The `Refiner` is used to determine whether an [`Aggregator`] `A` stored in the [`BT`] is | |
| 91 | /// sufficiently refined within a [`Cube`], and in such a case, produce a desired result (e.g. | |
| 92 | /// a maximum value of a function). | |
| 0 | 93 | pub trait Refiner<F : Float, A, G, const N : usize> |
| 94 | where F : Num, | |
| 95 | A : Aggregator, | |
| 96 | G : SupportGenerator<F, N> { | |
| 97 | ||
| 5 | 98 | /// The result type of the refiner |
| 0 | 99 | type Result : std::fmt::Debug; |
| 5 | 100 | /// The sorting to be employed by [`BTSearch::search_and_refine`] on node aggregators |
| 101 | /// to detemrine node priority. | |
| 0 | 102 | type Sorting : AggregatorSorting<Agg = A>; |
| 103 | ||
| 5 | 104 | /// Determines whether `aggregator` is sufficiently refined within `domain`. |
| 105 | /// | |
| 106 | /// If the aggregator is sufficiently refined that the desired `Self::Result` can be produced, | |
| 107 | /// a [`RefinerResult`]`::Certain` or `Uncertain` should be returned, depending on | |
| 108 | /// the confidence of the solution. In the uncertain case an improved aggregator should also | |
| 109 | /// be included. If the result cannot be produced, `NeedRefinement` should be | |
| 110 | /// returned. | |
| 111 | /// | |
| 112 | /// For example, if the refiner is used to minimise a function presented by the `BT`, | |
| 113 | /// an `Uncertain` result can be used to return a local maximum of the function on `domain` | |
| 114 | /// The result can be claimed `Certain` if it is a global maximum. In that case the | |
| 115 | /// refinment will stop immediately. A `NeedRefinement` result indicates that the `aggregator` | |
| 116 | /// and/or `domain` are not sufficiently refined to compute a lcoal maximum of sufficient | |
| 117 | /// quality. | |
| 118 | /// | |
| 119 | /// The vector `data` stored all the data of the [`BT`] in the node corresponding to `domain`. | |
| 120 | /// The `generator` can be used to convert `data` into [`Support`]s. The parameter `step` | |
| 121 | /// counts the calls to `refine`, and can be used to stop the refinement when a maximum | |
| 122 | /// number of steps is reached. | |
| 0 | 123 | fn refine( |
| 124 | &self, | |
| 125 | aggregator : &A, | |
| 126 | domain : &Cube<F, N>, | |
| 127 | data : &Vec<G::Id>, | |
| 128 | generator : &G, | |
| 129 | step : usize, | |
| 130 | ) -> RefinerResult<A, Self::Result>; | |
| 131 | } | |
| 132 | ||
| 133 | /// Structure for tracking the refinement process in a [`BinaryHeap`]. | |
| 134 | struct RefinementInfo<'a, F, D, A, S, RResult, const N : usize, const P : usize> | |
| 135 | where F : Float, | |
| 136 | D : 'static +, | |
| 137 | A : Aggregator, | |
| 138 | S : AggregatorSorting<Agg = A> { | |
| 139 | cube : Cube<F, N>, | |
| 140 | node : &'a mut Node<F, D, A, N, P>, | |
| 141 | refiner_info : Option<(A, RResult)>, | |
| 142 | sorting : PhantomData<S>, | |
| 143 | } | |
| 144 | ||
| 145 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> | |
| 146 | RefinementInfo<'a, F, D, A, S, RResult, N, P> | |
| 147 | where F : Float, | |
| 148 | D : 'static, | |
| 149 | A : Aggregator, | |
| 150 | S : AggregatorSorting<Agg = A> { | |
| 151 | ||
| 152 | #[inline] | |
| 153 | fn aggregator(&self) -> &A { | |
| 154 | match self.refiner_info { | |
| 155 | Some((ref agg, _)) => agg, | |
| 156 | None => &self.node.aggregator, | |
| 157 | } | |
| 158 | } | |
| 159 | ||
| 160 | #[inline] | |
| 161 | fn sort_lower(&self) -> S::Sort { | |
| 162 | S::sort_lower(self.aggregator()) | |
| 163 | } | |
| 164 | ||
| 165 | #[inline] | |
| 166 | fn sort_upper(&self) -> S::Sort { | |
| 167 | S::sort_upper(self.aggregator()) | |
| 168 | } | |
| 169 | } | |
| 170 | ||
| 171 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialEq | |
| 172 | for RefinementInfo<'a, F, D, A, S, RResult, N, P> | |
| 173 | where F : Float, | |
| 174 | D : 'static, | |
| 175 | A : Aggregator, | |
| 176 | S : AggregatorSorting<Agg = A> { | |
| 177 | ||
| 178 | #[inline] | |
| 179 | fn eq(&self, other : &Self) -> bool { self.cmp(other) == Equal } | |
| 180 | } | |
| 181 | ||
| 182 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialOrd | |
| 183 | for RefinementInfo<'a, F, D, A, S, RResult, N, P> | |
| 184 | where F : Float, | |
| 185 | D : 'static, | |
| 186 | A : Aggregator, | |
| 187 | S : AggregatorSorting<Agg = A> { | |
| 188 | ||
| 189 | #[inline] | |
| 190 | fn partial_cmp(&self, other : &Self) -> Option<Ordering> { Some(self.cmp(other)) } | |
| 191 | } | |
| 192 | ||
| 193 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Eq | |
| 194 | for RefinementInfo<'a, F, D, A, S, RResult, N, P> | |
| 195 | where F : Float, | |
| 196 | D : 'static, | |
| 197 | A : Aggregator, | |
| 198 | S : AggregatorSorting<Agg = A> { | |
| 199 | } | |
| 200 | ||
| 201 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Ord | |
| 202 | for RefinementInfo<'a, F, D, A, S, RResult, N, P> | |
| 203 | where F : Float, | |
| 204 | D : 'static, | |
| 205 | A : Aggregator, | |
| 206 | S : AggregatorSorting<Agg = A> { | |
| 207 | ||
| 208 | #[inline] | |
| 209 | fn cmp(&self, other : &Self) -> Ordering { | |
| 210 | let agg1 = self.aggregator(); | |
| 211 | let agg2 = other.aggregator(); | |
| 212 | match S::sort_upper(agg1).cmp(&S::sort_upper(agg2)) { | |
| 213 | Equal => S::sort_lower(agg1).cmp(&S::sort_lower(agg2)), | |
| 214 | order => order, | |
| 215 | } | |
| 216 | } | |
| 217 | } | |
| 218 | ||
| 5 | 219 | /// This is a container for a [`BinaryHeap`] of [`RefinementInfo`]s together with tracking of |
| 220 | /// the greatest lower bound of the [`Aggregator`]s of the [`Node`]s therein accroding to | |
| 221 | /// chosen [`AggregatorSorting`]. | |
| 222 | struct HeapContainer<'a, F, D, A, S, RResult, const N : usize, const P : usize> | |
| 0 | 223 | where F : Float, |
| 224 | D : 'static + Copy, | |
| 225 | Const<P> : BranchCount<N>, | |
| 226 | A : Aggregator, | |
| 227 | S : AggregatorSorting<Agg = A> { | |
| 228 | heap : BinaryHeap<RefinementInfo<'a, F, D, A, S, RResult, N, P>>, | |
| 229 | glb : S::Sort, | |
| 230 | glb_stale_counter : usize, | |
| 231 | stale_insert_counter : usize, | |
| 232 | } | |
| 233 | ||
| 234 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> | |
| 235 | HeapContainer<'a, F, D, A, S, RResult, N, P> | |
| 236 | where F : Float, | |
| 237 | D : 'static + Copy, | |
| 238 | Const<P> : BranchCount<N>, | |
| 239 | A : Aggregator, | |
| 240 | S : AggregatorSorting<Agg = A> { | |
| 241 | ||
| 5 | 242 | /// Push `ri` into the [`BinaryHeap`]. Do greatest lower bound maintenance. |
| 0 | 243 | fn push(&mut self, ri : RefinementInfo<'a, F, D, A, S, RResult, N, P>) { |
| 244 | if ri.sort_upper() >= self.glb { | |
| 245 | let l = ri.sort_lower(); | |
| 246 | self.heap.push(ri); | |
| 247 | self.glb = self.glb.max(l); | |
| 248 | if self.glb_stale_counter > 0 { | |
| 249 | self.stale_insert_counter += 1; | |
| 250 | } | |
| 251 | } | |
| 252 | } | |
| 253 | } | |
| 254 | ||
| 255 | impl<F : Float, D : 'static + Copy, A, const N : usize, const P : usize> | |
| 256 | Branches<F,D,A,N,P> | |
| 257 | where Const<P> : BranchCount<N>, | |
| 258 | A : Aggregator { | |
| 259 | ||
| 260 | /// Stage all subnodes of `self` into the refinement queue [`container`]. | |
| 261 | fn stage_refine<'a, S, RResult>( | |
| 262 | &'a mut self, | |
| 263 | domain : Cube<F,N>, | |
| 264 | container : &mut HeapContainer<'a, F, D, A, S, RResult, N, P>, | |
| 265 | ) where S : AggregatorSorting<Agg = A> { | |
| 266 | // Insert all subnodes into the refinement heap. | |
| 267 | for (node, subcube) in self.nodes_and_cubes_mut(&domain) { | |
| 268 | container.push(RefinementInfo { | |
| 269 | cube : subcube, | |
| 270 | node : node, | |
| 271 | refiner_info : None, | |
| 272 | sorting : PhantomData, | |
| 273 | }); | |
| 274 | } | |
| 275 | } | |
| 276 | } | |
| 277 | ||
| 278 | ||
| 279 | impl<F : Float, D : 'static + Copy, A, const N : usize, const P : usize> | |
| 280 | Node<F,D,A,N,P> | |
| 281 | where Const<P> : BranchCount<N>, | |
| 282 | A : Aggregator { | |
| 283 | ||
| 284 | /// If `self` is a leaf node, uses the `refiner` to determine whether further subdivision | |
| 285 | /// is required to get a sufficiently refined solution for the problem the refiner is used | |
| 286 | /// to solve. If the refiner returns [`RefinerResult::Certain`] result, it is returned. | |
| 287 | /// If [`RefinerResult::Uncertain`] is returned, the leaf is inserted back into the refinement | |
| 288 | /// queue `container`. If `self` is a branch, its subnodes are staged into `container` using | |
| 289 | /// [`Branches::stage_refine`]. | |
| 5 | 290 | /// |
| 291 | /// `domain`, as usual, indicates the spatial area corresponding to `self`. | |
| 0 | 292 | fn search_and_refine<'a, 'b, R, G>( |
| 293 | &'a mut self, | |
| 294 | domain : Cube<F,N>, | |
| 295 | refiner : &R, | |
| 296 | generator : &G, | |
| 297 | container : &'b mut HeapContainer<'a, F, D, A, R::Sorting, R::Result, N, P>, | |
| 298 | step : usize | |
| 299 | ) -> Option<R::Result> | |
| 300 | where R : Refiner<F, A, G, N>, | |
| 301 | G : SupportGenerator<F, N, Id=D>, | |
| 302 | G::SupportType : LocalAnalysis<F, A, N> { | |
| 303 | ||
| 304 | // The “complex” repeated pattern matching here is forced by mutability requirements. | |
| 305 | ||
| 306 | // Refine a leaf. | |
| 307 | let res = if let NodeOption::Leaf(ref v) = &mut self.data { | |
| 308 | let res = refiner.refine(&self.aggregator, &domain, v, generator, step); | |
| 309 | if let NeedRefinement = res { | |
| 310 | // The refiner has deemed the leaf unsufficiently refined, so subdivide | |
| 311 | // it and add the new nodes into the refinement priority heap. | |
| 312 | // We start iterating from the end to mix support_hint a bit. | |
| 313 | let mut it = v.iter().rev(); | |
| 314 | if let Some(&d) = it.next() { | |
| 315 | // Construct new Branches | |
| 316 | let support = generator.support_for(d); | |
| 317 | let b = Rc::new({ | |
| 318 | let mut b0 = Branches::new_with(&domain, &support); | |
| 319 | b0.insert(&domain, d, Const::<1>, &support); | |
| 320 | for &d in it { | |
| 321 | let support = generator.support_for(d); | |
| 322 | // TODO: can we be smarter than just refining one level? | |
| 323 | b0.insert(&domain, d, Const::<1>, &support); | |
| 324 | } | |
| 325 | b0 | |
| 326 | }); | |
| 327 | // Update current node | |
| 328 | self.aggregator.summarise(b.aggregators()); | |
| 329 | self.data = NodeOption::Branches(b); | |
| 330 | // The branches will be inserted into the refinement priority queue below. | |
| 331 | } | |
| 332 | } | |
| 333 | res | |
| 334 | } else { | |
| 335 | NeedRefinement | |
| 336 | }; | |
| 337 | ||
| 338 | if let Uncertain(agg, val) = res { | |
| 339 | // The refiner gave an undertain result. Push a leaf back into the refinement queue | |
| 340 | // with the new refined aggregator and custom return value. It will be popped and | |
| 341 | // returned in the loop of [`BT::search_and_refine`] when there are no unrefined | |
| 342 | // candidates that could potentially be better according to their basic aggregator. | |
| 343 | container.push(RefinementInfo { | |
| 344 | cube : domain, | |
| 345 | node : self, | |
| 346 | refiner_info : Some((agg, val)), | |
| 347 | sorting : PhantomData, | |
| 348 | }); | |
| 349 | None | |
| 350 | } else if let Certain(val) = res { | |
| 351 | // The refiner gave a certain result so return it to allow early termination | |
| 352 | Some(val) | |
| 353 | } else if let NodeOption::Branches(ref mut b) = &mut self.data { | |
| 354 | // Insert branches into refinement priority queue. | |
| 355 | Rc::make_mut(b).stage_refine(domain, container); | |
| 356 | None | |
| 357 | } else { | |
| 358 | None | |
| 359 | } | |
| 360 | } | |
| 361 | } | |
| 362 | ||
| 5 | 363 | /// Interface trait to a refining search on a [`BT`]. |
| 364 | /// | |
| 365 | /// This can be removed and the methods implemented directly on [`BT`] once Rust's const generics | |
| 366 | /// are flexible enough to allow fixing `P=pow(2, N)`. | |
| 0 | 367 | pub trait BTSearch<F, const N : usize> : BTImpl<F, N> |
| 368 | where F : Float { | |
| 369 | ||
| 5 | 370 | /// Perform a search on [`Self`], as determined by `refiner`. |
| 371 | /// | |
| 372 | /// Nodes are inserted in a priority queue and processed in the order determined by the | |
| 373 | /// [`AggregatorSorting`] [`Refiner::Sorting`]. Leaf nodes are subdivided until the refiner | |
| 374 | /// decides that a sufficiently refined leaf node has been found, as determined by either the | |
| 375 | /// refiner returning a [`RefinerResult::Certain`] result, or a previous | |
| 376 | /// [`RefinerResult::Uncertain`] result is found again at the top of the priority queue. | |
| 377 | /// | |
| 378 | /// The `generator` converts [`BTImpl::Data`] stored in the bisection tree into a [`Support`]. | |
| 0 | 379 | fn search_and_refine<'b, R, G>( |
| 380 | &'b mut self, | |
| 381 | refiner : &R, | |
| 382 | generator : &G, | |
| 383 | ) -> Option<R::Result> | |
| 384 | where R : Refiner<F, Self::Agg, G, N>, | |
| 385 | G : SupportGenerator<F, N, Id=Self::Data>, | |
| 386 | G::SupportType : LocalAnalysis<F, Self::Agg, N>; | |
| 387 | } | |
| 388 | ||
| 389 | // Needed to get access to a Node without a trait interface. | |
| 390 | macro_rules! impl_btsearch { | |
| 391 | ($($n:literal)*) => { $( | |
| 392 | impl<'a, M, F, D, A> | |
| 393 | BTSearch<F, $n> | |
| 394 | for BT<M,F,D,A,$n> | |
| 395 | where //Self : BTImpl<F,$n,Data=D,Agg=A, Depth=M>, // <== automatically deduce to be implemented | |
| 396 | M : Depth, | |
| 397 | F : Float, | |
| 398 | A : 'a + Aggregator, | |
| 399 | D : 'static + Copy + std::fmt::Debug { | |
| 400 | fn search_and_refine<'b, R, G>( | |
| 401 | &'b mut self, | |
| 402 | refiner : &R, | |
| 403 | generator : &G, | |
| 404 | ) -> Option<R::Result> | |
| 405 | where R : Refiner<F, A, G, $n>, | |
| 406 | G : SupportGenerator<F, $n, Id=D>, | |
| 407 | G::SupportType : LocalAnalysis<F, A, $n> { | |
| 408 | let mut container = HeapContainer { | |
| 409 | heap : BinaryHeap::new(), | |
| 410 | glb : R::Sorting::bottom(), | |
| 411 | glb_stale_counter : 0, | |
| 412 | stale_insert_counter : 0, | |
| 413 | }; | |
| 414 | container.push(RefinementInfo { | |
| 415 | cube : self.domain, | |
| 416 | node : &mut self.topnode, | |
| 417 | refiner_info : None, | |
| 418 | sorting : PhantomData, | |
| 419 | }); | |
| 420 | let mut step = 0; | |
| 421 | while let Some(ri) = container.heap.pop() { | |
| 422 | if let Some((_, result)) = ri.refiner_info { | |
| 423 | // Terminate based on a “best possible” result. | |
| 424 | return Some(result) | |
| 425 | } | |
| 426 | ||
| 427 | if ri.sort_lower() >= container.glb { | |
| 428 | container.glb_stale_counter += 1; | |
| 429 | if container.stale_insert_counter + container.glb_stale_counter | |
| 430 | > container.heap.len()/2 { | |
| 431 | // GLB propery no longer correct. | |
| 432 | match container.heap.iter().map(|ri| ri.sort_lower()).reduce(max) { | |
| 433 | Some(glb) => { | |
| 434 | container.glb = glb; | |
| 435 | container.heap.retain(|ri| ri.sort_upper() >= glb); | |
| 436 | }, | |
| 437 | None => { | |
| 438 | container.glb = R::Sorting::bottom() | |
| 439 | } | |
| 440 | } | |
| 441 | container.glb_stale_counter = 0; | |
| 442 | container.stale_insert_counter = 0; | |
| 443 | } | |
| 444 | } | |
| 445 | ||
| 446 | let res = ri.node.search_and_refine(ri.cube, refiner, generator, | |
| 447 | &mut container, step); | |
| 448 | if let Some(_) = res { | |
| 449 | // Terminate based on a certain result from the refiner | |
| 450 | return res | |
| 451 | } | |
| 452 | ||
| 453 | step += 1; | |
| 454 | } | |
| 455 | None | |
| 456 | } | |
| 457 | } | |
| 458 | )* } | |
| 459 | } | |
| 460 | ||
| 461 | impl_btsearch!(1 2 3 4); | |
| 462 |