src/bisection_tree/bt.rs

Sat, 22 Oct 2022 18:12:49 +0300

author
Tuomo Valkonen <tuomov@iki.fi>
date
Sat, 22 Oct 2022 18:12:49 +0300
changeset 1
df3901ec2f5d
parent 0
9f27689eb130
child 5
59dc4c5883f4
permissions
-rw-r--r--

Fix some unit tests after fundamental changes that made them invalid


use std::slice::{Iter,IterMut};
use std::iter::once;
use std::rc::Rc;
use serde::{Serialize, Deserialize};
pub use nalgebra::Const;
use itertools::izip;

use crate::iter::{MapF,Mappable};
use crate::types::{Float, Num};
use crate::coefficients::pow;
use crate::maputil::{
    array_init,
    map2,
    map2_indexed,
    collect_into_array_unchecked
};
pub use crate::sets::Cube;
pub use crate::loc::Loc;
use super::support::*;
use super::aggregator::*;

#[derive(Clone,Debug)]
pub enum NodeOption<F : Num, D, A : Aggregator, const N : usize, const P : usize> {
    // TODO: Could optimise Uninitialised away by simply treat Leaf with an empty Vec as
    // something that can be still replaced with Branches.
    Uninitialised,
    // TODO: replace with QuickVec fast and w/o allocs on single elements.
    Leaf(Rc<Vec<D>>),
    Branches(Rc<Branches<F, D, A, N, P>>),
}

/// Node of a [`BT`] bisection tree.
#[derive(Clone,Debug)]
pub struct Node<F : Num, D, A : Aggregator, const N : usize, const P : usize> {
    pub(super) data : NodeOption<F, D, A, N, P>,
    /// Aggregator for `data`.
    pub(super) aggregator : A,
}

/// Branch information of a [`Node`] of a [`BT`] bisection tree.
#[derive(Clone,Debug)]
pub struct Branches<F : Num, D, A : Aggregator, const N : usize, const P : usize> {
    /// Point for subdivision of the (unstored) [`Cube`] corresponding to the node.
    pub(super) branch_at : Loc<F, N>,
    /// Subnodes
    pub(super) nodes : [Node<F, D, A, N, P>; P],
}

/// Dirty workaround to broken Rust drop, see [https://github.com/rust-lang/rust/issues/58068]().
impl<F : Num, D, A : Aggregator, const N : usize, const P : usize>
Drop for Node<F, D, A, N, P> {
    fn drop(&mut self) {
        use NodeOption as NO;

        let process = |brc : Rc<Branches<F, D, A, N, P>>,
                       to_drop : &mut Vec<Rc<Branches<F, D, A, N, P>>>| {
            // We only drop Branches if we have the only strong reference.
            Rc::try_unwrap(brc).ok().map(|branches| branches.nodes.map(|mut node| {
                if let NO::Branches(brc2) = std::mem::replace(&mut node.data, NO::Uninitialised) {
                    to_drop.push(brc2)
                }
            }));
        };

        // We mark Self as NodeOption::Uninitialised, extracting the real contents.
        // If we have subprocess, we need to process them.
        if let NO::Branches(brc1) = std::mem::replace(&mut self.data, NO::Uninitialised) {
            // We store a queue of Rc<Branches> to drop into a vector
            let mut to_drop = Vec::new();
            process(brc1, &mut to_drop);

            // While there are any Branches in the drop queue vector, we continue the process,
            // pushing all internal branching nodes into the queue.
            while let Some(brc) = to_drop.pop() {
                process(brc, &mut to_drop)
            }
        }
    }
}

pub trait Depth : 'static + Copy + std::fmt::Debug {
    type Lower : Depth;
    fn lower(&self) -> Option<Self::Lower>;
    fn lower_or(&self) -> Self::Lower;
}

#[derive(Copy,Clone,Debug,Serialize,Deserialize)]
pub struct DynamicDepth(pub u8);
impl Depth for DynamicDepth {
    type Lower = Self;
    #[inline]
    fn lower(&self) -> Option<Self> {
        if self.0>0 {
            Some(DynamicDepth(self.0-1))
         } else {
            None
        }
    }
    #[inline]
    fn lower_or(&self) -> Self {
        DynamicDepth(if self.0>0 { self.0 - 1 } else { 0 })
    }
}

impl Depth for Const<0> {
    type Lower = Self;
    fn lower(&self) -> Option<Self::Lower> { None }
    fn lower_or(&self) -> Self::Lower { Const }
}

macro_rules! impl_constdepth {
    ($($n:literal)*) => { $(
        impl Depth for Const<$n> {
            type Lower = Const<{$n-1}>;
            fn lower(&self) -> Option<Self::Lower> { Some(Const) }
            fn lower_or(&self) -> Self::Lower { Const }
        }
    )* };
}
impl_constdepth!(1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32);


pub trait BranchCount<const N : usize> {}
macro_rules! impl_branchcount {
    ($($n:literal)*) => { $(
        impl BranchCount<$n> for Const<{pow(2, $n)}>{}
    )* }
}
impl_branchcount!(1 2 3 4 5 6 7 8);

impl<F : Float, D, A, const N : usize, const P : usize> Branches<F,D,A,N,P>
where Const<P> : BranchCount<N>,
      A : Aggregator
{
    fn get_node_index(&self, x : &Loc<F, N>) -> usize {
        izip!(0..P, x.iter(), self.branch_at.iter()).map(|(i, x_i, branch_i)|
            if x_i > branch_i { 1<<i } else { 0 }
        ).sum()
    }

    #[inline]
    pub fn get_node(&self, x : &Loc<F,N>) -> &Node<F,D,A,N,P> {
         &self.nodes[self.get_node_index(x)]
    }
}


pub struct BTIter<'a, D> {
    iter : Iter<'a, D>,
}

pub struct SubcubeIter<'b, F : Float, const N : usize, const P : usize> {
    domain : &'b Cube<F, N>,
    branch_at : Loc<F, N>,
    index : usize,
}

#[inline]
fn get_subcube<F : Float, const N : usize>(branch_at : &Loc<F, N>, domain : &Cube<F, N>, i : usize) -> Cube<F, N> {
    map2_indexed(branch_at, domain, move |j, &branch, &[start, end]| {
        if i & (1 << j) != 0 {
            [branch, end]
        } else {
            [start, branch]
        }
    }).into()
}

impl<'a, 'b, F : Float, const N : usize, const P : usize> Iterator
for SubcubeIter<'b, F, N, P> {
    type Item = Cube<F, N>;
    #[inline]
    fn next(&mut self) -> Option<Self::Item> {
        if self.index < P {
            let i = self.index;
            self.index += 1;
            Some(get_subcube(&self.branch_at, self.domain, i))
        } else {
            None
        }
    }
}

impl<F : Float, D : Copy, A, const N : usize, const P : usize>
Branches<F,D,A,N,P>
where Const<P> : BranchCount<N>,
      A : Aggregator {

    pub fn new_with<S : LocalAnalysis <F, A, N>>(
        domain : &Cube<F,N>,
        support : &S
    ) -> Self {
        let hint = support.bisection_hint(domain);
        let branch_at = map2(&hint, domain, |h, r| {
            h.unwrap_or_else(|| (r[0]+r[1])/F::TWO).max(r[0]).min(r[1])
        }).into();
        Branches{
            branch_at : branch_at,
            nodes : array_init(|| Node::new()),
        }
    }

    /// Get an iterator over the aggregators of the nodes of this branch head.
    #[inline]
    pub fn aggregators(&self) -> MapF<Iter<'_, Node<F,D,A,N,P>>, &'_ A> {
        self.nodes.iter().mapF(Node::get_aggregator)
    }

    #[inline]
    pub fn iter_subcubes<'b>(&self, domain : &'b Cube<F, N>)
    -> SubcubeIter<'b, F, N, P> {
        SubcubeIter {
            domain : domain,
            branch_at : self.branch_at,
            index : 0,
        }
    }

    /// Iterate over all nodes and corresponding subcubes of self.
    #[inline]
    pub fn nodes_and_cubes<'a, 'b>(&'a self, domain : &'b Cube<F, N>)
    -> std::iter::Zip<Iter<'a, Node<F,D,A,N,P>>, SubcubeIter<'b, F, N, P>> {
        self.nodes.iter().zip(self.iter_subcubes(domain))
    }

    /// Mutably iterate over all nodes and corresponding subcubes of self.
    #[inline]
    pub fn nodes_and_cubes_mut<'a, 'b>(&'a mut self, domain : &'b Cube<F, N>)
    -> std::iter::Zip<IterMut<'a, Node<F,D,A,N,P>>, SubcubeIter<'b, F, N, P>> {
        let subcube_iter = self.iter_subcubes(domain);
        self.nodes.iter_mut().zip(subcube_iter)
    }

    /// Insert data into the branch.
    pub fn insert<M : Depth, S : LocalAnalysis<F, A, N>>(
        &mut self,
        domain : &Cube<F,N>,
        d : D,
        new_leaf_depth : M,
        support : &S
    ) {
        let support_hint = support.support_hint();
        for (node, subcube) in self.nodes_and_cubes_mut(&domain) {
            if support_hint.intersects(&subcube) {
                node.insert(
                    &subcube,
                    d,
                    new_leaf_depth,
                    support
                );
            }
        }
    }

    /// Construct a new instance for a different aggregator
    pub fn convert_aggregator<ANew, G>(
        self,
        generator : &G,
        domain : &Cube<F, N>
    ) -> Branches<F,D,ANew,N,P>
    where ANew : Aggregator,
          G : SupportGenerator<F, N, Id=D>,
          G::SupportType : LocalAnalysis<F, ANew, N> {
        let branch_at = self.branch_at;
        let subcube_iter = self.iter_subcubes(domain);
        let new_nodes = self.nodes.into_iter().zip(subcube_iter).map(|(node, subcube)| {
            // TODO: avoid clone
            node.convert_aggregator(generator, &subcube)
        });
        Branches {
            branch_at : branch_at,
            nodes : collect_into_array_unchecked(new_nodes),
        }
    }

    /// Recalculate aggregator after changes to generator
    pub fn refresh_aggregator<G>(
        &mut self,
        generator : &G,
        domain : &Cube<F, N>
    ) where G : SupportGenerator<F, N, Id=D>,
            G::SupportType : LocalAnalysis<F, A, N> {
        for (node, subcube) in self.nodes_and_cubes_mut(domain) {
            node.refresh_aggregator(generator, &subcube)
        }
    }
}

impl<F : Float, D : Copy, A, const N : usize, const P : usize>
Node<F,D,A,N,P>
where Const<P> : BranchCount<N>,
      A : Aggregator {

    #[inline]
    pub fn new() -> Self {
        Node {
            data : NodeOption::Uninitialised,
            aggregator : A::new(),
        }
    }

    /// Get leaf data
    #[inline]
    pub fn get_leaf_data(&self, x : &Loc<F, N>) -> Option<&Vec<D>> {
        match self.data {
            NodeOption::Uninitialised => None,
            NodeOption::Leaf(ref data) => Some(data),
            NodeOption::Branches(ref b) => b.get_node(x).get_leaf_data(x),
        }
    }

    /// Returns a reference to the aggregator of this node
    #[inline]
    pub fn get_aggregator(&self) -> &A {
        &self.aggregator
    }

    /// Insert `d` into the tree. If an `Incomplete` node is encountered, a new
    /// leaf is created at a minimum depth of `new_leaf_depth`
    pub fn insert<M : Depth, S : LocalAnalysis <F, A, N>>(
        &mut self,
        domain : &Cube<F,N>,
        d : D,
        new_leaf_depth : M,
        support : &S
    ) {
        match &mut self.data {
            NodeOption::Uninitialised => {
                // Replace uninitialised node with a leaf or a branch
                self.data = match new_leaf_depth.lower() {
                    None => {
                        let a = support.local_analysis(&domain);
                        self.aggregator.aggregate(once(a));
                        // TODO: this is currently a dirty hard-coded heuristic;
                        // should add capacity as a parameter
                        let mut vec = Vec::with_capacity(2*P+1);
                        vec.push(d);
                        NodeOption::Leaf(Rc::new(vec))
                    },
                    Some(lower) => {
                        let b = Rc::new({
                            let mut b0 = Branches::new_with(domain, support);
                            b0.insert(domain, d, lower, support);
                            b0
                        });
                        self.aggregator.summarise(b.aggregators());
                        NodeOption::Branches(b)
                    }
                }
            },
            NodeOption::Leaf(leaf) => {
                Rc::make_mut(leaf).push(d);
                let a = support.local_analysis(&domain);
                self.aggregator.aggregate(once(a));
            },
            NodeOption::Branches(b) => {
                Rc::make_mut(b).insert(domain, d, new_leaf_depth.lower_or(), support);
                self.aggregator.summarise(b.aggregators());
            },
        }
    }

    /// Construct a new instance for a different aggregator
    pub fn convert_aggregator<ANew, G>(
        mut self,
        generator : &G,
        domain : &Cube<F, N>
    ) -> Node<F,D,ANew,N,P>
    where ANew : Aggregator,
          G : SupportGenerator<F, N, Id=D>,
          G::SupportType : LocalAnalysis<F, ANew, N> {

        // The mem::replace is needed due to the [`Drop`] implementation to extract self.data.
        match std::mem::replace(&mut self.data, NodeOption::Uninitialised) {
            NodeOption::Uninitialised => Node {
                data : NodeOption::Uninitialised,
                aggregator : ANew::new(),
            },
            NodeOption::Leaf(v) => {
                let mut anew = ANew::new();
                anew.aggregate(v.iter().map(|d| {
                    let support = generator.support_for(*d);
                    support.local_analysis(&domain)
                }));

                Node {
                    data : NodeOption::Leaf(v),
                    aggregator : anew,
                }
            },
            NodeOption::Branches(b) => {
                // TODO: now with Rc, convert_aggregator should be reference-based.
                let bnew = Rc::new(Rc::unwrap_or_clone(b).convert_aggregator(generator, domain));
                let mut anew = ANew::new();
                anew.summarise(bnew.aggregators());
                Node {
                    data : NodeOption::Branches(bnew),
                    aggregator : anew,
                }
            }
        }
    }

    /// Refresh aggregator after changes to generator
    pub fn refresh_aggregator<G>(
        &mut self,
        generator : &G,
        domain : &Cube<F, N>
    ) where G : SupportGenerator<F, N, Id=D>,
            G::SupportType : LocalAnalysis<F, A, N> {
        match &mut self.data {
            NodeOption::Uninitialised => { },
            NodeOption::Leaf(v) => {
                self.aggregator = A::new();
                self.aggregator.aggregate(v.iter().map(|d| {
                    generator.support_for(*d)
                             .local_analysis(&domain)
                }));
            },
            NodeOption::Branches(ref mut b) => {
                // TODO: now with Rc, convert_aggregator should be reference-based.
                Rc::make_mut(b).refresh_aggregator(generator, domain);
                self.aggregator.summarise(b.aggregators());
            }
        }
    }
}

impl<'a, D> Iterator for BTIter<'a,D> {
    type Item = &'a D;
    #[inline]
    fn next(&mut self) -> Option<&'a D> {
        self.iter.next()
    }
}


//
// BT
//

/// Internal structure to hide the `const P : usize` parameter of [`Node`] until
/// const generics are flexible enough to fix `P=pow(2, N)`.
pub trait BTNode<F, D, A, const N : usize>
where F : Float,
        D : 'static + Copy,
        A : Aggregator {
    type Node : Clone + std::fmt::Debug;
}

#[derive(Debug)]
pub struct BTNodeLookup;

/// Interface to a [`BT`] bisection tree.
pub trait BTImpl<F : Float, const N : usize> : std::fmt::Debug + Clone + GlobalAnalysis<F, Self::Agg> {
    type Data : 'static + Copy;
    type Depth : Depth;
    type Agg : Aggregator;
    type Converted<ANew> : BTImpl<F, N, Data=Self::Data, Agg=ANew> where ANew : Aggregator;

    /// Insert d into the `BisectionTree`.
    fn insert<S : LocalAnalysis<F, Self::Agg, N>>(
        &mut self,
        d : Self::Data,
        support : &S
    );

    /// Construct a new instance for a different aggregator
    fn convert_aggregator<ANew, G>(self, generator : &G)
    -> Self::Converted<ANew>
    where ANew : Aggregator,
          G : SupportGenerator<F, N, Id=Self::Data>,
          G::SupportType : LocalAnalysis<F, ANew, N>;


    /// Refresh aggregator after changes to generator
    fn refresh_aggregator<G>(&mut self, generator : &G)
    where G : SupportGenerator<F, N, Id=Self::Data>,
          G::SupportType : LocalAnalysis<F, Self::Agg, N>;

    /// Iterarate items at x
    fn iter_at<'a>(&'a self, x : &'a Loc<F,N>) -> BTIter<'a, Self::Data>;

    /// Create a new instance
    fn new(domain : Cube<F, N>, depth : Self::Depth) -> Self;
}

/// The main bisection tree structure. The interface operations are via [`BTImpl`]
/// to hide the `const P : usize` parameter until const generics are flexible enough
/// to fix `P=pow(2, N)`.
#[derive(Clone,Debug)]
pub struct BT<
    M : Depth,
    F : Float,
    D : 'static + Copy,
    A : Aggregator,
    const N : usize,
> where BTNodeLookup : BTNode<F, D, A, N> {
    pub(super) depth : M,
    pub(super) domain : Cube<F, N>,
    pub(super) topnode : <BTNodeLookup as BTNode<F, D, A, N>>::Node,
}

macro_rules! impl_bt {
    ($($n:literal)*) => { $(
        impl<F, D, A> BTNode<F, D, A, $n> for BTNodeLookup
        where F : Float,
              D : 'static + Copy + std::fmt::Debug,
              A : Aggregator {
            type Node = Node<F,D,A,$n,{pow(2, $n)}>;
        }

        impl<M,F,D,A> BTImpl<F,$n> for BT<M,F,D,A,$n>
        where M : Depth,
              F : Float,
              D : 'static + Copy + std::fmt::Debug,
              A : Aggregator {
            type Data = D;
            type Depth = M;
            type Agg = A;
            type Converted<ANew> = BT<M,F,D,ANew,$n> where ANew : Aggregator;

            /// Insert `d` into the tree.
            fn insert<S: LocalAnalysis<F, A, $n>>(
                &mut self,
                d : D,
                support : &S
            ) {
                self.topnode.insert(
                    &self.domain,
                    d,
                    self.depth,
                    support
                );
            }

            /// Construct a new instance for a different aggregator
            fn convert_aggregator<ANew, G>(self, generator : &G) -> Self::Converted<ANew>
            where ANew : Aggregator,
                  G : SupportGenerator<F, $n, Id=D>,
                  G::SupportType : LocalAnalysis<F, ANew, $n> {
                let topnode = self.topnode.convert_aggregator(generator, &self.domain);
                BT {
                    depth : self.depth,
                    domain : self.domain,
                    topnode
                }
            }

            /// Refresh aggregator after changes to generator
            fn refresh_aggregator<G>(&mut self, generator : &G)
            where G : SupportGenerator<F, $n, Id=Self::Data>,
                G::SupportType : LocalAnalysis<F, Self::Agg, $n> {
                self.topnode.refresh_aggregator(generator, &self.domain);
            }

            /// Iterate elements at `x`.
            fn iter_at<'a>(&'a self, x : &'a Loc<F,$n>) -> BTIter<'a,D> {
                match self.topnode.get_leaf_data(x) {
                    Some(data) => BTIter { iter : data.iter() },
                    None => BTIter { iter : [].iter() }
                }
            }

            fn new(domain : Cube<F, $n>, depth : M) -> Self {
                BT {
                    depth : depth,
                    domain : domain,
                    topnode : Node::new(),
                }
            }
        }

        impl<M,F,D,A> GlobalAnalysis<F,A> for BT<M,F,D,A,$n>
        where M : Depth,
              F : Float,
              D : 'static + Copy + std::fmt::Debug,
              A : Aggregator {
            fn global_analysis(&self) -> A {
                self.topnode.get_aggregator().clone()
            }
        }
    )* }
}

impl_bt!(1 2 3 4);

mercurial