src/bisection_tree/refine.rs

Tue, 06 Dec 2022 08:32:57 +0200

author
Tuomo Valkonen <tuomov@iki.fi>
date
Tue, 06 Dec 2022 08:32:57 +0200
changeset 15
e03ce15643da
parent 9
f40dfaf2166d
permissions
-rw-r--r--

Fix broken links in doc comments after Mapping -> Apply change.


use std::collections::BinaryHeap;
use std::cmp::{PartialOrd, Ord, Ordering, Ordering::*, max};
use std::marker::PhantomData;
use std::sync::{Arc, Mutex, MutexGuard, Condvar};
use crate::types::*;
use crate::nanleast::NaNLeast;
use crate::sets::Cube;
use crate::parallelism::{thread_pool_size, thread_pool};
use super::support::*;
use super::bt::*;
use super::aggregator::*;
use crate::parallelism::TaskBudget;

/// Trait for sorting [`Aggregator`]s for [`BT`] refinement.
///
/// The sorting involves two sorting keys, the “upper” and the “lower” key. Any [`BT`] nodes
/// with upper key less the lower key of another are discarded from the refinement process.
/// Nodes with the highest upper sorting key are picked for refinement.
pub trait AggregatorSorting : Sync + Send + 'static {
    // Priority
    type Agg : Aggregator;
    type Sort : Ord + Copy + std::fmt::Debug + Sync + Send;

    /// Returns lower sorting key
    fn sort_lower(aggregator : &Self::Agg) -> Self::Sort;

    /// Returns upper sorting key
    fn sort_upper(aggregator : &Self::Agg) -> Self::Sort;

    /// Returns a sorting key that is less than any other sorting key.
    fn bottom() -> Self::Sort;
}

/// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the upper/lower key.
///
/// See [`LowerBoundSorting`] for the opposite ordering.
pub struct UpperBoundSorting<F : Float>(PhantomData<F>);

/// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the lower/upper key.
///
/// See [`UpperBoundSorting`] for the opposite ordering.
pub struct LowerBoundSorting<F : Float>(PhantomData<F>);

impl<F : Float> AggregatorSorting for UpperBoundSorting<F> {
    type Agg = Bounds<F>;
    type Sort = NaNLeast<F>;

    #[inline]
    fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.lower()) }

    #[inline]
    fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.upper()) }

    #[inline]
    fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) }
}


impl<F : Float> AggregatorSorting for LowerBoundSorting<F> {
    type Agg = Bounds<F>;
    type Sort = NaNLeast<F>;

    #[inline]
    fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.lower()) }

    #[inline]
    fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.upper()) }

    #[inline]
    fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) }
}

/// Return type of [`Refiner::refine`].
///
/// The parameter `R` is the result type of the refiner acting on an [`Aggregator`] of type `A`.
pub enum RefinerResult<A : Aggregator, R> {
    /// Indicates an insufficiently refined state: the [`BT`] needs to be further refined.
    NeedRefinement,
    /// Indicates a certain result `R`, stop refinement immediately.
    Certain(R),
    /// Indicates an uncertain result: continue refinement until candidates have been exhausted
    /// or a certain result found.
    Uncertain(A, R)
}

use RefinerResult::*;

/// A `Refiner` is used to search a [`BT`], refining the subdivision when necessary.
///
/// The search is performed by [`BTSearch::search_and_refine`].
/// The `Refiner` is used to determine whether an [`Aggregator`] `A` stored in the [`BT`] is
/// sufficiently refined within a [`Cube`], and in such a case, produce a desired result (e.g.
/// a maximum value of a function).
pub trait Refiner<F : Float, A, G, const N : usize> : Sync + Send + 'static
where F : Num,
      A : Aggregator,
      G : SupportGenerator<F, N> {

    /// The result type of the refiner
    type Result : std::fmt::Debug + Sync + Send + 'static;
    /// The sorting to be employed by [`BTSearch::search_and_refine`] on node aggregators
    /// to detemrine node priority.
    type Sorting : AggregatorSorting<Agg = A>;

    /// Determines whether `aggregator` is sufficiently refined within `domain`.
    ///
    /// If the aggregator is sufficiently refined that the desired `Self::Result` can be produced,
    /// a [`RefinerResult`]`::Certain` or `Uncertain` should be returned, depending on
    /// the confidence of the solution. In the uncertain case an improved aggregator should also
    /// be included. If the result cannot be produced, `NeedRefinement` should be
    /// returned.
    ///
    /// For example, if the refiner is used to minimise a function presented by the `BT`,
    /// an `Uncertain` result can be used to return a local maximum of the function on `domain`
    /// The result can be claimed `Certain` if it is a global maximum. In that case the
    /// refinment will stop immediately. A `NeedRefinement` result indicates that the `aggregator`
    /// and/or `domain` are not sufficiently refined to compute a lcoal maximum of sufficient
    /// quality.
    ///
    /// The vector `data` stored all the data of the [`BT`] in the node corresponding to `domain`.
    /// The `generator` can be used to convert `data` into [`Support`]s. The parameter `step`
    /// counts the calls to `refine`, and can be used to stop the refinement when a maximum
    /// number of steps is reached.
    fn refine(
        &self,
        aggregator : &A,
        domain : &Cube<F, N>,
        data : &[G::Id],
        generator : &G,
        step : usize,
    ) -> RefinerResult<A, Self::Result>;

    /// Fuse two [`Self::Result`]s (needed in threaded refinement).
    fn fuse_results(r1 : &mut Self::Result, r2 : Self::Result);
}

/// Structure for tracking the refinement process in a [`BinaryHeap`].
struct RefinementInfo<'a, F, D, A, S, RResult, const N : usize, const P : usize>
where F : Float,
      D : 'static,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {
    /// Domain of `node`
    cube : Cube<F, N>,
    /// Node to be refined
    node : &'a mut Node<F, D, A, N, P>,
    /// Result and improve aggregator for the [`Refiner`]
    refiner_info : Option<(A, RResult)>,
    /// For [`AggregatorSorting`] being used for the type system
    sorting : PhantomData<S>,
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize>
RefinementInfo<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {

    #[inline]
    fn with_aggregator<U>(&self, f : impl FnOnce(&A) -> U) -> U {
        match self.refiner_info {
            Some((ref agg, _)) => f(agg),
            None => f(&self.node.aggregator),
        }
    }

    #[inline]
    fn sort_lower(&self) -> S::Sort {
        self.with_aggregator(S::sort_lower)
    }

    #[inline]
    fn sort_upper(&self) -> S::Sort {
        self.with_aggregator(S::sort_upper)
    }
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialEq
for RefinementInfo<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {

    #[inline]
    fn eq(&self, other : &Self) -> bool { self.cmp(other) == Equal }
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialOrd
for RefinementInfo<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {

    #[inline]
    fn partial_cmp(&self, other : &Self) -> Option<Ordering> { Some(self.cmp(other)) }
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Eq
for RefinementInfo<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Ord
for RefinementInfo<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {

    #[inline]
    fn cmp(&self, other : &Self) -> Ordering {
        self.with_aggregator(|agg1| other.with_aggregator(|agg2| {
            match S::sort_upper(agg1).cmp(&S::sort_upper(agg2)) {
                Equal => S::sort_lower(agg1).cmp(&S::sort_lower(agg2)),
                order => order,
            }
        }))
    }
}

/// This is a container for a [`BinaryHeap`] of [`RefinementInfo`]s together with tracking of
/// the greatest lower bound of the [`Aggregator`]s of the [`Node`]s therein accroding to
/// chosen [`AggregatorSorting`].
struct HeapContainer<'a, F, D, A, S, RResult, const N : usize, const P : usize>
where F : Float,
      D : 'static + Copy,
      Const<P> : BranchCount<N>,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {
    /// Priority queue of nodes to be refined
    heap : BinaryHeap<RefinementInfo<'a, F, D, A, S, RResult, N, P>>,
    /// Maximum of node sorting lower bounds seen in the heap
    glb : S::Sort,
    /// Number of insertions in the heap since previous prune
    insert_counter : usize,
    /// If a result has been found by some refinment threat, it is stored here
    result : Option<RResult>,
    /// Refinement step counter
    step : usize,
    /// Number of threads currently processing (not sleeping)
    n_processing : usize,
    /// Threshold for heap pruning
    heap_prune_threshold : usize,
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize>
HeapContainer<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static + Copy,
      Const<P> : BranchCount<N>,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {

    /// Push `ri` into the [`BinaryHeap`]. Do greatest lower bound maintenance.
    ///
    /// Returns a boolean indicating whether the push was actually performed due to glb
    /// filtering or not.
    #[inline]
    fn push(&mut self, ri : RefinementInfo<'a, F, D, A, S, RResult, N, P>) -> bool {
        if ri.sort_upper() >= self.glb {
            let l = ri.sort_lower();
            self.heap.push(ri);
            self.glb = self.glb.max(l);
            self.insert_counter += 1;
            true
        } else {
            false
        }
    }
}

impl<F : Float, D, A, const N : usize, const P : usize>
Branches<F,D,A,N,P>
where Const<P> : BranchCount<N>,
      A : Aggregator,
      D : 'static + Copy + Send + Sync {

    /// Stage all subnodes of `self` into the refinement queue `container`.
    fn stage_refine<'a, S, RResult>(
        &'a mut self,
        domain : Cube<F,N>,
        container : &mut HeapContainer<'a, F, D, A, S, RResult, N, P>,
    ) where S : AggregatorSorting<Agg = A> {
        // Insert all subnodes into the refinement heap.
        for (node, cube) in self.nodes_and_cubes_mut(&domain) {
            container.push(RefinementInfo {
                cube,
                node,
                refiner_info : None,
                sorting : PhantomData,
            });
        }
    }
}


impl<F : Float, D, A, const N : usize, const P : usize>
Node<F,D,A,N,P>
where Const<P> : BranchCount<N>,
      A : Aggregator,
      D : 'static + Copy + Send + Sync {

    /// If `self` is a leaf node, uses the `refiner` to determine whether further subdivision
    /// is required to get a sufficiently refined solution for the problem the refiner is used
    /// to solve. If the refiner returns [`RefinerResult::Certain`] result, it is returned.
    /// If [`RefinerResult::Uncertain`] is returned, the leaf is inserted back into the refinement
    /// queue `container`. If `self` is a branch, its subnodes are staged into `container` using
    /// [`Branches::stage_refine`].
    ///
    /// `domain`, as usual, indicates the spatial area corresponding to `self`.
    fn search_and_refine<'a, 'b, 'c, R, G>(
        self : &'a mut Self,
        domain : Cube<F,N>,
        refiner : &R,
        generator : &G,
        container_arc : &'c Arc<Mutex<HeapContainer<'a, F, D, A, R::Sorting, R::Result, N, P>>>,
        step : usize
    ) -> Result<R::Result, MutexGuard<'c, HeapContainer<'a, F, D, A, R::Sorting, R::Result, N, P>>>
    where R : Refiner<F, A, G, N>,
          G : SupportGenerator<F, N, Id=D>,
          G::SupportType : LocalAnalysis<F, A, N> {

        //drop(container);

        // Refine a leaf.
        let res = match self.data {
            NodeOption::Leaf(ref mut v) => {
                let res = refiner.refine(&self.aggregator, &domain, v, generator, step);
                if let NeedRefinement = res {
                    // The refiner has deemed the leaf unsufficiently refined, so subdivide
                    // it and add the new nodes into the refinement priority heap.
                    let mut it = v.iter();
                    // Only create new branches if there's anything to add.
                    // We insert the last item first to mix the support_hint a bit.
                    if let Some(&d0) = it.next_back() {
                        // Construct new Branches
                        let support = generator.support_for(d0);
                        let mut b = Branches::new_with(&domain, &support);
                        b.insert(&domain, d0, Const::<1>, &support, TaskBudget::none());
                        for &d in it {
                            let support = generator.support_for(d);
                            // TODO: can we be smarter than just refining one level?
                            b.insert(&domain, d, Const::<1>, &support, TaskBudget::none());
                        }
                        // Update current node and stage refinement of new branches.
                        b.summarise_into(&mut self.aggregator);
                        // FIXME: parent aggregators are not updated and will be out-of-date, but
                        // right now pointsource_algs is not really taking advantage of the
                        // aggregators being updated. Moreover, insertion and aggregator refinement
                        // code in `bt.rs` will overflow the stack in deeply nested trees.
                        // We nevertheless need to store `b` into `self` to be able to queue
                        // the branches.
                        self.data = NodeOption::Branches(Arc::new(b));
                        // This ugly match is needed to keep the compiler happy about lifetimes.
                        match self.data {
                            NodeOption::Branches(ref mut arc_b) => {
                                let mut container = container_arc.lock().unwrap();
                                // Safe: we just created arg_b and have a mutable exclusive
                                // reference to self containing it.
                                unsafe { Arc::get_mut_unchecked(arc_b) }
                                    .stage_refine(domain, &mut *container);
                                return Err(container)
                            },
                            _ => unreachable!("This cannot happen"),
                        }
                    }
                }
                res
            },
            NodeOption::Branches(ref mut b) => {
                // Insert branches into refinement priority queue.
                let mut container = container_arc.lock().unwrap();
                Arc::make_mut(b).stage_refine(domain, &mut *container);
                return Err(container)
            },
            NodeOption::Uninitialised => {
                refiner.refine(&self.aggregator, &domain, &[], generator, step)
            },
        };

        match res {
            Uncertain(agg, val) => {
                // The refiner gave an undertain result. Push a leaf back into the refinement queue
                // with the new refined aggregator and custom return value. It will be popped and
                // returned in the loop of [`BTSearch::search_and_refine`] when there are no
                // unrefined candidates that could potentially be better according to their basic
                // aggregator.
                let mut container = container_arc.lock().unwrap();
                container.push(RefinementInfo {
                    cube : domain,
                    node : self,
                    refiner_info : Some((agg, val)),
                    sorting : PhantomData,
                });
                Err(container)
            },
            Certain(val) => {
                // The refiner gave a certain result so return it to allow early termination
                Ok(val)
            },
            NeedRefinement => {
                // This should only happen when we run into NodeOption::Uninitialised above.
                // There's really nothing to do.
                panic!("Do not know whow to refine uninitialised nodes");
            }
        }
    }
}

/// Interface trait to a refining search on a [`BT`].
///
/// This can be removed and the methods implemented directly on [`BT`] once Rust's const generics
/// are flexible enough to allow fixing `P=pow(2, N)`.
pub trait BTSearch<F, const N : usize> : BTImpl<F, N>
where F : Float {

    /// Perform a search on on `Self`, as determined by `refiner`.
    ///
    /// Nodes are inserted in a priority queue and processed in the order determined by the
    /// [`AggregatorSorting`] [`Refiner::Sorting`]. Leaf nodes are subdivided until the refiner
    /// decides that a sufficiently refined leaf node has been found, as determined by either the
    /// refiner returning a [`RefinerResult::Certain`] result, or a previous
    /// [`RefinerResult::Uncertain`] result is found again at the top of the priority queue.
    ///
    /// The `generator` converts [`BTImpl::Data`] stored in the bisection tree into a [`Support`].
    fn search_and_refine<'b, R, G>(
        &'b mut self,
        refiner : R,
        generator : &Arc<G>,
    ) -> Option<R::Result>
    where R : Refiner<F, Self::Agg, G, N> + Sync + Send + 'static,
          G : SupportGenerator<F, N, Id=Self::Data> + Sync + Send + 'static,
          G::SupportType : LocalAnalysis<F, Self::Agg, N>;
}

fn refinement_loop<F : Float, D, A, R, G, const N : usize, const P : usize> (
    wakeup : Option<Arc<Condvar>>,
    refiner : &R,
    generator_arc : &Arc<G>,
    container_arc : &Arc<Mutex<HeapContainer<F, D, A, R::Sorting, R::Result, N, P>>>,
) where A : Aggregator,
        R : Refiner<F, A, G, N>,
        G : SupportGenerator<F, N, Id=D>,
        G::SupportType : LocalAnalysis<F, A, N>,
        Const<P> : BranchCount<N>,
        D : 'static + Copy + Sync + Send + std::fmt::Debug {

    let mut did_park = true;
    let mut container = container_arc.lock().unwrap();

    'main: loop {
        // Find a node to process
        let ri = 'get_next: loop {
            if did_park {
                container.n_processing += 1;
                did_park = false;
            }

            // Some refinement task/thread has found a result, return
            if container.result.is_some() {
                container.n_processing -= 1;
                break 'main
            }

            match container.heap.pop() {
                // There's work to be done.
                Some(ri) => break 'get_next ri,
                // No work to be done; park if some task/thread is still processing nodes,
                // fail if not.
                None => {
                    debug_assert!(container.n_processing > 0);
                    container.n_processing -= 1;
                    if container.n_processing == 0 {
                        break 'main;
                    } else if let Some(ref c) = wakeup {
                        did_park = true;
                        container = c.wait(container).unwrap();
                        continue 'get_next;
                    } else {
                        break 'main
                    }
                }
            };
        };

        let step = container.step;
        container.step += 1;

        if let Some((_, result)) = ri.refiner_info {
            // Terminate based on a “best possible” result.
            container.result = Some(result);
            container.n_processing -= 1;
            break 'main
        }

        // Do priority queue maintenance
        if container.insert_counter > container.heap_prune_threshold {
            // Make sure glb is good.
            match container.heap.iter().map(|ri| ri.sort_lower()).reduce(max) {
                Some(glb) => {
                    container.glb = glb;
                    // Prune
                    container.heap.retain(|ri| ri.sort_upper() >= glb);
                },
                None => {
                    container.glb = R::Sorting::bottom()
                }
            }
            container.insert_counter = 0;
        }

        // Unlock the mutex…
        drop(container);

        // … and process the node. We may get returned an already unlocked mutex.
        match Node::search_and_refine(ri.node, ri.cube, refiner, &**generator_arc,
                                      &container_arc, step) {
            Ok(r) => {
                let mut container = container_arc.lock().unwrap();
                // Terminate based on a certain result from the refiner
                match container.result {
                    Some(ref mut r_prev) => R::fuse_results(r_prev, r),
                    None => container.result = Some(r),
                }
                break 'main
            },
            Err(cnt) => {
                container = cnt;
                // Wake up another thread if one is sleeping; there should be now work in the
                // queue.
                if let Some(ref c) = wakeup {
                    c.notify_one();
                }
            }
        }

    }

    // Make sure no task is sleeping
    if let Some(ref c) = wakeup {
        c.notify_all();
    }
}

// Needed to get access to a Node without a trait interface.
macro_rules! impl_btsearch {
    ($($n:literal)*) => { $(
        impl<'a, M, F, D, A>
        BTSearch<F, $n>
        for BT<M,F,D,A,$n>
        where //Self : BTImpl<F,$n,Data=D,Agg=A, Depth=M>, //  <== automatically deduced
              M : Depth,
              F : Float + Send,
              A : Aggregator,
              D : 'static + Copy + Sync + Send + std::fmt::Debug {
            fn search_and_refine<'b, R, G>(
                &'b mut self,
                refiner : R,
                generator : &Arc<G>,
            ) -> Option<R::Result>
            where R : Refiner<F, A, G, $n>,
                  G : SupportGenerator<F, $n, Id=D>,
                  G::SupportType : LocalAnalysis<F, A, $n> {
                let mut init_container = HeapContainer {
                    heap : BinaryHeap::new(),
                    glb : R::Sorting::bottom(),
                    insert_counter : 0,
                    result : None,
                    step : 0,
                    n_processing : 0,
                    // An arbitrary threshold for starting pruning of the heap
                    heap_prune_threshold : 2u32.pow(16.max($n * self.depth.value())) as usize
                };
                init_container.push(RefinementInfo {
                    cube : self.domain,
                    node : &mut self.topnode,
                    refiner_info : None,
                    sorting : PhantomData,
                });

                let container_arc = Arc::new(Mutex::new(init_container));
                if let Some(pool) = thread_pool() {
                    let n = thread_pool_size();
                    pool.scope(|s| {
                        let wakeup = Arc::new(Condvar::new());
                        for _ in 0..n {
                            let refiner_ref = &refiner;
                            let container_t = Arc::clone(&container_arc);
                            let wakeup_t = Arc::clone(&wakeup);
                            s.spawn(move |_| {
                                refinement_loop(Some(wakeup_t), refiner_ref, generator,
                                                &container_t);
                            });
                        }
                        refinement_loop(Some(wakeup), &refiner, generator, &container_arc);
                    });
                } else {
                    refinement_loop(None, &refiner, generator, &container_arc);
                }

                match Arc::try_unwrap(container_arc) {
                    Ok(mtx) => mtx.into_inner().unwrap().result,
                    Err(_) => panic!("Refinement threads not finished properly."),
                }
            }
        }
    )* }
}

impl_btsearch!(1 2 3 4);

mercurial