Sat, 22 Oct 2022 22:28:04 +0300
Convert iteration utilities to GATs
0 | 1 | |
2 | use std::collections::BinaryHeap; | |
3 | use std::cmp::{PartialOrd,Ord,Ordering,Ordering::*,max}; | |
4 | use std::rc::Rc; | |
5 | use std::marker::PhantomData; | |
6 | use crate::types::*; | |
7 | use super::support::*; | |
8 | use super::bt::*; | |
9 | use super::aggregator::*; | |
10 | use crate::nanleast::NaNLeast; | |
11 | ||
12 | /// Trait for sorting [`Aggregator`]s for [`BT`] refinement. | |
13 | /// | |
14 | /// The sorting involves two sorting keys, the “upper” and the “lower” key. Any [`BT`] branches | |
15 | /// with upper key less the lower key of another are discarded from the refinement process. | |
16 | pub trait AggregatorSorting { | |
17 | // Priority | |
18 | type Agg : Aggregator; | |
19 | type Sort : Ord + Copy + std::fmt::Debug; | |
20 | ||
21 | /// Returns lower sorting key | |
22 | fn sort_lower(aggregator : &Self::Agg) -> Self::Sort; | |
23 | ||
24 | /// Returns upper sorting key | |
25 | fn sort_upper(aggregator : &Self::Agg) -> Self::Sort; | |
26 | ||
27 | /// Returns bottom sorting key. | |
28 | fn bottom() -> Self::Sort; | |
29 | } | |
30 | ||
31 | /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the upper/lower key. | |
32 | /// | |
33 | /// See [`LowerBoundSorting`] for the opposite ordering. | |
34 | pub struct UpperBoundSorting<F : Float>(PhantomData<F>); | |
35 | ||
36 | /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the lower/upper key. | |
37 | /// | |
38 | /// See [`UpperBoundSorting`] for the opposite ordering. | |
39 | pub struct LowerBoundSorting<F : Float>(PhantomData<F>); | |
40 | ||
41 | impl<F : Float> AggregatorSorting for UpperBoundSorting<F> { | |
42 | type Agg = Bounds<F>; | |
43 | type Sort = NaNLeast<F>; | |
44 | ||
45 | #[inline] | |
46 | fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.lower()) } | |
47 | ||
48 | #[inline] | |
49 | fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.upper()) } | |
50 | ||
51 | #[inline] | |
52 | fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } | |
53 | } | |
54 | ||
55 | ||
56 | impl<F : Float> AggregatorSorting for LowerBoundSorting<F> { | |
57 | type Agg = Bounds<F>; | |
58 | type Sort = NaNLeast<F>; | |
59 | ||
60 | #[inline] | |
61 | fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.lower()) } | |
62 | ||
63 | #[inline] | |
64 | fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.upper()) } | |
65 | ||
66 | #[inline] | |
67 | fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } | |
68 | } | |
69 | ||
70 | /// Result type of [`Refiner::refine`] for a refiner producing a result of type `R` acting on | |
71 | /// an [`Aggregator`] of type `A`. | |
72 | pub enum RefinerResult<A : Aggregator, R> { | |
73 | /// Indicates in insufficiently refined state: the [`BT`] needs to be further refined. | |
74 | NeedRefinement, | |
75 | /// Indicates a certain result `R`, stop refinement immediately. | |
76 | Certain(R), | |
77 | /// Indicates an uncertain result: continue refinement until candidates have been exhausted | |
78 | /// or a certain result found. | |
79 | Uncertain(A, R) | |
80 | } | |
81 | ||
82 | use RefinerResult::*; | |
83 | ||
84 | /// A `Refiner` is used to determine whether an [`Aggregator`] `A` is sufficiently refined within | |
85 | /// a [`Cube`] of a [`BT`], and in such a case, produce a desired result (e.g. a maximum value of | |
86 | /// a function). | |
87 | pub trait Refiner<F : Float, A, G, const N : usize> | |
88 | where F : Num, | |
89 | A : Aggregator, | |
90 | G : SupportGenerator<F, N> { | |
91 | ||
92 | type Result : std::fmt::Debug; | |
93 | type Sorting : AggregatorSorting<Agg = A>; | |
94 | ||
95 | /// Determines whether `aggregator` is sufficiently refined within `cube`. | |
96 | /// Should return a possibly refined version of the `aggregator` and an arbitrary value of | |
97 | /// the result type of the refiner. | |
98 | fn refine( | |
99 | &self, | |
100 | aggregator : &A, | |
101 | domain : &Cube<F, N>, | |
102 | data : &Vec<G::Id>, | |
103 | generator : &G, | |
104 | step : usize, | |
105 | ) -> RefinerResult<A, Self::Result>; | |
106 | } | |
107 | ||
108 | /// Structure for tracking the refinement process in a [`BinaryHeap`]. | |
109 | struct RefinementInfo<'a, F, D, A, S, RResult, const N : usize, const P : usize> | |
110 | where F : Float, | |
111 | D : 'static +, | |
112 | A : Aggregator, | |
113 | S : AggregatorSorting<Agg = A> { | |
114 | cube : Cube<F, N>, | |
115 | node : &'a mut Node<F, D, A, N, P>, | |
116 | refiner_info : Option<(A, RResult)>, | |
117 | sorting : PhantomData<S>, | |
118 | } | |
119 | ||
120 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> | |
121 | RefinementInfo<'a, F, D, A, S, RResult, N, P> | |
122 | where F : Float, | |
123 | D : 'static, | |
124 | A : Aggregator, | |
125 | S : AggregatorSorting<Agg = A> { | |
126 | ||
127 | #[inline] | |
128 | fn aggregator(&self) -> &A { | |
129 | match self.refiner_info { | |
130 | Some((ref agg, _)) => agg, | |
131 | None => &self.node.aggregator, | |
132 | } | |
133 | } | |
134 | ||
135 | #[inline] | |
136 | fn sort_lower(&self) -> S::Sort { | |
137 | S::sort_lower(self.aggregator()) | |
138 | } | |
139 | ||
140 | #[inline] | |
141 | fn sort_upper(&self) -> S::Sort { | |
142 | S::sort_upper(self.aggregator()) | |
143 | } | |
144 | } | |
145 | ||
146 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialEq | |
147 | for RefinementInfo<'a, F, D, A, S, RResult, N, P> | |
148 | where F : Float, | |
149 | D : 'static, | |
150 | A : Aggregator, | |
151 | S : AggregatorSorting<Agg = A> { | |
152 | ||
153 | #[inline] | |
154 | fn eq(&self, other : &Self) -> bool { self.cmp(other) == Equal } | |
155 | } | |
156 | ||
157 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialOrd | |
158 | for RefinementInfo<'a, F, D, A, S, RResult, N, P> | |
159 | where F : Float, | |
160 | D : 'static, | |
161 | A : Aggregator, | |
162 | S : AggregatorSorting<Agg = A> { | |
163 | ||
164 | #[inline] | |
165 | fn partial_cmp(&self, other : &Self) -> Option<Ordering> { Some(self.cmp(other)) } | |
166 | } | |
167 | ||
168 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Eq | |
169 | for RefinementInfo<'a, F, D, A, S, RResult, N, P> | |
170 | where F : Float, | |
171 | D : 'static, | |
172 | A : Aggregator, | |
173 | S : AggregatorSorting<Agg = A> { | |
174 | } | |
175 | ||
176 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Ord | |
177 | for RefinementInfo<'a, F, D, A, S, RResult, N, P> | |
178 | where F : Float, | |
179 | D : 'static, | |
180 | A : Aggregator, | |
181 | S : AggregatorSorting<Agg = A> { | |
182 | ||
183 | #[inline] | |
184 | fn cmp(&self, other : &Self) -> Ordering { | |
185 | let agg1 = self.aggregator(); | |
186 | let agg2 = other.aggregator(); | |
187 | match S::sort_upper(agg1).cmp(&S::sort_upper(agg2)) { | |
188 | Equal => S::sort_lower(agg1).cmp(&S::sort_lower(agg2)), | |
189 | order => order, | |
190 | } | |
191 | } | |
192 | } | |
193 | ||
194 | pub struct HeapContainer<'a, F, D, A, S, RResult, const N : usize, const P : usize> | |
195 | where F : Float, | |
196 | D : 'static + Copy, | |
197 | Const<P> : BranchCount<N>, | |
198 | A : Aggregator, | |
199 | S : AggregatorSorting<Agg = A> { | |
200 | heap : BinaryHeap<RefinementInfo<'a, F, D, A, S, RResult, N, P>>, | |
201 | glb : S::Sort, | |
202 | glb_stale_counter : usize, | |
203 | stale_insert_counter : usize, | |
204 | } | |
205 | ||
206 | impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> | |
207 | HeapContainer<'a, F, D, A, S, RResult, N, P> | |
208 | where F : Float, | |
209 | D : 'static + Copy, | |
210 | Const<P> : BranchCount<N>, | |
211 | A : Aggregator, | |
212 | S : AggregatorSorting<Agg = A> { | |
213 | ||
214 | fn push(&mut self, ri : RefinementInfo<'a, F, D, A, S, RResult, N, P>) { | |
215 | if ri.sort_upper() >= self.glb { | |
216 | let l = ri.sort_lower(); | |
217 | self.heap.push(ri); | |
218 | self.glb = self.glb.max(l); | |
219 | if self.glb_stale_counter > 0 { | |
220 | self.stale_insert_counter += 1; | |
221 | } | |
222 | } | |
223 | } | |
224 | } | |
225 | ||
226 | impl<F : Float, D : 'static + Copy, A, const N : usize, const P : usize> | |
227 | Branches<F,D,A,N,P> | |
228 | where Const<P> : BranchCount<N>, | |
229 | A : Aggregator { | |
230 | ||
231 | /// Stage all subnodes of `self` into the refinement queue [`container`]. | |
232 | fn stage_refine<'a, S, RResult>( | |
233 | &'a mut self, | |
234 | domain : Cube<F,N>, | |
235 | container : &mut HeapContainer<'a, F, D, A, S, RResult, N, P>, | |
236 | ) where S : AggregatorSorting<Agg = A> { | |
237 | // Insert all subnodes into the refinement heap. | |
238 | for (node, subcube) in self.nodes_and_cubes_mut(&domain) { | |
239 | container.push(RefinementInfo { | |
240 | cube : subcube, | |
241 | node : node, | |
242 | refiner_info : None, | |
243 | sorting : PhantomData, | |
244 | }); | |
245 | } | |
246 | } | |
247 | } | |
248 | ||
249 | ||
250 | impl<F : Float, D : 'static + Copy, A, const N : usize, const P : usize> | |
251 | Node<F,D,A,N,P> | |
252 | where Const<P> : BranchCount<N>, | |
253 | A : Aggregator { | |
254 | ||
255 | /// If `self` is a leaf node, uses the `refiner` to determine whether further subdivision | |
256 | /// is required to get a sufficiently refined solution for the problem the refiner is used | |
257 | /// to solve. If the refiner returns [`RefinerResult::Certain`] result, it is returned. | |
258 | /// If [`RefinerResult::Uncertain`] is returned, the leaf is inserted back into the refinement | |
259 | /// queue `container`. If `self` is a branch, its subnodes are staged into `container` using | |
260 | /// [`Branches::stage_refine`]. | |
261 | fn search_and_refine<'a, 'b, R, G>( | |
262 | &'a mut self, | |
263 | domain : Cube<F,N>, | |
264 | refiner : &R, | |
265 | generator : &G, | |
266 | container : &'b mut HeapContainer<'a, F, D, A, R::Sorting, R::Result, N, P>, | |
267 | step : usize | |
268 | ) -> Option<R::Result> | |
269 | where R : Refiner<F, A, G, N>, | |
270 | G : SupportGenerator<F, N, Id=D>, | |
271 | G::SupportType : LocalAnalysis<F, A, N> { | |
272 | ||
273 | // The “complex” repeated pattern matching here is forced by mutability requirements. | |
274 | ||
275 | // Refine a leaf. | |
276 | let res = if let NodeOption::Leaf(ref v) = &mut self.data { | |
277 | let res = refiner.refine(&self.aggregator, &domain, v, generator, step); | |
278 | if let NeedRefinement = res { | |
279 | // The refiner has deemed the leaf unsufficiently refined, so subdivide | |
280 | // it and add the new nodes into the refinement priority heap. | |
281 | // We start iterating from the end to mix support_hint a bit. | |
282 | let mut it = v.iter().rev(); | |
283 | if let Some(&d) = it.next() { | |
284 | // Construct new Branches | |
285 | let support = generator.support_for(d); | |
286 | let b = Rc::new({ | |
287 | let mut b0 = Branches::new_with(&domain, &support); | |
288 | b0.insert(&domain, d, Const::<1>, &support); | |
289 | for &d in it { | |
290 | let support = generator.support_for(d); | |
291 | // TODO: can we be smarter than just refining one level? | |
292 | b0.insert(&domain, d, Const::<1>, &support); | |
293 | } | |
294 | b0 | |
295 | }); | |
296 | // Update current node | |
297 | self.aggregator.summarise(b.aggregators()); | |
298 | self.data = NodeOption::Branches(b); | |
299 | // The branches will be inserted into the refinement priority queue below. | |
300 | } | |
301 | } | |
302 | res | |
303 | } else { | |
304 | NeedRefinement | |
305 | }; | |
306 | ||
307 | if let Uncertain(agg, val) = res { | |
308 | // The refiner gave an undertain result. Push a leaf back into the refinement queue | |
309 | // with the new refined aggregator and custom return value. It will be popped and | |
310 | // returned in the loop of [`BT::search_and_refine`] when there are no unrefined | |
311 | // candidates that could potentially be better according to their basic aggregator. | |
312 | container.push(RefinementInfo { | |
313 | cube : domain, | |
314 | node : self, | |
315 | refiner_info : Some((agg, val)), | |
316 | sorting : PhantomData, | |
317 | }); | |
318 | None | |
319 | } else if let Certain(val) = res { | |
320 | // The refiner gave a certain result so return it to allow early termination | |
321 | Some(val) | |
322 | } else if let NodeOption::Branches(ref mut b) = &mut self.data { | |
323 | // Insert branches into refinement priority queue. | |
324 | Rc::make_mut(b).stage_refine(domain, container); | |
325 | None | |
326 | } else { | |
327 | None | |
328 | } | |
329 | } | |
330 | } | |
331 | ||
332 | /// Helper trait for implementing a refining search on a [`BT`]. | |
333 | pub trait BTSearch<F, const N : usize> : BTImpl<F, N> | |
334 | where F : Float { | |
335 | ||
336 | /// Perform a refining search on [`Self`], as determined by `refiner`. Nodes are inserted | |
337 | /// in a priority queue and processed in the order determined by the [`AggregatorSorting`] | |
338 | /// [`Refiner::Sorting`]. Leaf nodes are subdivided until the refiner decides that a | |
339 | /// sufficiently refined leaf node has been found, as determined by either the refiner | |
340 | /// returning a [`RefinerResult::Certain`] result, or a previous [`RefinerResult::Uncertain`] | |
341 | /// result is found again at the top of the priority queue. | |
342 | fn search_and_refine<'b, R, G>( | |
343 | &'b mut self, | |
344 | refiner : &R, | |
345 | generator : &G, | |
346 | ) -> Option<R::Result> | |
347 | where R : Refiner<F, Self::Agg, G, N>, | |
348 | G : SupportGenerator<F, N, Id=Self::Data>, | |
349 | G::SupportType : LocalAnalysis<F, Self::Agg, N>; | |
350 | } | |
351 | ||
352 | // Needed to get access to a Node without a trait interface. | |
353 | macro_rules! impl_btsearch { | |
354 | ($($n:literal)*) => { $( | |
355 | impl<'a, M, F, D, A> | |
356 | BTSearch<F, $n> | |
357 | for BT<M,F,D,A,$n> | |
358 | where //Self : BTImpl<F,$n,Data=D,Agg=A, Depth=M>, // <== automatically deduce to be implemented | |
359 | M : Depth, | |
360 | F : Float, | |
361 | A : 'a + Aggregator, | |
362 | D : 'static + Copy + std::fmt::Debug { | |
363 | fn search_and_refine<'b, R, G>( | |
364 | &'b mut self, | |
365 | refiner : &R, | |
366 | generator : &G, | |
367 | ) -> Option<R::Result> | |
368 | where R : Refiner<F, A, G, $n>, | |
369 | G : SupportGenerator<F, $n, Id=D>, | |
370 | G::SupportType : LocalAnalysis<F, A, $n> { | |
371 | let mut container = HeapContainer { | |
372 | heap : BinaryHeap::new(), | |
373 | glb : R::Sorting::bottom(), | |
374 | glb_stale_counter : 0, | |
375 | stale_insert_counter : 0, | |
376 | }; | |
377 | container.push(RefinementInfo { | |
378 | cube : self.domain, | |
379 | node : &mut self.topnode, | |
380 | refiner_info : None, | |
381 | sorting : PhantomData, | |
382 | }); | |
383 | let mut step = 0; | |
384 | while let Some(ri) = container.heap.pop() { | |
385 | if let Some((_, result)) = ri.refiner_info { | |
386 | // Terminate based on a “best possible” result. | |
387 | return Some(result) | |
388 | } | |
389 | ||
390 | if ri.sort_lower() >= container.glb { | |
391 | container.glb_stale_counter += 1; | |
392 | if container.stale_insert_counter + container.glb_stale_counter | |
393 | > container.heap.len()/2 { | |
394 | // GLB propery no longer correct. | |
395 | match container.heap.iter().map(|ri| ri.sort_lower()).reduce(max) { | |
396 | Some(glb) => { | |
397 | container.glb = glb; | |
398 | container.heap.retain(|ri| ri.sort_upper() >= glb); | |
399 | }, | |
400 | None => { | |
401 | container.glb = R::Sorting::bottom() | |
402 | } | |
403 | } | |
404 | container.glb_stale_counter = 0; | |
405 | container.stale_insert_counter = 0; | |
406 | } | |
407 | } | |
408 | ||
409 | let res = ri.node.search_and_refine(ri.cube, refiner, generator, | |
410 | &mut container, step); | |
411 | if let Some(_) = res { | |
412 | // Terminate based on a certain result from the refiner | |
413 | return res | |
414 | } | |
415 | ||
416 | step += 1; | |
417 | } | |
418 | None | |
419 | } | |
420 | } | |
421 | )* } | |
422 | } | |
423 | ||
424 | impl_btsearch!(1 2 3 4); | |
425 |