Sat, 22 Oct 2022 22:28:04 +0300
Convert iteration utilities to GATs
use std::collections::BinaryHeap; use std::cmp::{PartialOrd,Ord,Ordering,Ordering::*,max}; use std::rc::Rc; use std::marker::PhantomData; use crate::types::*; use super::support::*; use super::bt::*; use super::aggregator::*; use crate::nanleast::NaNLeast; /// Trait for sorting [`Aggregator`]s for [`BT`] refinement. /// /// The sorting involves two sorting keys, the “upper” and the “lower” key. Any [`BT`] branches /// with upper key less the lower key of another are discarded from the refinement process. pub trait AggregatorSorting { // Priority type Agg : Aggregator; type Sort : Ord + Copy + std::fmt::Debug; /// Returns lower sorting key fn sort_lower(aggregator : &Self::Agg) -> Self::Sort; /// Returns upper sorting key fn sort_upper(aggregator : &Self::Agg) -> Self::Sort; /// Returns bottom sorting key. fn bottom() -> Self::Sort; } /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the upper/lower key. /// /// See [`LowerBoundSorting`] for the opposite ordering. pub struct UpperBoundSorting<F : Float>(PhantomData<F>); /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the lower/upper key. /// /// See [`UpperBoundSorting`] for the opposite ordering. pub struct LowerBoundSorting<F : Float>(PhantomData<F>); impl<F : Float> AggregatorSorting for UpperBoundSorting<F> { type Agg = Bounds<F>; type Sort = NaNLeast<F>; #[inline] fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.lower()) } #[inline] fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.upper()) } #[inline] fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } } impl<F : Float> AggregatorSorting for LowerBoundSorting<F> { type Agg = Bounds<F>; type Sort = NaNLeast<F>; #[inline] fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.lower()) } #[inline] fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.upper()) } #[inline] fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } } /// Result type of [`Refiner::refine`] for a refiner producing a result of type `R` acting on /// an [`Aggregator`] of type `A`. pub enum RefinerResult<A : Aggregator, R> { /// Indicates in insufficiently refined state: the [`BT`] needs to be further refined. NeedRefinement, /// Indicates a certain result `R`, stop refinement immediately. Certain(R), /// Indicates an uncertain result: continue refinement until candidates have been exhausted /// or a certain result found. Uncertain(A, R) } use RefinerResult::*; /// A `Refiner` is used to determine whether an [`Aggregator`] `A` is sufficiently refined within /// a [`Cube`] of a [`BT`], and in such a case, produce a desired result (e.g. a maximum value of /// a function). pub trait Refiner<F : Float, A, G, const N : usize> where F : Num, A : Aggregator, G : SupportGenerator<F, N> { type Result : std::fmt::Debug; type Sorting : AggregatorSorting<Agg = A>; /// Determines whether `aggregator` is sufficiently refined within `cube`. /// Should return a possibly refined version of the `aggregator` and an arbitrary value of /// the result type of the refiner. fn refine( &self, aggregator : &A, domain : &Cube<F, N>, data : &Vec<G::Id>, generator : &G, step : usize, ) -> RefinerResult<A, Self::Result>; } /// Structure for tracking the refinement process in a [`BinaryHeap`]. struct RefinementInfo<'a, F, D, A, S, RResult, const N : usize, const P : usize> where F : Float, D : 'static +, A : Aggregator, S : AggregatorSorting<Agg = A> { cube : Cube<F, N>, node : &'a mut Node<F, D, A, N, P>, refiner_info : Option<(A, RResult)>, sorting : PhantomData<S>, } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> RefinementInfo<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static, A : Aggregator, S : AggregatorSorting<Agg = A> { #[inline] fn aggregator(&self) -> &A { match self.refiner_info { Some((ref agg, _)) => agg, None => &self.node.aggregator, } } #[inline] fn sort_lower(&self) -> S::Sort { S::sort_lower(self.aggregator()) } #[inline] fn sort_upper(&self) -> S::Sort { S::sort_upper(self.aggregator()) } } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialEq for RefinementInfo<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static, A : Aggregator, S : AggregatorSorting<Agg = A> { #[inline] fn eq(&self, other : &Self) -> bool { self.cmp(other) == Equal } } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialOrd for RefinementInfo<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static, A : Aggregator, S : AggregatorSorting<Agg = A> { #[inline] fn partial_cmp(&self, other : &Self) -> Option<Ordering> { Some(self.cmp(other)) } } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Eq for RefinementInfo<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static, A : Aggregator, S : AggregatorSorting<Agg = A> { } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Ord for RefinementInfo<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static, A : Aggregator, S : AggregatorSorting<Agg = A> { #[inline] fn cmp(&self, other : &Self) -> Ordering { let agg1 = self.aggregator(); let agg2 = other.aggregator(); match S::sort_upper(agg1).cmp(&S::sort_upper(agg2)) { Equal => S::sort_lower(agg1).cmp(&S::sort_lower(agg2)), order => order, } } } pub struct HeapContainer<'a, F, D, A, S, RResult, const N : usize, const P : usize> where F : Float, D : 'static + Copy, Const<P> : BranchCount<N>, A : Aggregator, S : AggregatorSorting<Agg = A> { heap : BinaryHeap<RefinementInfo<'a, F, D, A, S, RResult, N, P>>, glb : S::Sort, glb_stale_counter : usize, stale_insert_counter : usize, } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> HeapContainer<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static + Copy, Const<P> : BranchCount<N>, A : Aggregator, S : AggregatorSorting<Agg = A> { fn push(&mut self, ri : RefinementInfo<'a, F, D, A, S, RResult, N, P>) { if ri.sort_upper() >= self.glb { let l = ri.sort_lower(); self.heap.push(ri); self.glb = self.glb.max(l); if self.glb_stale_counter > 0 { self.stale_insert_counter += 1; } } } } impl<F : Float, D : 'static + Copy, A, const N : usize, const P : usize> Branches<F,D,A,N,P> where Const<P> : BranchCount<N>, A : Aggregator { /// Stage all subnodes of `self` into the refinement queue [`container`]. fn stage_refine<'a, S, RResult>( &'a mut self, domain : Cube<F,N>, container : &mut HeapContainer<'a, F, D, A, S, RResult, N, P>, ) where S : AggregatorSorting<Agg = A> { // Insert all subnodes into the refinement heap. for (node, subcube) in self.nodes_and_cubes_mut(&domain) { container.push(RefinementInfo { cube : subcube, node : node, refiner_info : None, sorting : PhantomData, }); } } } impl<F : Float, D : 'static + Copy, A, const N : usize, const P : usize> Node<F,D,A,N,P> where Const<P> : BranchCount<N>, A : Aggregator { /// If `self` is a leaf node, uses the `refiner` to determine whether further subdivision /// is required to get a sufficiently refined solution for the problem the refiner is used /// to solve. If the refiner returns [`RefinerResult::Certain`] result, it is returned. /// If [`RefinerResult::Uncertain`] is returned, the leaf is inserted back into the refinement /// queue `container`. If `self` is a branch, its subnodes are staged into `container` using /// [`Branches::stage_refine`]. fn search_and_refine<'a, 'b, R, G>( &'a mut self, domain : Cube<F,N>, refiner : &R, generator : &G, container : &'b mut HeapContainer<'a, F, D, A, R::Sorting, R::Result, N, P>, step : usize ) -> Option<R::Result> where R : Refiner<F, A, G, N>, G : SupportGenerator<F, N, Id=D>, G::SupportType : LocalAnalysis<F, A, N> { // The “complex” repeated pattern matching here is forced by mutability requirements. // Refine a leaf. let res = if let NodeOption::Leaf(ref v) = &mut self.data { let res = refiner.refine(&self.aggregator, &domain, v, generator, step); if let NeedRefinement = res { // The refiner has deemed the leaf unsufficiently refined, so subdivide // it and add the new nodes into the refinement priority heap. // We start iterating from the end to mix support_hint a bit. let mut it = v.iter().rev(); if let Some(&d) = it.next() { // Construct new Branches let support = generator.support_for(d); let b = Rc::new({ let mut b0 = Branches::new_with(&domain, &support); b0.insert(&domain, d, Const::<1>, &support); for &d in it { let support = generator.support_for(d); // TODO: can we be smarter than just refining one level? b0.insert(&domain, d, Const::<1>, &support); } b0 }); // Update current node self.aggregator.summarise(b.aggregators()); self.data = NodeOption::Branches(b); // The branches will be inserted into the refinement priority queue below. } } res } else { NeedRefinement }; if let Uncertain(agg, val) = res { // The refiner gave an undertain result. Push a leaf back into the refinement queue // with the new refined aggregator and custom return value. It will be popped and // returned in the loop of [`BT::search_and_refine`] when there are no unrefined // candidates that could potentially be better according to their basic aggregator. container.push(RefinementInfo { cube : domain, node : self, refiner_info : Some((agg, val)), sorting : PhantomData, }); None } else if let Certain(val) = res { // The refiner gave a certain result so return it to allow early termination Some(val) } else if let NodeOption::Branches(ref mut b) = &mut self.data { // Insert branches into refinement priority queue. Rc::make_mut(b).stage_refine(domain, container); None } else { None } } } /// Helper trait for implementing a refining search on a [`BT`]. pub trait BTSearch<F, const N : usize> : BTImpl<F, N> where F : Float { /// Perform a refining search on [`Self`], as determined by `refiner`. Nodes are inserted /// in a priority queue and processed in the order determined by the [`AggregatorSorting`] /// [`Refiner::Sorting`]. Leaf nodes are subdivided until the refiner decides that a /// sufficiently refined leaf node has been found, as determined by either the refiner /// returning a [`RefinerResult::Certain`] result, or a previous [`RefinerResult::Uncertain`] /// result is found again at the top of the priority queue. fn search_and_refine<'b, R, G>( &'b mut self, refiner : &R, generator : &G, ) -> Option<R::Result> where R : Refiner<F, Self::Agg, G, N>, G : SupportGenerator<F, N, Id=Self::Data>, G::SupportType : LocalAnalysis<F, Self::Agg, N>; } // Needed to get access to a Node without a trait interface. macro_rules! impl_btsearch { ($($n:literal)*) => { $( impl<'a, M, F, D, A> BTSearch<F, $n> for BT<M,F,D,A,$n> where //Self : BTImpl<F,$n,Data=D,Agg=A, Depth=M>, // <== automatically deduce to be implemented M : Depth, F : Float, A : 'a + Aggregator, D : 'static + Copy + std::fmt::Debug { fn search_and_refine<'b, R, G>( &'b mut self, refiner : &R, generator : &G, ) -> Option<R::Result> where R : Refiner<F, A, G, $n>, G : SupportGenerator<F, $n, Id=D>, G::SupportType : LocalAnalysis<F, A, $n> { let mut container = HeapContainer { heap : BinaryHeap::new(), glb : R::Sorting::bottom(), glb_stale_counter : 0, stale_insert_counter : 0, }; container.push(RefinementInfo { cube : self.domain, node : &mut self.topnode, refiner_info : None, sorting : PhantomData, }); let mut step = 0; while let Some(ri) = container.heap.pop() { if let Some((_, result)) = ri.refiner_info { // Terminate based on a “best possible” result. return Some(result) } if ri.sort_lower() >= container.glb { container.glb_stale_counter += 1; if container.stale_insert_counter + container.glb_stale_counter > container.heap.len()/2 { // GLB propery no longer correct. match container.heap.iter().map(|ri| ri.sort_lower()).reduce(max) { Some(glb) => { container.glb = glb; container.heap.retain(|ri| ri.sort_upper() >= glb); }, None => { container.glb = R::Sorting::bottom() } } container.glb_stale_counter = 0; container.stale_insert_counter = 0; } } let res = ri.node.search_and_refine(ri.cube, refiner, generator, &mut container, step); if let Some(_) = res { // Terminate based on a certain result from the refiner return res } step += 1; } None } } )* } } impl_btsearch!(1 2 3 4);