diff -r 4e09b7829b51 -r f40dfaf2166d src/bisection_tree/refine.rs --- a/src/bisection_tree/refine.rs Tue Nov 01 09:24:45 2022 +0200 +++ b/src/bisection_tree/refine.rs Fri Nov 18 10:29:50 2022 +0200 @@ -2,7 +2,7 @@ use std::collections::BinaryHeap; use std::cmp::{PartialOrd, Ord, Ordering, Ordering::*, max}; use std::marker::PhantomData; -use std::sync::{Arc, Mutex, Condvar}; +use std::sync::{Arc, Mutex, MutexGuard, Condvar}; use crate::types::*; use crate::nanleast::NaNLeast; use crate::sets::Cube; @@ -126,7 +126,7 @@ &self, aggregator : &A, domain : &Cube, - data : &Vec, + data : &[G::Id], generator : &G, step : usize, ) -> RefinerResult; @@ -141,9 +141,13 @@ D : 'static, A : Aggregator, S : AggregatorSorting { + /// Domain of `node` cube : Cube, + /// Node to be refined node : &'a mut Node, + /// Result and improve aggregator for the [`Refiner`] refiner_info : Option<(A, RResult)>, + /// For [`AggregatorSorting`] being used for the type system sorting : PhantomData, } @@ -214,8 +218,8 @@ fn cmp(&self, other : &Self) -> Ordering { self.with_aggregator(|agg1| other.with_aggregator(|agg2| { match S::sort_upper(agg1).cmp(&S::sort_upper(agg2)) { - Equal => S::sort_lower(agg1).cmp(&S::sort_lower(agg2)), - order => order, + Equal => S::sort_lower(agg1).cmp(&S::sort_lower(agg2)), + order => order, } })) } @@ -230,13 +234,20 @@ Const

: BranchCount, A : Aggregator, S : AggregatorSorting { + /// Priority queue of nodes to be refined heap : BinaryHeap>, + /// Maximum of node sorting lower bounds seen in the heap glb : S::Sort, - glb_stale_counter : usize, - stale_insert_counter : usize, + /// Number of insertions in the heap since previous prune + insert_counter : usize, + /// If a result has been found by some refinment threat, it is stored here result : Option, + /// Refinement step counter step : usize, - n_processing : usize + /// Number of threads currently processing (not sleeping) + n_processing : usize, + /// Threshold for heap pruning + heap_prune_threshold : usize, } impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> @@ -248,14 +259,19 @@ S : AggregatorSorting { /// Push `ri` into the [`BinaryHeap`]. Do greatest lower bound maintenance. - fn push(&mut self, ri : RefinementInfo<'a, F, D, A, S, RResult, N, P>) { + /// + /// Returns a boolean indicating whether the push was actually performed due to glb + /// filtering or not. + #[inline] + fn push(&mut self, ri : RefinementInfo<'a, F, D, A, S, RResult, N, P>) -> bool { if ri.sort_upper() >= self.glb { let l = ri.sort_lower(); self.heap.push(ri); self.glb = self.glb.max(l); - if self.glb_stale_counter > 0 { - self.stale_insert_counter += 1; - } + self.insert_counter += 1; + true + } else { + false } } } @@ -266,18 +282,17 @@ A : Aggregator, D : 'static + Copy + Send + Sync { - /// Stage all subnodes of `self` into the refinement queue [`container`]. + /// Stage all subnodes of `self` into the refinement queue `container`. fn stage_refine<'a, S, RResult>( &'a mut self, domain : Cube, - container_arc : &Arc>>, + container : &mut HeapContainer<'a, F, D, A, S, RResult, N, P>, ) where S : AggregatorSorting { - let mut container = container_arc.lock().unwrap(); // Insert all subnodes into the refinement heap. - for (node, subcube) in self.nodes_and_cubes_mut(&domain) { + for (node, cube) in self.nodes_and_cubes_mut(&domain) { container.push(RefinementInfo { - cube : subcube, - node : node, + cube, + node, refiner_info : None, sorting : PhantomData, }); @@ -300,74 +315,102 @@ /// [`Branches::stage_refine`]. /// /// `domain`, as usual, indicates the spatial area corresponding to `self`. - fn search_and_refine<'a, 'b, R, G>( + fn search_and_refine<'a, 'b, 'c, R, G>( self : &'a mut Self, domain : Cube, refiner : &R, generator : &G, - container_arc : &Arc>>, + container_arc : &'c Arc>>, step : usize - ) -> Option + ) -> Result>> where R : Refiner, G : SupportGenerator, G::SupportType : LocalAnalysis { + //drop(container); + // Refine a leaf. - let res = if let NodeOption::Leaf(ref v) = &mut self.data { - let res = refiner.refine(&self.aggregator, &domain, v, generator, step); - if let NeedRefinement = res { - // The refiner has deemed the leaf unsufficiently refined, so subdivide - // it and add the new nodes into the refinement priority heap. - // We start iterating from the end to mix support_hint a bit. - let mut it = v.iter().rev(); - if let Some(&d) = it.next() { - // Construct new Branches - let support = generator.support_for(d); - let b = Arc::new({ - let mut b0 = Branches::new_with(&domain, &support); - b0.insert(&domain, d, Const::<1>, &support, TaskBudget::none()); + let res = match self.data { + NodeOption::Leaf(ref mut v) => { + let res = refiner.refine(&self.aggregator, &domain, v, generator, step); + if let NeedRefinement = res { + // The refiner has deemed the leaf unsufficiently refined, so subdivide + // it and add the new nodes into the refinement priority heap. + let mut it = v.iter(); + // Only create new branches if there's anything to add. + // We insert the last item first to mix the support_hint a bit. + if let Some(&d0) = it.next_back() { + // Construct new Branches + let support = generator.support_for(d0); + let mut b = Branches::new_with(&domain, &support); + b.insert(&domain, d0, Const::<1>, &support, TaskBudget::none()); for &d in it { let support = generator.support_for(d); // TODO: can we be smarter than just refining one level? - b0.insert(&domain, d, Const::<1>, &support, TaskBudget::none()); + b.insert(&domain, d, Const::<1>, &support, TaskBudget::none()); } - b0 - }); - // Update current node - b.summarise_into(&mut self.aggregator); - self.data = NodeOption::Branches(b); - // The branches will be inserted into the refinement priority queue below. + // Update current node and stage refinement of new branches. + b.summarise_into(&mut self.aggregator); + // FIXME: parent aggregators are not updated and will be out-of-date, but + // right now pointsource_algs is not really taking advantage of the + // aggregators being updated. Moreover, insertion and aggregator refinement + // code in `bt.rs` will overflow the stack in deeply nested trees. + // We nevertheless need to store `b` into `self` to be able to queue + // the branches. + self.data = NodeOption::Branches(Arc::new(b)); + // This ugly match is needed to keep the compiler happy about lifetimes. + match self.data { + NodeOption::Branches(ref mut arc_b) => { + let mut container = container_arc.lock().unwrap(); + // Safe: we just created arg_b and have a mutable exclusive + // reference to self containing it. + unsafe { Arc::get_mut_unchecked(arc_b) } + .stage_refine(domain, &mut *container); + return Err(container) + }, + _ => unreachable!("This cannot happen"), + } + } } - } - res - } else { - NeedRefinement + res + }, + NodeOption::Branches(ref mut b) => { + // Insert branches into refinement priority queue. + let mut container = container_arc.lock().unwrap(); + Arc::make_mut(b).stage_refine(domain, &mut *container); + return Err(container) + }, + NodeOption::Uninitialised => { + refiner.refine(&self.aggregator, &domain, &[], generator, step) + }, }; - if let Uncertain(agg, val) = res { - // The refiner gave an undertain result. Push a leaf back into the refinement queue - // with the new refined aggregator and custom return value. It will be popped and - // returned in the loop of [`BT::search_and_refine`] when there are no unrefined - // candidates that could potentially be better according to their basic aggregator. - let mut container = container_arc.lock().unwrap(); - container.push(RefinementInfo { - cube : domain, - node : self, - refiner_info : Some((agg, val)), - sorting : PhantomData, - }); - None - } else if let Certain(val) = res { - // The refiner gave a certain result so return it to allow early termination - Some(val) - } else if let NodeOption::Branches(ref mut b) = &mut self.data { - // Insert branches into refinement priority queue. - Arc::make_mut(b).stage_refine(domain, container_arc); - None - } else { - None + match res { + Uncertain(agg, val) => { + // The refiner gave an undertain result. Push a leaf back into the refinement queue + // with the new refined aggregator and custom return value. It will be popped and + // returned in the loop of [`BTSearch::search_and_refine`] when there are no + // unrefined candidates that could potentially be better according to their basic + // aggregator. + let mut container = container_arc.lock().unwrap(); + container.push(RefinementInfo { + cube : domain, + node : self, + refiner_info : Some((agg, val)), + sorting : PhantomData, + }); + Err(container) + }, + Certain(val) => { + // The refiner gave a certain result so return it to allow early termination + Ok(val) + }, + NeedRefinement => { + // This should only happen when we run into NodeOption::Uninitialised above. + // There's really nothing to do. + panic!("Do not know whow to refine uninitialised nodes"); + } } - } } @@ -378,7 +421,7 @@ pub trait BTSearch : BTImpl where F : Float { - /// Perform a search on [`Self`], as determined by `refiner`. + /// Perform a search on on `Self`, as determined by `refiner`. /// /// Nodes are inserted in a priority queue and processed in the order determined by the /// [`AggregatorSorting`] [`Refiner::Sorting`]. Leaf nodes are subdivided until the refiner @@ -398,7 +441,7 @@ } fn refinement_loop ( - condvar : Option>, + wakeup : Option>, refiner : &R, generator_arc : &Arc, container_arc : &Arc>>, @@ -410,93 +453,98 @@ D : 'static + Copy + Sync + Send + std::fmt::Debug { let mut did_park = true; + let mut container = container_arc.lock().unwrap(); 'main: loop { - let (ri, step) = { - let mut container = container_arc.lock().unwrap(); - let ri = 'get_next: loop { - if did_park { - container.n_processing += 1; - did_park = false; - } - - // Some thread has found a result, return - if container.result.is_some() { - container.n_processing -= 1; - break 'main - } + // Find a node to process + let ri = 'get_next: loop { + if did_park { + container.n_processing += 1; + did_park = false; + } - let ri = match container.heap.pop() { - Some(ri) => ri, - None => { - debug_assert!(container.n_processing > 0); - container.n_processing -= 1; - if container.n_processing == 0 { - break 'main; - } else if let Some(ref c) = condvar { - //eprintln!("Sleeping {t:?} {n} {worker_counter}\n", t=thread::current(), n=container.n_processing); - did_park = true; - container = c.wait(container).unwrap(); - continue 'get_next; - } else { - break 'main - } - } - }; - break ri; - }; - - let step = container.step; - container.step += 1; - - if let Some((_, result)) = ri.refiner_info { - // Terminate based on a “best possible” result. - container.result = Some(result); + // Some refinement task/thread has found a result, return + if container.result.is_some() { container.n_processing -= 1; break 'main } - if ri.sort_lower() >= container.glb { - container.glb_stale_counter += 1; - if container.stale_insert_counter + container.glb_stale_counter - > container.heap.len()/2 { - // GLB propery no longer correct. - match container.heap.iter().map(|ri| ri.sort_lower()).reduce(max) { - Some(glb) => { - container.glb = glb; - container.heap.retain(|ri| ri.sort_upper() >= glb); - }, - None => { - container.glb = R::Sorting::bottom() - } + match container.heap.pop() { + // There's work to be done. + Some(ri) => break 'get_next ri, + // No work to be done; park if some task/thread is still processing nodes, + // fail if not. + None => { + debug_assert!(container.n_processing > 0); + container.n_processing -= 1; + if container.n_processing == 0 { + break 'main; + } else if let Some(ref c) = wakeup { + did_park = true; + container = c.wait(container).unwrap(); + continue 'get_next; + } else { + break 'main } - container.glb_stale_counter = 0; - container.stale_insert_counter = 0; } - } - - (ri, step) + }; }; - let res = Node::search_and_refine(ri.node, ri.cube, refiner, &**generator_arc, - &container_arc, step); - if let Some(r) = res { - // Terminate based on a certain result from the refiner - let mut container = container_arc.lock().unwrap(); - if let &mut Some(ref mut r_prev) = &mut container.result { - R::fuse_results(r_prev, r); - } else { - container.result = Some(r); - } + let step = container.step; + container.step += 1; + + if let Some((_, result)) = ri.refiner_info { + // Terminate based on a “best possible” result. + container.result = Some(result); + container.n_processing -= 1; break 'main } - if let Some(ref c) = condvar { - c.notify_one(); + // Do priority queue maintenance + if container.insert_counter > container.heap_prune_threshold { + // Make sure glb is good. + match container.heap.iter().map(|ri| ri.sort_lower()).reduce(max) { + Some(glb) => { + container.glb = glb; + // Prune + container.heap.retain(|ri| ri.sort_upper() >= glb); + }, + None => { + container.glb = R::Sorting::bottom() + } + } + container.insert_counter = 0; } + + // Unlock the mutex… + drop(container); + + // … and process the node. We may get returned an already unlocked mutex. + match Node::search_and_refine(ri.node, ri.cube, refiner, &**generator_arc, + &container_arc, step) { + Ok(r) => { + let mut container = container_arc.lock().unwrap(); + // Terminate based on a certain result from the refiner + match container.result { + Some(ref mut r_prev) => R::fuse_results(r_prev, r), + None => container.result = Some(r), + } + break 'main + }, + Err(cnt) => { + container = cnt; + // Wake up another thread if one is sleeping; there should be now work in the + // queue. + if let Some(ref c) = wakeup { + c.notify_one(); + } + } + } + } - if let Some(ref c) = condvar { + // Make sure no task is sleeping + if let Some(ref c) = wakeup { c.notify_all(); } } @@ -507,7 +555,7 @@ impl<'a, M, F, D, A> BTSearch for BT - where //Self : BTImpl, // <== automatically deduce to be implemented + where //Self : BTImpl, // <== automatically deduced M : Depth, F : Float + Send, A : Aggregator, @@ -520,15 +568,15 @@ where R : Refiner, G : SupportGenerator, G::SupportType : LocalAnalysis { - let mut init_container /*: HeapContainer*/ - = HeapContainer { + let mut init_container = HeapContainer { heap : BinaryHeap::new(), glb : R::Sorting::bottom(), - glb_stale_counter : 0, - stale_insert_counter : 0, + insert_counter : 0, result : None, step : 0, n_processing : 0, + // An arbitrary threshold for starting pruning of the heap + heap_prune_threshold : 2u32.pow(16.max($n * self.depth.value())) as usize }; init_container.push(RefinementInfo { cube : self.domain, @@ -536,33 +584,31 @@ refiner_info : None, sorting : PhantomData, }); - // let n_workers = thread::available_parallelism() - // .map_or(1, core::num::NonZeroUsize::get); - let maybe_pool = thread_pool(); + let container_arc = Arc::new(Mutex::new(init_container)); - if let Some(pool) = maybe_pool { + if let Some(pool) = thread_pool() { let n = thread_pool_size(); pool.scope(|s| { - let condvar = Arc::new(Condvar::new()); - for _ in 0..n{ + let wakeup = Arc::new(Condvar::new()); + for _ in 0..n { let refiner_ref = &refiner; let container_t = Arc::clone(&container_arc); - let condvar_t = Arc::clone(&condvar); + let wakeup_t = Arc::clone(&wakeup); s.spawn(move |_| { - refinement_loop(Some(condvar_t), refiner_ref, generator, + refinement_loop(Some(wakeup_t), refiner_ref, generator, &container_t); }); } - refinement_loop(Some(condvar), &refiner, generator, &container_arc); + refinement_loop(Some(wakeup), &refiner, generator, &container_arc); }); } else { refinement_loop(None, &refiner, generator, &container_arc); } - Arc::try_unwrap(container_arc) - .map(|mtx| mtx.into_inner().unwrap().result) - .ok() - .flatten() + match Arc::try_unwrap(container_arc) { + Ok(mtx) => mtx.into_inner().unwrap().result, + Err(_) => panic!("Refinement threads not finished properly."), + } } } )* }