src/bisection_tree/bt.rs

Thu, 01 May 2025 13:06:58 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 01 May 2025 13:06:58 -0500
branch
dev
changeset 124
6aa955ad8122
parent 97
4e80fb049dca
permissions
-rw-r--r--

Transpose loc parameters to allow f64 defaults

/*!
Bisection tree basics, [`BT`] type and the [`BTImpl`] trait.
*/

use itertools::izip;
pub(super) use nalgebra::Const;
use serde::{Deserialize, Serialize};
use std::iter::once;
use std::slice::IterMut;
use std::sync::Arc;

use super::aggregator::*;
use super::support::*;
use crate::coefficients::pow;
use crate::loc::Loc;
use crate::maputil::{array_init, collect_into_array_unchecked, map2, map2_indexed};
use crate::parallelism::{with_task_budget, TaskBudget};
use crate::sets::Cube;
use crate::types::{Float, Num};

/// An enum that indicates whether a [`Node`] of a [`BT`] is uninitialised, leaf, or branch.
///
/// For the type and const parametere, see the [module level documentation][super].
#[derive(Clone, Debug)]
pub(super) enum NodeOption<F: Num, D, A: Aggregator, const N: usize, const P: usize> {
    /// Indicates an uninitilised node; may become a branch or a leaf.
    // TODO: Could optimise Uninitialised away by simply treat Leaf with an empty Vec as
    // something that can be still replaced with Branches.
    Uninitialised,
    /// Indicates a leaf node containing a copy-on-write reference-counted vector
    /// of data of type `D`.
    Leaf(Vec<D>),
    /// Indicates a branch node, cotaning a copy-on-write reference to the [`Branches`].
    Branches(Arc<Branches<F, D, A, N, P>>),
}

/// Node of a [`BT`] bisection tree.
///
/// For the type and const parameteres, see the [module level documentation][super].
#[derive(Clone, Debug)]
pub struct Node<F: Num, D, A: Aggregator, const N: usize, const P: usize> {
    /// The data or branches under the node.
    pub(super) data: NodeOption<F, D, A, N, P>,
    /// Aggregator for `data`.
    pub(super) aggregator: A,
}

/// Branching information of a [`Node`] of a [`BT`] bisection tree into `P` subnodes.
///
/// For the type and const parameters, see the [module level documentation][super].
#[derive(Clone, Debug)]
pub(super) struct Branches<F: Num, D, A: Aggregator, const N: usize, const P: usize> {
    /// Point for subdivision of the (unstored) [`Cube`] corresponding to the node.
    pub(super) branch_at: Loc<N, F>,
    /// Subnodes
    pub(super) nodes: [Node<F, D, A, N, P>; P],
}

/// Dirty workaround to broken Rust drop, see [https://github.com/rust-lang/rust/issues/58068]().
impl<F: Num, D, A: Aggregator, const N: usize, const P: usize> Drop for Node<F, D, A, N, P> {
    fn drop(&mut self) {
        use NodeOption as NO;

        let process = |brc: Arc<Branches<F, D, A, N, P>>,
                       to_drop: &mut Vec<Arc<Branches<F, D, A, N, P>>>| {
            // We only drop Branches if we have the only strong reference.
            // FIXME: update the RwLocks on Nodes.
            Arc::try_unwrap(brc).ok().map(|branches| {
                branches.nodes.map(|mut node| {
                    if let NO::Branches(brc2) = std::mem::replace(&mut node.data, NO::Uninitialised)
                    {
                        to_drop.push(brc2)
                    }
                })
            });
        };

        // We mark Self as NodeOption::Uninitialised, extracting the real contents.
        // If we have subprocess, we need to process them.
        if let NO::Branches(brc1) = std::mem::replace(&mut self.data, NO::Uninitialised) {
            // We store a queue of Arc<Branches> to drop into a vector
            let mut to_drop = Vec::new();
            process(brc1, &mut to_drop);

            // While there are any Branches in the drop queue vector, we continue the process,
            // pushing all internal branching nodes into the queue.
            while let Some(brc) = to_drop.pop() {
                process(brc, &mut to_drop)
            }
        }
    }
}

/// Trait for the depth of a [`BT`].
///
/// This will generally be either a runtime [`DynamicDepth`] or compile-time [`Const`] depth.
pub trait Depth: 'static + Copy + Send + Sync + std::fmt::Debug {
    /// Lower depth type.
    type Lower: Depth;

    /// Returns a lower depth, if there still is one.
    fn lower(&self) -> Option<Self::Lower>;

    /// Returns a lower depth or self if this is the lowest depth.
    fn lower_or(&self) -> Self::Lower;

    /// Returns the numeric value of the depth
    fn value(&self) -> u32;
}

/// Dynamic (runtime) [`Depth`] for a [`BT`].
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct DynamicDepth(
    /// The depth
    pub u8,
);

impl Depth for DynamicDepth {
    type Lower = Self;
    #[inline]
    fn lower(&self) -> Option<Self> {
        if self.0 > 0 {
            Some(DynamicDepth(self.0 - 1))
        } else {
            None
        }
    }

    #[inline]
    fn lower_or(&self) -> Self {
        DynamicDepth(if self.0 > 0 { self.0 - 1 } else { 0 })
    }

    #[inline]
    fn value(&self) -> u32 {
        self.0 as u32
    }
}

impl Depth for Const<0> {
    type Lower = Self;
    fn lower(&self) -> Option<Self::Lower> {
        None
    }
    fn lower_or(&self) -> Self::Lower {
        Const
    }
    fn value(&self) -> u32 {
        0
    }
}

macro_rules! impl_constdepth {
    ($($n:literal)*) => { $(
        impl Depth for Const<$n> {
            type Lower = Const<{$n-1}>;
            fn lower(&self) -> Option<Self::Lower> { Some(Const) }
            fn lower_or(&self) -> Self::Lower { Const }
            fn value(&self) -> u32 { $n }
        }
    )* };
}
impl_constdepth!(1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32);

/// Trait for counting the branching factor of a [`BT`] of dimension `N`.
///
/// The const parameter `P` from the [module level documentation][super] is required to satisfy
/// `Const<P> : Branchcount<N>`.
/// This trait is implemented for `P=pow(2, N)` for small `N`.
pub trait BranchCount<const N: usize> {}
macro_rules! impl_branchcount {
    ($($n:literal)*) => { $(
        impl BranchCount<$n> for Const<{pow(2, $n)}>{}
    )* }
}
impl_branchcount!(1 2 3 4 5 6 7 8);

impl<F: Float, D, A, const N: usize, const P: usize> Branches<F, D, A, N, P>
where
    Const<P>: BranchCount<N>,
    A: Aggregator,
{
    /// Returns the index in {0, …, `P`-1} for the branch to which the point `x` corresponds.
    ///
    /// This only takes the branch subdivision point $d$ into account, so is always succesfull.
    /// Thus, for this point, each branch corresponds to a quadrant of $ℝ^N$ relative to $d$.
    fn get_node_index(&self, x: &Loc<N, F>) -> usize {
        izip!(0..P, x.iter(), self.branch_at.iter())
            .map(|(i, x_i, branch_i)| if x_i > branch_i { 1 << i } else { 0 })
            .sum()
    }

    /// Returns the node within `Self` containing the point `x`.
    ///
    /// This only takes the branch subdivision point $d$ into account, so is always succesfull.
    /// Thus, for this point, each branch corresponds to a quadrant of $ℝ^N$ relative to $d$.
    #[inline]
    fn get_node(&self, x: &Loc<N, F>) -> &Node<F, D, A, N, P> {
        &self.nodes[self.get_node_index(x)]
    }
}

/// An iterator over the $P=2^N$ subcubes of a [`Cube`] subdivided at a point `d`.
pub(super) struct SubcubeIter<'b, F: Float, const N: usize, const P: usize> {
    domain: &'b Cube<N, F>,
    branch_at: Loc<N, F>,
    index: usize,
}

/// Returns the `i`:th subcube of `domain` subdivided at `branch_at`.
#[inline]
fn get_subcube<F: Float, const N: usize>(
    branch_at: &Loc<N, F>,
    domain: &Cube<N, F>,
    i: usize,
) -> Cube<N, F> {
    map2_indexed(branch_at, domain, move |j, &branch, &[start, end]| {
        if i & (1 << j) != 0 {
            [branch, end]
        } else {
            [start, branch]
        }
    })
    .into()
}

impl<'a, 'b, F: Float, const N: usize, const P: usize> Iterator for SubcubeIter<'b, F, N, P> {
    type Item = Cube<N, F>;
    #[inline]
    fn next(&mut self) -> Option<Self::Item> {
        if self.index < P {
            let i = self.index;
            self.index += 1;
            Some(get_subcube(&self.branch_at, self.domain, i))
        } else {
            None
        }
    }
}

impl<F: Float, D, A, const N: usize, const P: usize> Branches<F, D, A, N, P>
where
    Const<P>: BranchCount<N>,
    A: Aggregator,
    D: 'static + Copy + Send + Sync,
{
    /// Creates a new node branching structure, subdividing `domain` based on the
    /// [hint][Support::support_hint] of `support`.
    pub(super) fn new_with<S: LocalAnalysis<F, A, N> + Support<N, F>>(
        domain: &Cube<N, F>,
        support: &S,
    ) -> Self {
        let hint = support.bisection_hint(domain);
        let branch_at = map2(&hint, domain, |h, r| {
            h.unwrap_or_else(|| (r[0] + r[1]) / F::TWO)
                .max(r[0])
                .min(r[1])
        })
        .into();
        Branches {
            branch_at: branch_at,
            nodes: array_init(|| Node::new()),
        }
    }

    /// Summarises the aggregators of these branches into `agg`
    pub(super) fn summarise_into(&self, agg: &mut A) {
        // We need to create an array of the aggregators clones due to the RwLock.
        agg.summarise(self.nodes.iter().map(Node::get_aggregator));
    }

    /// Returns an iterator over the subcubes of `domain` subdivided at the branching point
    /// of `self`.
    #[inline]
    pub(super) fn iter_subcubes<'b>(&self, domain: &'b Cube<N, F>) -> SubcubeIter<'b, F, N, P> {
        SubcubeIter {
            domain: domain,
            branch_at: self.branch_at,
            index: 0,
        }
    }

    /*
    /// Returns an iterator over all nodes and corresponding subcubes of `self`.
    #[inline]
    pub(super) fn nodes_and_cubes<'a, 'b>(&'a self, domain : &'b Cube<N, F>)
    -> std::iter::Zip<Iter<'a, Node<F,D,A,N,P>>, SubcubeIter<'b, F, N, P>> {
        self.nodes.iter().zip(self.iter_subcubes(domain))
    }
    */

    /// Mutably iterate over all nodes and corresponding subcubes of `self`.
    #[inline]
    pub(super) fn nodes_and_cubes_mut<'a, 'b>(
        &'a mut self,
        domain: &'b Cube<N, F>,
    ) -> std::iter::Zip<IterMut<'a, Node<F, D, A, N, P>>, SubcubeIter<'b, F, N, P>> {
        let subcube_iter = self.iter_subcubes(domain);
        self.nodes.iter_mut().zip(subcube_iter)
    }

    /// Call `f` on all `(subnode, subcube)` pairs in multiple threads, if `guard` so deems.
    #[inline]
    fn recurse<'scope, 'smaller, 'refs>(
        &'smaller mut self,
        domain: &'smaller Cube<N, F>,
        task_budget: TaskBudget<'scope, 'refs>,
        guard: impl Fn(&Node<F, D, A, N, P>, &Cube<N, F>) -> bool + Send + 'smaller,
        mut f: impl for<'a> FnMut(&mut Node<F, D, A, N, P>, &Cube<N, F>, TaskBudget<'smaller, 'a>)
            + Send
            + Copy
            + 'smaller,
    ) where
        'scope: 'smaller,
    {
        let subs = self.nodes_and_cubes_mut(domain);
        task_budget.zoom(move |s| {
            for (node, subcube) in subs {
                if guard(node, &subcube) {
                    s.execute(move |new_budget| f(node, &subcube, new_budget))
                }
            }
        });
    }

    /// Insert data into the branch.
    ///
    /// The parameters are as follows:
    ///  * `domain` is the cube corresponding to this branch.
    ///  * `d` is the data to be inserted
    ///  * `new_leaf_depth` is the depth relative to `self` at which the data is to be inserted.
    ///  * `support` is the [`Support`] that is used determine with which subcubes of `domain`
    ///     (at subdivision depth `new_leaf_depth`) the data `d` is to be associated with.
    ///
    pub(super) fn insert<'refs, 'scope, M: Depth, S: LocalAnalysis<F, A, N> + Support<N, F>>(
        &mut self,
        domain: &Cube<N, F>,
        d: D,
        new_leaf_depth: M,
        support: &S,
        task_budget: TaskBudget<'scope, 'refs>,
    ) {
        let support_hint = support.support_hint();
        self.recurse(
            domain,
            task_budget,
            |_, subcube| support_hint.intersects(&subcube),
            move |node, subcube, new_budget| {
                node.insert(subcube, d, new_leaf_depth, support, new_budget)
            },
        );
    }

    /// Construct a new instance of the branch for a different aggregator.
    ///
    /// The `generator` is used to convert the data of type `D` of the branch into corresponding
    /// [`Support`]s. The `domain` is the cube corresponding to `self`.
    /// The type parameter `ANew´ is the new aggregator, and needs to be implemented for the
    /// generator's `SupportType`.
    pub(super) fn convert_aggregator<ANew, G>(
        self,
        generator: &G,
        domain: &Cube<N, F>,
    ) -> Branches<F, D, ANew, N, P>
    where
        ANew: Aggregator,
        G: SupportGenerator<N, F, Id = D>,
        G::SupportType: LocalAnalysis<F, ANew, N>,
    {
        let branch_at = self.branch_at;
        let subcube_iter = self.iter_subcubes(domain);
        let new_nodes = self
            .nodes
            .into_iter()
            .zip(subcube_iter)
            .map(|(node, subcube)| Node::convert_aggregator(node, generator, &subcube));
        Branches {
            branch_at: branch_at,
            nodes: collect_into_array_unchecked(new_nodes),
        }
    }

    /// Recalculate aggregator after changes to generator.
    ///
    /// The `generator` is used to convert the data of type `D` of the branch into corresponding
    /// [`Support`]s. The `domain` is the cube corresponding to `self`.
    pub(super) fn refresh_aggregator<'refs, 'scope, G>(
        &mut self,
        generator: &G,
        domain: &Cube<N, F>,
        task_budget: TaskBudget<'scope, 'refs>,
    ) where
        G: SupportGenerator<N, F, Id = D>,
        G::SupportType: LocalAnalysis<F, A, N>,
    {
        self.recurse(
            domain,
            task_budget,
            |_, _| true,
            move |node, subcube, new_budget| {
                node.refresh_aggregator(generator, subcube, new_budget)
            },
        );
    }
}

impl<F: Float, D, A, const N: usize, const P: usize> Node<F, D, A, N, P>
where
    Const<P>: BranchCount<N>,
    A: Aggregator,
    D: 'static + Copy + Send + Sync,
{
    /// Create a new node
    #[inline]
    pub(super) fn new() -> Self {
        Node {
            data: NodeOption::Uninitialised,
            aggregator: A::new(),
        }
    }

    /*
    /// Get leaf data
    #[inline]
    pub(super) fn get_leaf_data(&self, x : &Loc<N, F>) -> Option<&Vec<D>> {
        match self.data {
            NodeOption::Uninitialised => None,
            NodeOption::Leaf(ref data) => Some(data),
            NodeOption::Branches(ref b) => b.get_node(x).get_leaf_data(x),
        }
    }*/

    /// Get leaf data iterator
    #[inline]
    pub(super) fn get_leaf_data_iter(&self, x: &Loc<N, F>) -> Option<std::slice::Iter<'_, D>> {
        match self.data {
            NodeOption::Uninitialised => None,
            NodeOption::Leaf(ref data) => Some(data.iter()),
            NodeOption::Branches(ref b) => b.get_node(x).get_leaf_data_iter(x),
        }
    }

    /// Returns a reference to the aggregator of this node
    #[inline]
    pub(super) fn get_aggregator(&self) -> &A {
        &self.aggregator
    }

    /// Insert data under the node.
    ///
    /// The parameters are as follows:
    ///  * `domain` is the cube corresponding to this branch.
    ///  * `d` is the data to be inserted
    ///  * `new_leaf_depth` is the depth relative to `self` at which new leaves are created.
    ///  * `support` is the [`Support`] that is used determine with which subcubes of `domain`
    ///     (at subdivision depth `new_leaf_depth`) the data `d` is to be associated with.
    ///
    /// If `self` is already [`NodeOption::Leaf`], the data is inserted directly in this node.
    /// If `self` is a [`NodeOption::Branches`], the data is passed to branches whose subcubes
    /// `support` intersects. If an [`NodeOption::Uninitialised`] node is encountered, a new leaf is
    /// created at a minimum depth of `new_leaf_depth`.
    pub(super) fn insert<'refs, 'scope, M: Depth, S: LocalAnalysis<F, A, N> + Support<N, F>>(
        &mut self,
        domain: &Cube<N, F>,
        d: D,
        new_leaf_depth: M,
        support: &S,
        task_budget: TaskBudget<'scope, 'refs>,
    ) {
        match &mut self.data {
            NodeOption::Uninitialised => {
                // Replace uninitialised node with a leaf or a branch
                self.data = match new_leaf_depth.lower() {
                    None => {
                        let a = support.local_analysis(&domain);
                        self.aggregator.aggregate(once(a));
                        // TODO: this is currently a dirty hard-coded heuristic;
                        // should add capacity as a parameter
                        let mut vec = Vec::with_capacity(2 * P + 1);
                        vec.push(d);
                        NodeOption::Leaf(vec)
                    }
                    Some(lower) => {
                        let b = Arc::new({
                            let mut b0 = Branches::new_with(domain, support);
                            b0.insert(domain, d, lower, support, task_budget);
                            b0
                        });
                        b.summarise_into(&mut self.aggregator);
                        NodeOption::Branches(b)
                    }
                }
            }
            NodeOption::Leaf(leaf) => {
                leaf.push(d);
                let a = support.local_analysis(&domain);
                self.aggregator.aggregate(once(a));
            }
            NodeOption::Branches(b) => {
                // FIXME: recursion that may cause stack overflow if the tree becomes
                // very deep, e.g. due to [`BTSearch::search_and_refine`].
                let bm = Arc::make_mut(b);
                bm.insert(domain, d, new_leaf_depth.lower_or(), support, task_budget);
                bm.summarise_into(&mut self.aggregator);
            }
        }
    }

    /// Construct a new instance of the node for a different aggregator
    ///
    /// The `generator` is used to convert the data of type `D` of the node into corresponding
    /// [`Support`]s. The `domain` is the cube corresponding to `self`.
    /// The type parameter `ANew´ is the new aggregator, and needs to be implemented for the
    /// generator's `SupportType`.
    pub(super) fn convert_aggregator<ANew, G>(
        mut self,
        generator: &G,
        domain: &Cube<N, F>,
    ) -> Node<F, D, ANew, N, P>
    where
        ANew: Aggregator,
        G: SupportGenerator<N, F, Id = D>,
        G::SupportType: LocalAnalysis<F, ANew, N>,
    {
        // The mem::replace is needed due to the [`Drop`] implementation to extract self.data.
        match std::mem::replace(&mut self.data, NodeOption::Uninitialised) {
            NodeOption::Uninitialised => Node {
                data: NodeOption::Uninitialised,
                aggregator: ANew::new(),
            },
            NodeOption::Leaf(v) => {
                let mut anew = ANew::new();
                anew.aggregate(v.iter().map(|d| {
                    let support = generator.support_for(*d);
                    support.local_analysis(&domain)
                }));

                Node {
                    data: NodeOption::Leaf(v),
                    aggregator: anew,
                }
            }
            NodeOption::Branches(b) => {
                // FIXME: recursion that may cause stack overflow if the tree becomes
                // very deep, e.g. due to [`BTSearch::search_and_refine`].
                let bnew = Arc::unwrap_or_clone(b).convert_aggregator(generator, domain);
                let mut anew = ANew::new();
                bnew.summarise_into(&mut anew);
                Node {
                    data: NodeOption::Branches(Arc::new(bnew)),
                    aggregator: anew,
                }
            }
        }
    }

    /// Refresh aggregator after changes to generator.
    ///
    /// The `generator` is used to convert the data of type `D` of the node into corresponding
    /// [`Support`]s. The `domain` is the cube corresponding to `self`.
    pub(super) fn refresh_aggregator<'refs, 'scope, G>(
        &mut self,
        generator: &G,
        domain: &Cube<N, F>,
        task_budget: TaskBudget<'scope, 'refs>,
    ) where
        G: SupportGenerator<N, F, Id = D>,
        G::SupportType: LocalAnalysis<F, A, N>,
    {
        match &mut self.data {
            NodeOption::Uninitialised => {}
            NodeOption::Leaf(v) => {
                self.aggregator = A::new();
                self.aggregator.aggregate(
                    v.iter()
                        .map(|d| generator.support_for(*d).local_analysis(&domain)),
                );
            }
            NodeOption::Branches(ref mut b) => {
                // FIXME: recursion that may cause stack overflow if the tree becomes
                // very deep, e.g. due to [`BTSearch::search_and_refine`].
                let bm = Arc::make_mut(b);
                bm.refresh_aggregator(generator, domain, task_budget);
                bm.summarise_into(&mut self.aggregator);
            }
        }
    }
}

/// Helper trait for working with [`Node`]s without the knowledge of `P`.
///
/// This can be removed and the methods implemented directly on [`BT`] once Rust's const generics
/// are flexible enough to allow fixing `P=pow(2, N)`.
pub trait BTNode<F, D, A, const N: usize>
where
    F: Float,
    D: 'static + Copy,
    A: Aggregator,
{
    type Node: Clone + std::fmt::Debug;
}

/// Helper structure for looking up a [`Node`] without the knowledge of `P`.
///
/// This can be removed once Rust's const generics are flexible enough to allow fixing
/// `P=pow(2, N)`.
#[derive(Debug)]
pub struct BTNodeLookup;

/// Basic interface to a [`BT`] bisection tree.
///
/// Further routines are provided by the [`BTSearch`][super::refine::BTSearch] trait.
pub trait BTImpl<const N: usize, F: Float = f64>:
    std::fmt::Debug + Clone + GlobalAnalysis<F, Self::Agg>
{
    /// The data type stored in the tree
    type Data: 'static + Copy + Send + Sync;
    /// The depth type of the tree
    type Depth: Depth;
    /// The type for the [aggregate information][Aggregator] about the `Data` stored in each node
    /// of the tree.
    type Agg: Aggregator;
    /// The type of the tree with the aggregator converted to `ANew`.
    type Converted<ANew>: BTImpl<N, F, Data = Self::Data, Agg = ANew>
    where
        ANew: Aggregator;

    /// Insert the data `d` into the tree for `support`.
    ///
    /// Every leaf node of the tree that intersects the `support` will contain a copy of
    /// `d`.
    fn insert<S: LocalAnalysis<F, Self::Agg, N> + Support<N, F>>(
        &mut self,
        d: Self::Data,
        support: &S,
    );

    /// Construct a new instance of the tree for a different aggregator
    ///
    /// The `generator` is used to convert the data of type [`Self::Data`] contained in the tree
    /// into corresponding [`Support`]s.
    fn convert_aggregator<ANew, G>(self, generator: &G) -> Self::Converted<ANew>
    where
        ANew: Aggregator,
        G: SupportGenerator<N, F, Id = Self::Data>,
        G::SupportType: LocalAnalysis<F, ANew, N>;

    /// Refreshes the aggregator of the three after possible changes to the support generator.
    ///
    /// The `generator` is used to convert the data of type [`Self::Data`] contained in the tree
    /// into corresponding [`Support`]s.
    fn refresh_aggregator<G>(&mut self, generator: &G)
    where
        G: SupportGenerator<N, F, Id = Self::Data>,
        G::SupportType: LocalAnalysis<F, Self::Agg, N>;

    /// Returns an iterator over all [`Self::Data`] items at the point `x` of the domain.
    fn iter_at(&self, x: &Loc<N, F>) -> std::slice::Iter<'_, Self::Data>;

    /*
    /// Returns all [`Self::Data`] items at the point `x` of the domain.
    fn data_at(&self, x : &Loc<N, F>) -> Arc<Vec<Self::Data>>;
    */

    /// Create a new tree on `domain` of indicated `depth`.
    fn new(domain: Cube<N, F>, depth: Self::Depth) -> Self;
}

/// The main bisection tree structure.
///
/// It should be accessed via the [`BTImpl`] trait to hide the `const P : usize` parameter until
/// const generics are flexible enough to fix `P=pow(2, N)` and thus also get rid of
/// the `BTNodeLookup : BTNode<F, D, A, N>` trait bound.
#[derive(Clone, Debug)]
pub struct BT<M: Depth, F: Float, D: 'static + Copy, A: Aggregator, const N: usize>
where
    BTNodeLookup: BTNode<F, D, A, N>,
{
    /// The depth of the tree (initial, before refinement)
    pub(super) depth: M,
    /// The domain of the toplevel node
    pub(super) domain: Cube<N, F>,
    /// The toplevel node of the tree
    pub(super) topnode: <BTNodeLookup as BTNode<F, D, A, N>>::Node,
}

macro_rules! impl_bt {
    ($($n:literal)*) => { $(
        impl<F, D, A> BTNode<F, D, A, $n> for BTNodeLookup
        where F : Float,
              D : 'static + Copy + Send + Sync + std::fmt::Debug,
              A : Aggregator {
            type Node = Node<F,D,A,$n,{pow(2, $n)}>;
        }

        impl<M,F,D,A> BTImpl<$n, F> for BT<M,F,D,A,$n>
        where M : Depth,
              F : Float,
              D : 'static + Copy + Send + Sync + std::fmt::Debug,
              A : Aggregator {
            type Data = D;
            type Depth = M;
            type Agg = A;
            type Converted<ANew> = BT<M,F,D,ANew,$n> where ANew : Aggregator;

            fn insert<S: LocalAnalysis<F, A, $n> + Support< $n, F>>(
                &mut self,
                d : D,
                support : &S
            ) {
                with_task_budget(|task_budget|
                    self.topnode.insert(
                        &self.domain,
                        d,
                        self.depth,
                        support,
                        task_budget
                    )
                )
            }

            fn convert_aggregator<ANew, G>(self, generator : &G) -> Self::Converted<ANew>
            where ANew : Aggregator,
                  G : SupportGenerator< $n, F, Id=D>,
                  G::SupportType : LocalAnalysis<F, ANew, $n> {
                let topnode = self.topnode.convert_aggregator(generator, &self.domain);

                BT {
                    depth : self.depth,
                    domain : self.domain,
                    topnode
                }
            }

            fn refresh_aggregator<G>(&mut self, generator : &G)
            where G : SupportGenerator< $n, F, Id=Self::Data>,
                G::SupportType : LocalAnalysis<F, Self::Agg, $n> {
                with_task_budget(|task_budget|
                    self.topnode.refresh_aggregator(generator, &self.domain, task_budget)
                )
            }

            /*fn data_at(&self, x : &Loc<$n, F>) -> Arc<Vec<D>> {
                self.topnode.get_leaf_data(x).unwrap_or_else(|| Arc::new(Vec::new()))
            }*/

            fn iter_at(&self, x : &Loc<$n, F>) -> std::slice::Iter<'_, D> {
                self.topnode.get_leaf_data_iter(x).unwrap_or_else(|| [].iter())
            }

            fn new(domain : Cube<$n, F>, depth : M) -> Self {
                BT {
                    depth : depth,
                    domain : domain,
                    topnode : Node::new(),
                }
            }
        }

        impl<M,F,D,A> GlobalAnalysis<F,A> for BT<M,F,D,A,$n>
        where M : Depth,
              F : Float,
              D : 'static + Copy + Send + Sync + std::fmt::Debug,
              A : Aggregator {
            fn global_analysis(&self) -> A {
                self.topnode.get_aggregator().clone()
            }
        }
    )* }
}

impl_bt!(1 2 3 4);

mercurial