Wed, 07 Dec 2022 07:00:27 +0200
Added tag v0.1.0 for changeset 51bfde513cfa
use std::collections::BinaryHeap; use std::cmp::{PartialOrd, Ord, Ordering, Ordering::*, max}; use std::marker::PhantomData; use std::sync::{Arc, Mutex, MutexGuard, Condvar}; use crate::types::*; use crate::nanleast::NaNLeast; use crate::sets::Cube; use crate::parallelism::{thread_pool_size, thread_pool}; use super::support::*; use super::bt::*; use super::aggregator::*; use crate::parallelism::TaskBudget; /// Trait for sorting [`Aggregator`]s for [`BT`] refinement. /// /// The sorting involves two sorting keys, the “upper” and the “lower” key. Any [`BT`] nodes /// with upper key less the lower key of another are discarded from the refinement process. /// Nodes with the highest upper sorting key are picked for refinement. pub trait AggregatorSorting : Sync + Send + 'static { // Priority type Agg : Aggregator; type Sort : Ord + Copy + std::fmt::Debug + Sync + Send; /// Returns lower sorting key fn sort_lower(aggregator : &Self::Agg) -> Self::Sort; /// Returns upper sorting key fn sort_upper(aggregator : &Self::Agg) -> Self::Sort; /// Returns a sorting key that is less than any other sorting key. fn bottom() -> Self::Sort; } /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the upper/lower key. /// /// See [`LowerBoundSorting`] for the opposite ordering. pub struct UpperBoundSorting<F : Float>(PhantomData<F>); /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the lower/upper key. /// /// See [`UpperBoundSorting`] for the opposite ordering. pub struct LowerBoundSorting<F : Float>(PhantomData<F>); impl<F : Float> AggregatorSorting for UpperBoundSorting<F> { type Agg = Bounds<F>; type Sort = NaNLeast<F>; #[inline] fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.lower()) } #[inline] fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.upper()) } #[inline] fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } } impl<F : Float> AggregatorSorting for LowerBoundSorting<F> { type Agg = Bounds<F>; type Sort = NaNLeast<F>; #[inline] fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.lower()) } #[inline] fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.upper()) } #[inline] fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } } /// Return type of [`Refiner::refine`]. /// /// The parameter `R` is the result type of the refiner acting on an [`Aggregator`] of type `A`. pub enum RefinerResult<A : Aggregator, R> { /// Indicates an insufficiently refined state: the [`BT`] needs to be further refined. NeedRefinement, /// Indicates a certain result `R`, stop refinement immediately. Certain(R), /// Indicates an uncertain result: continue refinement until candidates have been exhausted /// or a certain result found. Uncertain(A, R) } use RefinerResult::*; /// A `Refiner` is used to search a [`BT`], refining the subdivision when necessary. /// /// The search is performed by [`BTSearch::search_and_refine`]. /// The `Refiner` is used to determine whether an [`Aggregator`] `A` stored in the [`BT`] is /// sufficiently refined within a [`Cube`], and in such a case, produce a desired result (e.g. /// a maximum value of a function). pub trait Refiner<F : Float, A, G, const N : usize> : Sync + Send + 'static where F : Num, A : Aggregator, G : SupportGenerator<F, N> { /// The result type of the refiner type Result : std::fmt::Debug + Sync + Send + 'static; /// The sorting to be employed by [`BTSearch::search_and_refine`] on node aggregators /// to detemrine node priority. type Sorting : AggregatorSorting<Agg = A>; /// Determines whether `aggregator` is sufficiently refined within `domain`. /// /// If the aggregator is sufficiently refined that the desired `Self::Result` can be produced, /// a [`RefinerResult`]`::Certain` or `Uncertain` should be returned, depending on /// the confidence of the solution. In the uncertain case an improved aggregator should also /// be included. If the result cannot be produced, `NeedRefinement` should be /// returned. /// /// For example, if the refiner is used to minimise a function presented by the `BT`, /// an `Uncertain` result can be used to return a local maximum of the function on `domain` /// The result can be claimed `Certain` if it is a global maximum. In that case the /// refinment will stop immediately. A `NeedRefinement` result indicates that the `aggregator` /// and/or `domain` are not sufficiently refined to compute a lcoal maximum of sufficient /// quality. /// /// The vector `data` stored all the data of the [`BT`] in the node corresponding to `domain`. /// The `generator` can be used to convert `data` into [`Support`]s. The parameter `step` /// counts the calls to `refine`, and can be used to stop the refinement when a maximum /// number of steps is reached. fn refine( &self, aggregator : &A, domain : &Cube<F, N>, data : &[G::Id], generator : &G, step : usize, ) -> RefinerResult<A, Self::Result>; /// Fuse two [`Self::Result`]s (needed in threaded refinement). fn fuse_results(r1 : &mut Self::Result, r2 : Self::Result); } /// Structure for tracking the refinement process in a [`BinaryHeap`]. struct RefinementInfo<'a, F, D, A, S, RResult, const N : usize, const P : usize> where F : Float, D : 'static, A : Aggregator, S : AggregatorSorting<Agg = A> { /// Domain of `node` cube : Cube<F, N>, /// Node to be refined node : &'a mut Node<F, D, A, N, P>, /// Result and improve aggregator for the [`Refiner`] refiner_info : Option<(A, RResult)>, /// For [`AggregatorSorting`] being used for the type system sorting : PhantomData<S>, } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> RefinementInfo<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static, A : Aggregator, S : AggregatorSorting<Agg = A> { #[inline] fn with_aggregator<U>(&self, f : impl FnOnce(&A) -> U) -> U { match self.refiner_info { Some((ref agg, _)) => f(agg), None => f(&self.node.aggregator), } } #[inline] fn sort_lower(&self) -> S::Sort { self.with_aggregator(S::sort_lower) } #[inline] fn sort_upper(&self) -> S::Sort { self.with_aggregator(S::sort_upper) } } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialEq for RefinementInfo<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static, A : Aggregator, S : AggregatorSorting<Agg = A> { #[inline] fn eq(&self, other : &Self) -> bool { self.cmp(other) == Equal } } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialOrd for RefinementInfo<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static, A : Aggregator, S : AggregatorSorting<Agg = A> { #[inline] fn partial_cmp(&self, other : &Self) -> Option<Ordering> { Some(self.cmp(other)) } } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Eq for RefinementInfo<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static, A : Aggregator, S : AggregatorSorting<Agg = A> { } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Ord for RefinementInfo<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static, A : Aggregator, S : AggregatorSorting<Agg = A> { #[inline] fn cmp(&self, other : &Self) -> Ordering { self.with_aggregator(|agg1| other.with_aggregator(|agg2| { match S::sort_upper(agg1).cmp(&S::sort_upper(agg2)) { Equal => S::sort_lower(agg1).cmp(&S::sort_lower(agg2)), order => order, } })) } } /// This is a container for a [`BinaryHeap`] of [`RefinementInfo`]s together with tracking of /// the greatest lower bound of the [`Aggregator`]s of the [`Node`]s therein accroding to /// chosen [`AggregatorSorting`]. struct HeapContainer<'a, F, D, A, S, RResult, const N : usize, const P : usize> where F : Float, D : 'static + Copy, Const<P> : BranchCount<N>, A : Aggregator, S : AggregatorSorting<Agg = A> { /// Priority queue of nodes to be refined heap : BinaryHeap<RefinementInfo<'a, F, D, A, S, RResult, N, P>>, /// Maximum of node sorting lower bounds seen in the heap glb : S::Sort, /// Number of insertions in the heap since previous prune insert_counter : usize, /// If a result has been found by some refinment threat, it is stored here result : Option<RResult>, /// Refinement step counter step : usize, /// Number of threads currently processing (not sleeping) n_processing : usize, /// Threshold for heap pruning heap_prune_threshold : usize, } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> HeapContainer<'a, F, D, A, S, RResult, N, P> where F : Float, D : 'static + Copy, Const<P> : BranchCount<N>, A : Aggregator, S : AggregatorSorting<Agg = A> { /// Push `ri` into the [`BinaryHeap`]. Do greatest lower bound maintenance. /// /// Returns a boolean indicating whether the push was actually performed due to glb /// filtering or not. #[inline] fn push(&mut self, ri : RefinementInfo<'a, F, D, A, S, RResult, N, P>) -> bool { if ri.sort_upper() >= self.glb { let l = ri.sort_lower(); self.heap.push(ri); self.glb = self.glb.max(l); self.insert_counter += 1; true } else { false } } } impl<F : Float, D, A, const N : usize, const P : usize> Branches<F,D,A,N,P> where Const<P> : BranchCount<N>, A : Aggregator, D : 'static + Copy + Send + Sync { /// Stage all subnodes of `self` into the refinement queue `container`. fn stage_refine<'a, S, RResult>( &'a mut self, domain : Cube<F,N>, container : &mut HeapContainer<'a, F, D, A, S, RResult, N, P>, ) where S : AggregatorSorting<Agg = A> { // Insert all subnodes into the refinement heap. for (node, cube) in self.nodes_and_cubes_mut(&domain) { container.push(RefinementInfo { cube, node, refiner_info : None, sorting : PhantomData, }); } } } impl<F : Float, D, A, const N : usize, const P : usize> Node<F,D,A,N,P> where Const<P> : BranchCount<N>, A : Aggregator, D : 'static + Copy + Send + Sync { /// If `self` is a leaf node, uses the `refiner` to determine whether further subdivision /// is required to get a sufficiently refined solution for the problem the refiner is used /// to solve. If the refiner returns [`RefinerResult::Certain`] result, it is returned. /// If [`RefinerResult::Uncertain`] is returned, the leaf is inserted back into the refinement /// queue `container`. If `self` is a branch, its subnodes are staged into `container` using /// [`Branches::stage_refine`]. /// /// `domain`, as usual, indicates the spatial area corresponding to `self`. fn search_and_refine<'a, 'b, 'c, R, G>( self : &'a mut Self, domain : Cube<F,N>, refiner : &R, generator : &G, container_arc : &'c Arc<Mutex<HeapContainer<'a, F, D, A, R::Sorting, R::Result, N, P>>>, step : usize ) -> Result<R::Result, MutexGuard<'c, HeapContainer<'a, F, D, A, R::Sorting, R::Result, N, P>>> where R : Refiner<F, A, G, N>, G : SupportGenerator<F, N, Id=D>, G::SupportType : LocalAnalysis<F, A, N> { //drop(container); // Refine a leaf. let res = match self.data { NodeOption::Leaf(ref mut v) => { let res = refiner.refine(&self.aggregator, &domain, v, generator, step); if let NeedRefinement = res { // The refiner has deemed the leaf unsufficiently refined, so subdivide // it and add the new nodes into the refinement priority heap. let mut it = v.iter(); // Only create new branches if there's anything to add. // We insert the last item first to mix the support_hint a bit. if let Some(&d0) = it.next_back() { // Construct new Branches let support = generator.support_for(d0); let mut b = Branches::new_with(&domain, &support); b.insert(&domain, d0, Const::<1>, &support, TaskBudget::none()); for &d in it { let support = generator.support_for(d); // TODO: can we be smarter than just refining one level? b.insert(&domain, d, Const::<1>, &support, TaskBudget::none()); } // Update current node and stage refinement of new branches. b.summarise_into(&mut self.aggregator); // FIXME: parent aggregators are not updated and will be out-of-date, but // right now pointsource_algs is not really taking advantage of the // aggregators being updated. Moreover, insertion and aggregator refinement // code in `bt.rs` will overflow the stack in deeply nested trees. // We nevertheless need to store `b` into `self` to be able to queue // the branches. self.data = NodeOption::Branches(Arc::new(b)); // This ugly match is needed to keep the compiler happy about lifetimes. match self.data { NodeOption::Branches(ref mut arc_b) => { let mut container = container_arc.lock().unwrap(); // Safe: we just created arg_b and have a mutable exclusive // reference to self containing it. unsafe { Arc::get_mut_unchecked(arc_b) } .stage_refine(domain, &mut *container); return Err(container) }, _ => unreachable!("This cannot happen"), } } } res }, NodeOption::Branches(ref mut b) => { // Insert branches into refinement priority queue. let mut container = container_arc.lock().unwrap(); Arc::make_mut(b).stage_refine(domain, &mut *container); return Err(container) }, NodeOption::Uninitialised => { refiner.refine(&self.aggregator, &domain, &[], generator, step) }, }; match res { Uncertain(agg, val) => { // The refiner gave an undertain result. Push a leaf back into the refinement queue // with the new refined aggregator and custom return value. It will be popped and // returned in the loop of [`BTSearch::search_and_refine`] when there are no // unrefined candidates that could potentially be better according to their basic // aggregator. let mut container = container_arc.lock().unwrap(); container.push(RefinementInfo { cube : domain, node : self, refiner_info : Some((agg, val)), sorting : PhantomData, }); Err(container) }, Certain(val) => { // The refiner gave a certain result so return it to allow early termination Ok(val) }, NeedRefinement => { // This should only happen when we run into NodeOption::Uninitialised above. // There's really nothing to do. panic!("Do not know whow to refine uninitialised nodes"); } } } } /// Interface trait to a refining search on a [`BT`]. /// /// This can be removed and the methods implemented directly on [`BT`] once Rust's const generics /// are flexible enough to allow fixing `P=pow(2, N)`. pub trait BTSearch<F, const N : usize> : BTImpl<F, N> where F : Float { /// Perform a search on on `Self`, as determined by `refiner`. /// /// Nodes are inserted in a priority queue and processed in the order determined by the /// [`AggregatorSorting`] [`Refiner::Sorting`]. Leaf nodes are subdivided until the refiner /// decides that a sufficiently refined leaf node has been found, as determined by either the /// refiner returning a [`RefinerResult::Certain`] result, or a previous /// [`RefinerResult::Uncertain`] result is found again at the top of the priority queue. /// /// The `generator` converts [`BTImpl::Data`] stored in the bisection tree into a [`Support`]. fn search_and_refine<'b, R, G>( &'b mut self, refiner : R, generator : &Arc<G>, ) -> Option<R::Result> where R : Refiner<F, Self::Agg, G, N> + Sync + Send + 'static, G : SupportGenerator<F, N, Id=Self::Data> + Sync + Send + 'static, G::SupportType : LocalAnalysis<F, Self::Agg, N>; } fn refinement_loop<F : Float, D, A, R, G, const N : usize, const P : usize> ( wakeup : Option<Arc<Condvar>>, refiner : &R, generator_arc : &Arc<G>, container_arc : &Arc<Mutex<HeapContainer<F, D, A, R::Sorting, R::Result, N, P>>>, ) where A : Aggregator, R : Refiner<F, A, G, N>, G : SupportGenerator<F, N, Id=D>, G::SupportType : LocalAnalysis<F, A, N>, Const<P> : BranchCount<N>, D : 'static + Copy + Sync + Send + std::fmt::Debug { let mut did_park = true; let mut container = container_arc.lock().unwrap(); 'main: loop { // Find a node to process let ri = 'get_next: loop { if did_park { container.n_processing += 1; did_park = false; } // Some refinement task/thread has found a result, return if container.result.is_some() { container.n_processing -= 1; break 'main } match container.heap.pop() { // There's work to be done. Some(ri) => break 'get_next ri, // No work to be done; park if some task/thread is still processing nodes, // fail if not. None => { debug_assert!(container.n_processing > 0); container.n_processing -= 1; if container.n_processing == 0 { break 'main; } else if let Some(ref c) = wakeup { did_park = true; container = c.wait(container).unwrap(); continue 'get_next; } else { break 'main } } }; }; let step = container.step; container.step += 1; if let Some((_, result)) = ri.refiner_info { // Terminate based on a “best possible” result. container.result = Some(result); container.n_processing -= 1; break 'main } // Do priority queue maintenance if container.insert_counter > container.heap_prune_threshold { // Make sure glb is good. match container.heap.iter().map(|ri| ri.sort_lower()).reduce(max) { Some(glb) => { container.glb = glb; // Prune container.heap.retain(|ri| ri.sort_upper() >= glb); }, None => { container.glb = R::Sorting::bottom() } } container.insert_counter = 0; } // Unlock the mutex… drop(container); // … and process the node. We may get returned an already unlocked mutex. match Node::search_and_refine(ri.node, ri.cube, refiner, &**generator_arc, &container_arc, step) { Ok(r) => { let mut container = container_arc.lock().unwrap(); // Terminate based on a certain result from the refiner match container.result { Some(ref mut r_prev) => R::fuse_results(r_prev, r), None => container.result = Some(r), } break 'main }, Err(cnt) => { container = cnt; // Wake up another thread if one is sleeping; there should be now work in the // queue. if let Some(ref c) = wakeup { c.notify_one(); } } } } // Make sure no task is sleeping if let Some(ref c) = wakeup { c.notify_all(); } } // Needed to get access to a Node without a trait interface. macro_rules! impl_btsearch { ($($n:literal)*) => { $( impl<'a, M, F, D, A> BTSearch<F, $n> for BT<M,F,D,A,$n> where //Self : BTImpl<F,$n,Data=D,Agg=A, Depth=M>, // <== automatically deduced M : Depth, F : Float + Send, A : Aggregator, D : 'static + Copy + Sync + Send + std::fmt::Debug { fn search_and_refine<'b, R, G>( &'b mut self, refiner : R, generator : &Arc<G>, ) -> Option<R::Result> where R : Refiner<F, A, G, $n>, G : SupportGenerator<F, $n, Id=D>, G::SupportType : LocalAnalysis<F, A, $n> { let mut init_container = HeapContainer { heap : BinaryHeap::new(), glb : R::Sorting::bottom(), insert_counter : 0, result : None, step : 0, n_processing : 0, // An arbitrary threshold for starting pruning of the heap heap_prune_threshold : 2u32.pow(16.max($n * self.depth.value())) as usize }; init_container.push(RefinementInfo { cube : self.domain, node : &mut self.topnode, refiner_info : None, sorting : PhantomData, }); let container_arc = Arc::new(Mutex::new(init_container)); if let Some(pool) = thread_pool() { let n = thread_pool_size(); pool.scope(|s| { let wakeup = Arc::new(Condvar::new()); for _ in 0..n { let refiner_ref = &refiner; let container_t = Arc::clone(&container_arc); let wakeup_t = Arc::clone(&wakeup); s.spawn(move |_| { refinement_loop(Some(wakeup_t), refiner_ref, generator, &container_t); }); } refinement_loop(Some(wakeup), &refiner, generator, &container_arc); }); } else { refinement_loop(None, &refiner, generator, &container_arc); } match Arc::try_unwrap(container_arc) { Ok(mtx) => mtx.into_inner().unwrap().result, Err(_) => panic!("Refinement threads not finished properly."), } } } )* } } impl_btsearch!(1 2 3 4);