src/bisection_tree/refine.rs

Wed, 26 Oct 2022 22:16:57 +0300

author
Tuomo Valkonen <tuomov@iki.fi>
date
Wed, 26 Oct 2022 22:16:57 +0300
changeset 7
860a54fca7bc
parent 5
59dc4c5883f4
child 8
4e09b7829b51
permissions
-rw-r--r--

Added tag unthreaded for changeset d80b87b8acd0


use std::collections::BinaryHeap;
use std::cmp::{PartialOrd,Ord,Ordering,Ordering::*,max};
use std::rc::Rc;
use std::marker::PhantomData;
use crate::types::*;
use crate::nanleast::NaNLeast;
use crate::sets::Cube;
use super::support::*;
use super::bt::*;
use super::aggregator::*;

/// Trait for sorting [`Aggregator`]s for [`BT`] refinement.
///
/// The sorting involves two sorting keys, the “upper” and the “lower” key. Any [`BT`] nodes
/// with upper key less the lower key of another are discarded from the refinement process.
/// Nodes with the highest upper sorting key are picked for refinement.
pub trait AggregatorSorting {
    // Priority
    type Agg : Aggregator;
    type Sort : Ord + Copy + std::fmt::Debug;

    /// Returns lower sorting key
    fn sort_lower(aggregator : &Self::Agg) -> Self::Sort;

    /// Returns upper sorting key
    fn sort_upper(aggregator : &Self::Agg) -> Self::Sort;

    /// Returns a sorting key that is less than any other sorting key.
    fn bottom() -> Self::Sort;
}

/// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the upper/lower key.
///
/// See [`LowerBoundSorting`] for the opposite ordering.
pub struct UpperBoundSorting<F : Float>(PhantomData<F>);

/// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the lower/upper key.
///
/// See [`UpperBoundSorting`] for the opposite ordering.
pub struct LowerBoundSorting<F : Float>(PhantomData<F>);

impl<F : Float> AggregatorSorting for UpperBoundSorting<F> {
    type Agg = Bounds<F>;
    type Sort = NaNLeast<F>;

    #[inline]
    fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.lower()) }

    #[inline]
    fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(aggregator.upper()) }

    #[inline]
    fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) }
}


impl<F : Float> AggregatorSorting for LowerBoundSorting<F> {
    type Agg = Bounds<F>;
    type Sort = NaNLeast<F>;

    #[inline]
    fn sort_upper(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.lower()) }

    #[inline]
    fn sort_lower(aggregator : &Bounds<F>) -> Self::Sort { NaNLeast(-aggregator.upper()) }

    #[inline]
    fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) }
}

/// Return type of [`Refiner::refine`].
///
/// The parameter `R` is the result type of the refiner acting on an [`Aggregator`] of type `A`.
pub enum RefinerResult<A : Aggregator, R> {
    /// Indicates an insufficiently refined state: the [`BT`] needs to be further refined.
    NeedRefinement,
    /// Indicates a certain result `R`, stop refinement immediately.
    Certain(R),
    /// Indicates an uncertain result: continue refinement until candidates have been exhausted
    /// or a certain result found.
    Uncertain(A, R)
}

use RefinerResult::*;

/// A `Refiner` is used to search a [`BT`], refining the subdivision when necessary.
///
/// The search is performed by [`BTSearch::search_and_refine`].
/// The `Refiner` is used to determine whether an [`Aggregator`] `A` stored in the [`BT`] is
/// sufficiently refined within a [`Cube`], and in such a case, produce a desired result (e.g.
/// a maximum value of a function).
pub trait Refiner<F : Float, A, G, const N : usize>
where F : Num,
      A : Aggregator,
      G : SupportGenerator<F, N> {

    /// The result type of the refiner
    type Result : std::fmt::Debug;
    /// The sorting to be employed by [`BTSearch::search_and_refine`] on node aggregators
    /// to detemrine node priority.
    type Sorting : AggregatorSorting<Agg = A>;

    /// Determines whether `aggregator` is sufficiently refined within `domain`.
    ///
    /// If the aggregator is sufficiently refined that the desired `Self::Result` can be produced,
    /// a [`RefinerResult`]`::Certain` or `Uncertain` should be returned, depending on
    /// the confidence of the solution. In the uncertain case an improved aggregator should also
    /// be included. If the result cannot be produced, `NeedRefinement` should be
    /// returned.
    ///
    /// For example, if the refiner is used to minimise a function presented by the `BT`,
    /// an `Uncertain` result can be used to return a local maximum of the function on `domain`
    /// The result can be claimed `Certain` if it is a global maximum. In that case the
    /// refinment will stop immediately. A `NeedRefinement` result indicates that the `aggregator`
    /// and/or `domain` are not sufficiently refined to compute a lcoal maximum of sufficient
    /// quality.
    ///
    /// The vector `data` stored all the data of the [`BT`] in the node corresponding to `domain`.
    /// The `generator` can be used to convert `data` into [`Support`]s. The parameter `step`
    /// counts the calls to `refine`, and can be used to stop the refinement when a maximum
    /// number of steps is reached.
    fn refine(
        &self,
        aggregator : &A,
        domain : &Cube<F, N>,
        data : &Vec<G::Id>,
        generator : &G,
        step : usize,
    ) -> RefinerResult<A, Self::Result>;
}

/// Structure for tracking the refinement process in a [`BinaryHeap`].
struct RefinementInfo<'a, F, D, A, S, RResult, const N : usize, const P : usize>
where F : Float,
      D : 'static +,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {
    cube : Cube<F, N>,
    node : &'a mut Node<F, D, A, N, P>,
    refiner_info : Option<(A, RResult)>,
    sorting : PhantomData<S>,
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize>
RefinementInfo<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {

    #[inline]
    fn aggregator(&self) -> &A {
        match self.refiner_info {
            Some((ref agg, _)) => agg,
            None => &self.node.aggregator,
        }
    }

    #[inline]
    fn sort_lower(&self) -> S::Sort {
        S::sort_lower(self.aggregator())
    }

    #[inline]
    fn sort_upper(&self) -> S::Sort {
        S::sort_upper(self.aggregator())
    }
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialEq
for RefinementInfo<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {

    #[inline]
    fn eq(&self, other : &Self) -> bool { self.cmp(other) == Equal }
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialOrd
for RefinementInfo<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {

    #[inline]
    fn partial_cmp(&self, other : &Self) -> Option<Ordering> { Some(self.cmp(other)) }
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Eq
for RefinementInfo<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Ord
for RefinementInfo<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {

    #[inline]
    fn cmp(&self, other : &Self) -> Ordering {
       let agg1 = self.aggregator();
       let agg2 = other.aggregator();
       match S::sort_upper(agg1).cmp(&S::sort_upper(agg2)) {
            Equal => S::sort_lower(agg1).cmp(&S::sort_lower(agg2)),
            order => order,
       }
    }
}

/// This is a container for a [`BinaryHeap`] of [`RefinementInfo`]s together with tracking of
/// the greatest lower bound of the [`Aggregator`]s of the [`Node`]s therein accroding to
/// chosen [`AggregatorSorting`].
struct HeapContainer<'a, F, D, A, S, RResult, const N : usize, const P : usize>
where F : Float,
      D : 'static + Copy,
      Const<P> : BranchCount<N>,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {
    heap : BinaryHeap<RefinementInfo<'a, F, D, A, S, RResult, N, P>>,
    glb : S::Sort,
    glb_stale_counter : usize,
    stale_insert_counter : usize,
}

impl<'a, F, D, A, S, RResult, const N : usize, const P : usize>
HeapContainer<'a, F, D, A, S, RResult, N, P>
where F : Float,
      D : 'static + Copy,
      Const<P> : BranchCount<N>,
      A : Aggregator,
      S : AggregatorSorting<Agg = A> {

    /// Push `ri` into the [`BinaryHeap`]. Do greatest lower bound maintenance.
    fn push(&mut self, ri : RefinementInfo<'a, F, D, A, S, RResult, N, P>) {
        if ri.sort_upper() >= self.glb {
            let l = ri.sort_lower();
            self.heap.push(ri);
            self.glb = self.glb.max(l);
            if self.glb_stale_counter > 0 {
                self.stale_insert_counter += 1;
            }
        }
    }
}

impl<F : Float, D : 'static + Copy, A, const N : usize, const P : usize>
Branches<F,D,A,N,P>
where Const<P> : BranchCount<N>,
      A : Aggregator {

    /// Stage all subnodes of `self` into the refinement queue [`container`].
    fn stage_refine<'a, S, RResult>(
        &'a mut self,
        domain : Cube<F,N>,
        container : &mut HeapContainer<'a, F, D, A, S, RResult, N, P>,
    ) where S : AggregatorSorting<Agg = A> {
        // Insert all subnodes into the refinement heap.
        for (node, subcube) in self.nodes_and_cubes_mut(&domain) {
            container.push(RefinementInfo {
                cube : subcube,
                node : node,
                refiner_info : None,
                sorting : PhantomData,
            });
        }
    }
}


impl<F : Float, D : 'static + Copy, A, const N : usize, const P : usize>
Node<F,D,A,N,P>
where Const<P> : BranchCount<N>,
      A : Aggregator {

    /// If `self` is a leaf node, uses the `refiner` to determine whether further subdivision
    /// is required to get a sufficiently refined solution for the problem the refiner is used
    /// to solve. If the refiner returns [`RefinerResult::Certain`] result, it is returned.
    /// If [`RefinerResult::Uncertain`] is returned, the leaf is inserted back into the refinement
    /// queue `container`. If `self` is a branch, its subnodes are staged into `container` using
    /// [`Branches::stage_refine`].
    ///
    /// `domain`, as usual, indicates the spatial area corresponding to `self`.
    fn search_and_refine<'a, 'b, R, G>(
        &'a mut self,
        domain : Cube<F,N>,
        refiner : &R,
        generator : &G,
        container : &'b mut HeapContainer<'a, F, D, A, R::Sorting, R::Result, N, P>,
        step : usize
    ) -> Option<R::Result>
    where R : Refiner<F, A, G, N>,
          G : SupportGenerator<F, N, Id=D>,
          G::SupportType : LocalAnalysis<F, A, N> {

        // The “complex” repeated pattern matching here is forced by mutability requirements.

        // Refine a leaf.
        let res = if let NodeOption::Leaf(ref v) = &mut self.data {
            let res = refiner.refine(&self.aggregator, &domain, v, generator, step);
            if let NeedRefinement = res {
                // The refiner has deemed the leaf unsufficiently refined, so subdivide
                // it and add the new nodes into the refinement priority heap.
                // We start iterating from the end to mix support_hint a bit.
                let mut it = v.iter().rev();
                if let Some(&d) = it.next() {
                    // Construct new Branches
                    let support = generator.support_for(d);
                    let b = Rc::new({
                        let mut b0 = Branches::new_with(&domain, &support);
                        b0.insert(&domain, d, Const::<1>, &support);
                        for &d in it {
                            let support = generator.support_for(d);
                            // TODO: can we be smarter than just refining one level?
                            b0.insert(&domain, d, Const::<1>, &support);
                        }
                        b0
                    });
                    // Update current node
                    self.aggregator.summarise(b.aggregators());
                    self.data = NodeOption::Branches(b);
                    // The branches will be inserted into the refinement priority queue below.
                }
            }
            res
        } else {
            NeedRefinement
        };

        if let Uncertain(agg, val) = res {
            // The refiner gave an undertain result. Push a leaf back into the refinement queue
            // with the new refined aggregator and custom return value. It will be popped and
            // returned in the loop of [`BT::search_and_refine`] when there are no unrefined
            // candidates that could potentially be better according to their basic aggregator.
            container.push(RefinementInfo {
                cube : domain,
                node : self,
                refiner_info : Some((agg, val)),
                sorting : PhantomData,
            });
            None
        } else if let Certain(val) = res {
            // The refiner gave a certain result so return it to allow early termination
            Some(val)
        } else if let NodeOption::Branches(ref mut b) = &mut self.data {
            // Insert branches into refinement priority queue.
            Rc::make_mut(b).stage_refine(domain, container);
            None
        } else {
            None
        }
    }
}

/// Interface trait to a refining search on a [`BT`].
///
/// This can be removed and the methods implemented directly on [`BT`] once Rust's const generics
/// are flexible enough to allow fixing `P=pow(2, N)`.
pub trait BTSearch<F, const N : usize> : BTImpl<F, N>
where F : Float {

    /// Perform a search on [`Self`], as determined by `refiner`.
    ///
    /// Nodes are inserted in a priority queue and processed in the order determined by the
    /// [`AggregatorSorting`] [`Refiner::Sorting`]. Leaf nodes are subdivided until the refiner
    /// decides that a sufficiently refined leaf node has been found, as determined by either the
    /// refiner returning a [`RefinerResult::Certain`] result, or a previous
    /// [`RefinerResult::Uncertain`] result is found again at the top of the priority queue.
    ///
    /// The `generator` converts [`BTImpl::Data`] stored in the bisection tree into a [`Support`].
    fn search_and_refine<'b, R, G>(
        &'b mut self,
        refiner : &R,
        generator : &G,
    ) -> Option<R::Result>
    where R : Refiner<F, Self::Agg, G, N>,
          G : SupportGenerator<F, N, Id=Self::Data>,
          G::SupportType : LocalAnalysis<F, Self::Agg, N>;
}

// Needed to get access to a Node without a trait interface.
macro_rules! impl_btsearch {
    ($($n:literal)*) => { $(
        impl<'a, M, F, D, A>
        BTSearch<F, $n>
        for BT<M,F,D,A,$n>
        where //Self : BTImpl<F,$n,Data=D,Agg=A, Depth=M>,  //  <== automatically deduce to be implemented
              M : Depth,
              F : Float,
              A : 'a + Aggregator,
              D : 'static + Copy + std::fmt::Debug {
            fn search_and_refine<'b, R, G>(
                &'b mut self,
                refiner : &R,
                generator : &G,
            ) -> Option<R::Result>
            where R : Refiner<F, A, G, $n>,
                  G : SupportGenerator<F, $n, Id=D>,
                  G::SupportType : LocalAnalysis<F, A, $n> {
                let mut container = HeapContainer {
                    heap : BinaryHeap::new(),
                    glb : R::Sorting::bottom(),
                    glb_stale_counter : 0,
                    stale_insert_counter : 0,
                };
                container.push(RefinementInfo {
                    cube : self.domain,
                    node : &mut self.topnode,
                    refiner_info : None,
                    sorting : PhantomData,
                });
                let mut step = 0;
                while let Some(ri) = container.heap.pop() {
                    if let Some((_, result)) = ri.refiner_info {
                        // Terminate based on a “best possible” result.
                        return Some(result)
                    }

                    if ri.sort_lower() >= container.glb {
                        container.glb_stale_counter += 1;
                        if container.stale_insert_counter + container.glb_stale_counter
                           > container.heap.len()/2 {
                            // GLB propery no longer correct.
                            match container.heap.iter().map(|ri| ri.sort_lower()).reduce(max) {
                                Some(glb) => {
                                    container.glb = glb;
                                    container.heap.retain(|ri| ri.sort_upper() >= glb);
                                },
                                None => {
                                    container.glb = R::Sorting::bottom()
                                }
                            }
                            container.glb_stale_counter = 0;
                            container.stale_insert_counter = 0;
                        }
                    }

                    let res = ri.node.search_and_refine(ri.cube, refiner, generator,
                                                        &mut container, step);
                    if let Some(_) = res {
                        // Terminate based on a certain result from the refiner
                        return res
                    }

                    step += 1;
                }
                None
            }
        }
    )* }
}

impl_btsearch!(1 2 3 4);

mercurial