--- a/src/bisection_tree/refine.rs Thu May 01 08:40:33 2025 -0500 +++ b/src/bisection_tree/refine.rs Thu May 01 13:06:58 2025 -0500 @@ -1,32 +1,31 @@ - -use std::collections::BinaryHeap; -use std::cmp::{PartialOrd, Ord, Ordering, Ordering::*, max}; -use std::marker::PhantomData; -use std::sync::{Arc, Mutex, MutexGuard, Condvar}; -use crate::types::*; +use super::aggregator::*; +use super::bt::*; +use super::support::*; use crate::nanleast::NaNLeast; +use crate::parallelism::TaskBudget; +use crate::parallelism::{thread_pool, thread_pool_size}; use crate::sets::Cube; -use crate::parallelism::{thread_pool_size, thread_pool}; -use super::support::*; -use super::bt::*; -use super::aggregator::*; -use crate::parallelism::TaskBudget; +use crate::types::*; +use std::cmp::{max, Ord, Ordering, Ordering::*, PartialOrd}; +use std::collections::BinaryHeap; +use std::marker::PhantomData; +use std::sync::{Arc, Condvar, Mutex, MutexGuard}; /// Trait for sorting [`Aggregator`]s for [`BT`] refinement. /// /// The sorting involves two sorting keys, the “upper” and the “lower” key. Any [`BT`] nodes /// with upper key less the lower key of another are discarded from the refinement process. /// Nodes with the highest upper sorting key are picked for refinement. -pub trait AggregatorSorting : Sync + Send + 'static { +pub trait AggregatorSorting: Sync + Send + 'static { // Priority - type Agg : Aggregator; - type Sort : Ord + Copy + std::fmt::Debug + Sync + Send; + type Agg: Aggregator; + type Sort: Ord + Copy + std::fmt::Debug + Sync + Send; /// Returns lower sorting key - fn sort_lower(aggregator : &Self::Agg) -> Self::Sort; + fn sort_lower(aggregator: &Self::Agg) -> Self::Sort; /// Returns upper sorting key - fn sort_upper(aggregator : &Self::Agg) -> Self::Sort; + fn sort_upper(aggregator: &Self::Agg) -> Self::Sort; /// Returns a sorting key that is less than any other sorting key. fn bottom() -> Self::Sort; @@ -35,53 +34,64 @@ /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the upper/lower key. /// /// See [`LowerBoundSorting`] for the opposite ordering. -pub struct UpperBoundSorting<F : Float>(PhantomData<F>); +pub struct UpperBoundSorting<F: Float>(PhantomData<F>); /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the lower/upper key. /// /// See [`UpperBoundSorting`] for the opposite ordering. -pub struct LowerBoundSorting<F : Float>(PhantomData<F>); +pub struct LowerBoundSorting<F: Float>(PhantomData<F>); -impl<F : Float> AggregatorSorting for UpperBoundSorting<F> { +impl<F: Float> AggregatorSorting for UpperBoundSorting<F> { type Agg = Bounds<F>; type Sort = NaNLeast<F>; #[inline] - fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.lower()) } - - #[inline] - fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.upper()) } + fn sort_lower(aggregator: &Bounds<F>) -> Self::Sort { + NaNLeast(aggregator.lower()) + } #[inline] - fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } + fn sort_upper(aggregator: &Bounds<F>) -> Self::Sort { + NaNLeast(aggregator.upper()) + } + + #[inline] + fn bottom() -> Self::Sort { + NaNLeast(F::NEG_INFINITY) + } } - -impl<F : Float> AggregatorSorting for LowerBoundSorting<F> { +impl<F: Float> AggregatorSorting for LowerBoundSorting<F> { type Agg = Bounds<F>; type Sort = NaNLeast<F>; #[inline] - fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.lower()) } + fn sort_upper(aggregator: &Bounds<F>) -> Self::Sort { + NaNLeast(-aggregator.lower()) + } #[inline] - fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.upper()) } + fn sort_lower(aggregator: &Bounds<F>) -> Self::Sort { + NaNLeast(-aggregator.upper()) + } #[inline] - fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } + fn bottom() -> Self::Sort { + NaNLeast(F::NEG_INFINITY) + } } /// Return type of [`Refiner::refine`]. /// /// The parameter `R` is the result type of the refiner acting on an [`Aggregator`] of type `A`. -pub enum RefinerResult<A : Aggregator, R> { +pub enum RefinerResult<A: Aggregator, R> { /// Indicates an insufficiently refined state: the [`BT`] needs to be further refined. NeedRefinement, /// Indicates a certain result `R`, stop refinement immediately. Certain(R), /// Indicates an uncertain result: continue refinement until candidates have been exhausted /// or a certain result found. - Uncertain(A, R) + Uncertain(A, R), } use RefinerResult::*; @@ -92,16 +102,17 @@ /// The `Refiner` is used to determine whether an [`Aggregator`] `A` stored in the [`BT`] is /// sufficiently refined within a [`Cube`], and in such a case, produce a desired result (e.g. /// a maximum value of a function). -pub trait Refiner<F : Float, A, G, const N : usize> : Sync + Send + 'static -where F : Num, - A : Aggregator, - G : SupportGenerator<F, N> { - +pub trait Refiner<F: Float, A, G, const N: usize>: Sync + Send + 'static +where + F: Num, + A: Aggregator, + G: SupportGenerator<N, F>, +{ /// The result type of the refiner - type Result : std::fmt::Debug + Sync + Send + 'static; + type Result: std::fmt::Debug + Sync + Send + 'static; /// The sorting to be employed by [`BTSearch::search_and_refine`] on node aggregators /// to detemrine node priority. - type Sorting : AggregatorSorting<Agg = A>; + type Sorting: AggregatorSorting<Agg = A>; /// Determines whether `aggregator` is sufficiently refined within `domain`. /// @@ -124,42 +135,45 @@ /// number of steps is reached. fn refine( &self, - aggregator : &A, - domain : &Cube<F, N>, - data : &[G::Id], - generator : &G, - step : usize, + aggregator: &A, + domain: &Cube<N, F>, + data: &[G::Id], + generator: &G, + step: usize, ) -> RefinerResult<A, Self::Result>; /// Fuse two [`Self::Result`]s (needed in threaded refinement). - fn fuse_results(r1 : &mut Self::Result, r2 : Self::Result); + fn fuse_results(r1: &mut Self::Result, r2: Self::Result); } /// Structure for tracking the refinement process in a [`BinaryHeap`]. -struct RefinementInfo<'a, F, D, A, S, RResult, const N : usize, const P : usize> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting<Agg = A> { +struct RefinementInfo<'a, F, D, A, S, RResult, const N: usize, const P: usize> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting<Agg = A>, +{ /// Domain of `node` - cube : Cube<F, N>, + cube: Cube<N, F>, /// Node to be refined - node : &'a mut Node<F, D, A, N, P>, + node: &'a mut Node<F, D, A, N, P>, /// Result and improve aggregator for the [`Refiner`] - refiner_info : Option<(A, RResult)>, + refiner_info: Option<(A, RResult)>, /// For [`AggregatorSorting`] being used for the type system - sorting : PhantomData<S>, + sorting: PhantomData<S>, } -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> -RefinementInfo<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting<Agg = A> { - +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> + RefinementInfo<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting<Agg = A>, +{ #[inline] - fn with_aggregator<U>(&self, f : impl FnOnce(&A) -> U) -> U { + fn with_aggregator<U>(&self, f: impl FnOnce(&A) -> U) -> U { match self.refiner_info { Some((ref agg, _)) => f(agg), None => f(&self.node.aggregator), @@ -177,93 +191,105 @@ } } -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialEq -for RefinementInfo<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting<Agg = A> { - +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> PartialEq + for RefinementInfo<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting<Agg = A>, +{ #[inline] - fn eq(&self, other : &Self) -> bool { self.cmp(other) == Equal } -} - -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialOrd -for RefinementInfo<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting<Agg = A> { - - #[inline] - fn partial_cmp(&self, other : &Self) -> Option<Ordering> { Some(self.cmp(other)) } + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Equal + } } -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Eq -for RefinementInfo<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting<Agg = A> { +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> PartialOrd + for RefinementInfo<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting<Agg = A>, +{ + #[inline] + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some(self.cmp(other)) + } } -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Ord -for RefinementInfo<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting<Agg = A> { +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> Eq + for RefinementInfo<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting<Agg = A>, +{ +} +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> Ord + for RefinementInfo<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting<Agg = A>, +{ #[inline] - fn cmp(&self, other : &Self) -> Ordering { - self.with_aggregator(|agg1| other.with_aggregator(|agg2| { - match S::sort_upper(agg1).cmp(&S::sort_upper(agg2)) { + fn cmp(&self, other: &Self) -> Ordering { + self.with_aggregator(|agg1| { + other.with_aggregator(|agg2| match S::sort_upper(agg1).cmp(&S::sort_upper(agg2)) { Equal => S::sort_lower(agg1).cmp(&S::sort_lower(agg2)), order => order, - } - })) + }) + }) } } /// This is a container for a [`BinaryHeap`] of [`RefinementInfo`]s together with tracking of /// the greatest lower bound of the [`Aggregator`]s of the [`Node`]s therein accroding to /// chosen [`AggregatorSorting`]. -struct HeapContainer<'a, F, D, A, S, RResult, const N : usize, const P : usize> -where F : Float, - D : 'static + Copy, - Const<P> : BranchCount<N>, - A : Aggregator, - S : AggregatorSorting<Agg = A> { +struct HeapContainer<'a, F, D, A, S, RResult, const N: usize, const P: usize> +where + F: Float, + D: 'static + Copy, + Const<P>: BranchCount<N>, + A: Aggregator, + S: AggregatorSorting<Agg = A>, +{ /// Priority queue of nodes to be refined - heap : BinaryHeap<RefinementInfo<'a, F, D, A, S, RResult, N, P>>, + heap: BinaryHeap<RefinementInfo<'a, F, D, A, S, RResult, N, P>>, /// Maximum of node sorting lower bounds seen in the heap - glb : S::Sort, + glb: S::Sort, /// Number of insertions in the heap since previous prune - insert_counter : usize, + insert_counter: usize, /// If a result has been found by some refinment threat, it is stored here - result : Option<RResult>, + result: Option<RResult>, /// Refinement step counter - step : usize, + step: usize, /// Number of threads currently processing (not sleeping) - n_processing : usize, + n_processing: usize, /// Threshold for heap pruning - heap_prune_threshold : usize, + heap_prune_threshold: usize, } -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> -HeapContainer<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static + Copy, - Const<P> : BranchCount<N>, - A : Aggregator, - S : AggregatorSorting<Agg = A> { - +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> + HeapContainer<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static + Copy, + Const<P>: BranchCount<N>, + A: Aggregator, + S: AggregatorSorting<Agg = A>, +{ /// Push `ri` into the [`BinaryHeap`]. Do greatest lower bound maintenance. /// /// Returns a boolean indicating whether the push was actually performed due to glb /// filtering or not. #[inline] - fn push(&mut self, ri : RefinementInfo<'a, F, D, A, S, RResult, N, P>) -> bool { + fn push(&mut self, ri: RefinementInfo<'a, F, D, A, S, RResult, N, P>) -> bool { if ri.sort_upper() >= self.glb { let l = ri.sort_lower(); self.heap.push(ri); @@ -276,37 +302,38 @@ } } -impl<F : Float, D, A, const N : usize, const P : usize> -Branches<F,D,A,N,P> -where Const<P> : BranchCount<N>, - A : Aggregator, - D : 'static + Copy + Send + Sync { - +impl<F: Float, D, A, const N: usize, const P: usize> Branches<F, D, A, N, P> +where + Const<P>: BranchCount<N>, + A: Aggregator, + D: 'static + Copy + Send + Sync, +{ /// Stage all subnodes of `self` into the refinement queue `container`. fn stage_refine<'a, S, RResult>( &'a mut self, - domain : Cube<F,N>, - container : &mut HeapContainer<'a, F, D, A, S, RResult, N, P>, - ) where S : AggregatorSorting<Agg = A> { + domain: Cube<N, F>, + container: &mut HeapContainer<'a, F, D, A, S, RResult, N, P>, + ) where + S: AggregatorSorting<Agg = A>, + { // Insert all subnodes into the refinement heap. for (node, cube) in self.nodes_and_cubes_mut(&domain) { container.push(RefinementInfo { cube, node, - refiner_info : None, - sorting : PhantomData, + refiner_info: None, + sorting: PhantomData, }); } } } - -impl<F : Float, D, A, const N : usize, const P : usize> -Node<F,D,A,N,P> -where Const<P> : BranchCount<N>, - A : Aggregator, - D : 'static + Copy + Send + Sync { - +impl<F: Float, D, A, const N: usize, const P: usize> Node<F, D, A, N, P> +where + Const<P>: BranchCount<N>, + A: Aggregator, + D: 'static + Copy + Send + Sync, +{ /// If `self` is a leaf node, uses the `refiner` to determine whether further subdivision /// is required to get a sufficiently refined solution for the problem the refiner is used /// to solve. If the refiner returns [`RefinerResult::Certain`] result, it is returned. @@ -316,17 +343,18 @@ /// /// `domain`, as usual, indicates the spatial area corresponding to `self`. fn search_and_refine<'a, 'b, 'c, R, G>( - self : &'a mut Self, - domain : Cube<F,N>, - refiner : &R, - generator : &G, - container_arc : &'c Arc<Mutex<HeapContainer<'a, F, D, A, R::Sorting, R::Result, N, P>>>, - step : usize + self: &'a mut Self, + domain: Cube<N, F>, + refiner: &R, + generator: &G, + container_arc: &'c Arc<Mutex<HeapContainer<'a, F, D, A, R::Sorting, R::Result, N, P>>>, + step: usize, ) -> Result<R::Result, MutexGuard<'c, HeapContainer<'a, F, D, A, R::Sorting, R::Result, N, P>>> - where R : Refiner<F, A, G, N>, - G : SupportGenerator<F, N, Id=D>, - G::SupportType : LocalAnalysis<F, A, N> { - + where + R: Refiner<F, A, G, N>, + G: SupportGenerator<N, F, Id = D>, + G::SupportType: LocalAnalysis<F, A, N>, + { //drop(container); // Refine a leaf. @@ -368,26 +396,27 @@ unsafe { Arc::get_mut_unchecked(arc_b) } .stage_refine(domain, &mut *container); #[cfg(not(nightly))] - Arc::get_mut(arc_b).unwrap() + Arc::get_mut(arc_b) + .unwrap() .stage_refine(domain, &mut *container); - - return Err(container) - }, + + return Err(container); + } _ => unreachable!("This cannot happen"), } } } res - }, + } NodeOption::Branches(ref mut b) => { // Insert branches into refinement priority queue. let mut container = container_arc.lock().unwrap(); Arc::make_mut(b).stage_refine(domain, &mut *container); - return Err(container) - }, + return Err(container); + } NodeOption::Uninitialised => { refiner.refine(&self.aggregator, &domain, &[], generator, step) - }, + } }; match res { @@ -399,17 +428,17 @@ // aggregator. let mut container = container_arc.lock().unwrap(); container.push(RefinementInfo { - cube : domain, - node : self, - refiner_info : Some((agg, val)), - sorting : PhantomData, + cube: domain, + node: self, + refiner_info: Some((agg, val)), + sorting: PhantomData, }); Err(container) - }, + } Certain(val) => { // The refiner gave a certain result so return it to allow early termination Ok(val) - }, + } NeedRefinement => { // This should only happen when we run into NodeOption::Uninitialised above. // There's really nothing to do. @@ -423,9 +452,10 @@ /// /// This can be removed and the methods implemented directly on [`BT`] once Rust's const generics /// are flexible enough to allow fixing `P=pow(2, N)`. -pub trait BTSearch<F, const N : usize> : BTImpl<F, N> -where F : Float { - +pub trait BTSearch<const N: usize, F = f64>: BTImpl<N, F> +where + F: Float, +{ /// Perform a search on on `Self`, as determined by `refiner`. /// /// Nodes are inserted in a priority queue and processed in the order determined by the @@ -437,26 +467,28 @@ /// The `generator` converts [`BTImpl::Data`] stored in the bisection tree into a [`Support`]. fn search_and_refine<'b, R, G>( &'b mut self, - refiner : R, - generator : &Arc<G>, + refiner: R, + generator: &Arc<G>, ) -> Option<R::Result> - where R : Refiner<F, Self::Agg, G, N> + Sync + Send + 'static, - G : SupportGenerator<F, N, Id=Self::Data> + Sync + Send + 'static, - G::SupportType : LocalAnalysis<F, Self::Agg, N>; + where + R: Refiner<F, Self::Agg, G, N> + Sync + Send + 'static, + G: SupportGenerator<N, F, Id = Self::Data> + Sync + Send + 'static, + G::SupportType: LocalAnalysis<F, Self::Agg, N>; } -fn refinement_loop<F : Float, D, A, R, G, const N : usize, const P : usize> ( - wakeup : Option<Arc<Condvar>>, - refiner : &R, - generator_arc : &Arc<G>, - container_arc : &Arc<Mutex<HeapContainer<F, D, A, R::Sorting, R::Result, N, P>>>, -) where A : Aggregator, - R : Refiner<F, A, G, N>, - G : SupportGenerator<F, N, Id=D>, - G::SupportType : LocalAnalysis<F, A, N>, - Const<P> : BranchCount<N>, - D : 'static + Copy + Sync + Send + std::fmt::Debug { - +fn refinement_loop<F: Float, D, A, R, G, const N: usize, const P: usize>( + wakeup: Option<Arc<Condvar>>, + refiner: &R, + generator_arc: &Arc<G>, + container_arc: &Arc<Mutex<HeapContainer<F, D, A, R::Sorting, R::Result, N, P>>>, +) where + A: Aggregator, + R: Refiner<F, A, G, N>, + G: SupportGenerator<N, F, Id = D>, + G::SupportType: LocalAnalysis<F, A, N>, + Const<P>: BranchCount<N>, + D: 'static + Copy + Sync + Send + std::fmt::Debug, +{ let mut did_park = true; let mut container = container_arc.lock().unwrap(); @@ -471,7 +503,7 @@ // Some refinement task/thread has found a result, return if container.result.is_some() { container.n_processing -= 1; - break 'main + break 'main; } match container.heap.pop() { @@ -489,7 +521,7 @@ container = c.wait(container).unwrap(); continue 'get_next; } else { - break 'main + break 'main; } } }; @@ -502,7 +534,7 @@ // Terminate based on a “best possible” result. container.result = Some(result); container.n_processing -= 1; - break 'main + break 'main; } // Do priority queue maintenance @@ -513,10 +545,8 @@ container.glb = glb; // Prune container.heap.retain(|ri| ri.sort_upper() >= glb); - }, - None => { - container.glb = R::Sorting::bottom() } + None => container.glb = R::Sorting::bottom(), } container.insert_counter = 0; } @@ -525,8 +555,14 @@ drop(container); // … and process the node. We may get returned an already unlocked mutex. - match Node::search_and_refine(ri.node, ri.cube, refiner, &**generator_arc, - &container_arc, step) { + match Node::search_and_refine( + ri.node, + ri.cube, + refiner, + &**generator_arc, + &container_arc, + step, + ) { Ok(r) => { let mut container = container_arc.lock().unwrap(); // Terminate based on a certain result from the refiner @@ -534,8 +570,8 @@ Some(ref mut r_prev) => R::fuse_results(r_prev, r), None => container.result = Some(r), } - break 'main - }, + break 'main; + } Err(cnt) => { container = cnt; // Wake up another thread if one is sleeping; there should be now work in the @@ -545,7 +581,6 @@ } } } - } // Make sure no task is sleeping @@ -558,9 +593,9 @@ macro_rules! impl_btsearch { ($($n:literal)*) => { $( impl<'a, M, F, D, A> - BTSearch<F, $n> + BTSearch<$n, F> for BT<M,F,D,A,$n> - where //Self : BTImpl<F,$n,Data=D,Agg=A, Depth=M>, // <== automatically deduced + where //Self : BTImpl<$n, F, Data=D,Agg=A, Depth=M>, // <== automatically deduced M : Depth, F : Float + Send, A : Aggregator, @@ -571,7 +606,7 @@ generator : &Arc<G>, ) -> Option<R::Result> where R : Refiner<F, A, G, $n>, - G : SupportGenerator<F, $n, Id=D>, + G : SupportGenerator< $n, F, Id=D>, G::SupportType : LocalAnalysis<F, A, $n> { let mut init_container = HeapContainer { heap : BinaryHeap::new(), @@ -620,4 +655,3 @@ } impl_btsearch!(1 2 3 4); -