Sat, 22 Oct 2022 18:12:49 +0300
Fix some unit tests after fundamental changes that made them invalid
use std::slice::{Iter,IterMut}; use std::iter::once; use std::rc::Rc; use serde::{Serialize, Deserialize}; pub use nalgebra::Const; use itertools::izip; use crate::iter::{MapF,Mappable}; use crate::types::{Float, Num}; use crate::coefficients::pow; use crate::maputil::{ array_init, map2, map2_indexed, collect_into_array_unchecked }; pub use crate::sets::Cube; pub use crate::loc::Loc; use super::support::*; use super::aggregator::*; #[derive(Clone,Debug)] pub enum NodeOption<F : Num, D, A : Aggregator, const N : usize, const P : usize> { // TODO: Could optimise Uninitialised away by simply treat Leaf with an empty Vec as // something that can be still replaced with Branches. Uninitialised, // TODO: replace with QuickVec fast and w/o allocs on single elements. Leaf(Rc<Vec<D>>), Branches(Rc<Branches<F, D, A, N, P>>), } /// Node of a [`BT`] bisection tree. #[derive(Clone,Debug)] pub struct Node<F : Num, D, A : Aggregator, const N : usize, const P : usize> { pub(super) data : NodeOption<F, D, A, N, P>, /// Aggregator for `data`. pub(super) aggregator : A, } /// Branch information of a [`Node`] of a [`BT`] bisection tree. #[derive(Clone,Debug)] pub struct Branches<F : Num, D, A : Aggregator, const N : usize, const P : usize> { /// Point for subdivision of the (unstored) [`Cube`] corresponding to the node. pub(super) branch_at : Loc<F, N>, /// Subnodes pub(super) nodes : [Node<F, D, A, N, P>; P], } /// Dirty workaround to broken Rust drop, see [https://github.com/rust-lang/rust/issues/58068](). impl<F : Num, D, A : Aggregator, const N : usize, const P : usize> Drop for Node<F, D, A, N, P> { fn drop(&mut self) { use NodeOption as NO; let process = |brc : Rc<Branches<F, D, A, N, P>>, to_drop : &mut Vec<Rc<Branches<F, D, A, N, P>>>| { // We only drop Branches if we have the only strong reference. Rc::try_unwrap(brc).ok().map(|branches| branches.nodes.map(|mut node| { if let NO::Branches(brc2) = std::mem::replace(&mut node.data, NO::Uninitialised) { to_drop.push(brc2) } })); }; // We mark Self as NodeOption::Uninitialised, extracting the real contents. // If we have subprocess, we need to process them. if let NO::Branches(brc1) = std::mem::replace(&mut self.data, NO::Uninitialised) { // We store a queue of Rc<Branches> to drop into a vector let mut to_drop = Vec::new(); process(brc1, &mut to_drop); // While there are any Branches in the drop queue vector, we continue the process, // pushing all internal branching nodes into the queue. while let Some(brc) = to_drop.pop() { process(brc, &mut to_drop) } } } } pub trait Depth : 'static + Copy + std::fmt::Debug { type Lower : Depth; fn lower(&self) -> Option<Self::Lower>; fn lower_or(&self) -> Self::Lower; } #[derive(Copy,Clone,Debug,Serialize,Deserialize)] pub struct DynamicDepth(pub u8); impl Depth for DynamicDepth { type Lower = Self; #[inline] fn lower(&self) -> Option<Self> { if self.0>0 { Some(DynamicDepth(self.0-1)) } else { None } } #[inline] fn lower_or(&self) -> Self { DynamicDepth(if self.0>0 { self.0 - 1 } else { 0 }) } } impl Depth for Const<0> { type Lower = Self; fn lower(&self) -> Option<Self::Lower> { None } fn lower_or(&self) -> Self::Lower { Const } } macro_rules! impl_constdepth { ($($n:literal)*) => { $( impl Depth for Const<$n> { type Lower = Const<{$n-1}>; fn lower(&self) -> Option<Self::Lower> { Some(Const) } fn lower_or(&self) -> Self::Lower { Const } } )* }; } impl_constdepth!(1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32); pub trait BranchCount<const N : usize> {} macro_rules! impl_branchcount { ($($n:literal)*) => { $( impl BranchCount<$n> for Const<{pow(2, $n)}>{} )* } } impl_branchcount!(1 2 3 4 5 6 7 8); impl<F : Float, D, A, const N : usize, const P : usize> Branches<F,D,A,N,P> where Const<P> : BranchCount<N>, A : Aggregator { fn get_node_index(&self, x : &Loc<F, N>) -> usize { izip!(0..P, x.iter(), self.branch_at.iter()).map(|(i, x_i, branch_i)| if x_i > branch_i { 1<<i } else { 0 } ).sum() } #[inline] pub fn get_node(&self, x : &Loc<F,N>) -> &Node<F,D,A,N,P> { &self.nodes[self.get_node_index(x)] } } pub struct BTIter<'a, D> { iter : Iter<'a, D>, } pub struct SubcubeIter<'b, F : Float, const N : usize, const P : usize> { domain : &'b Cube<F, N>, branch_at : Loc<F, N>, index : usize, } #[inline] fn get_subcube<F : Float, const N : usize>(branch_at : &Loc<F, N>, domain : &Cube<F, N>, i : usize) -> Cube<F, N> { map2_indexed(branch_at, domain, move |j, &branch, &[start, end]| { if i & (1 << j) != 0 { [branch, end] } else { [start, branch] } }).into() } impl<'a, 'b, F : Float, const N : usize, const P : usize> Iterator for SubcubeIter<'b, F, N, P> { type Item = Cube<F, N>; #[inline] fn next(&mut self) -> Option<Self::Item> { if self.index < P { let i = self.index; self.index += 1; Some(get_subcube(&self.branch_at, self.domain, i)) } else { None } } } impl<F : Float, D : Copy, A, const N : usize, const P : usize> Branches<F,D,A,N,P> where Const<P> : BranchCount<N>, A : Aggregator { pub fn new_with<S : LocalAnalysis <F, A, N>>( domain : &Cube<F,N>, support : &S ) -> Self { let hint = support.bisection_hint(domain); let branch_at = map2(&hint, domain, |h, r| { h.unwrap_or_else(|| (r[0]+r[1])/F::TWO).max(r[0]).min(r[1]) }).into(); Branches{ branch_at : branch_at, nodes : array_init(|| Node::new()), } } /// Get an iterator over the aggregators of the nodes of this branch head. #[inline] pub fn aggregators(&self) -> MapF<Iter<'_, Node<F,D,A,N,P>>, &'_ A> { self.nodes.iter().mapF(Node::get_aggregator) } #[inline] pub fn iter_subcubes<'b>(&self, domain : &'b Cube<F, N>) -> SubcubeIter<'b, F, N, P> { SubcubeIter { domain : domain, branch_at : self.branch_at, index : 0, } } /// Iterate over all nodes and corresponding subcubes of self. #[inline] pub fn nodes_and_cubes<'a, 'b>(&'a self, domain : &'b Cube<F, N>) -> std::iter::Zip<Iter<'a, Node<F,D,A,N,P>>, SubcubeIter<'b, F, N, P>> { self.nodes.iter().zip(self.iter_subcubes(domain)) } /// Mutably iterate over all nodes and corresponding subcubes of self. #[inline] pub fn nodes_and_cubes_mut<'a, 'b>(&'a mut self, domain : &'b Cube<F, N>) -> std::iter::Zip<IterMut<'a, Node<F,D,A,N,P>>, SubcubeIter<'b, F, N, P>> { let subcube_iter = self.iter_subcubes(domain); self.nodes.iter_mut().zip(subcube_iter) } /// Insert data into the branch. pub fn insert<M : Depth, S : LocalAnalysis<F, A, N>>( &mut self, domain : &Cube<F,N>, d : D, new_leaf_depth : M, support : &S ) { let support_hint = support.support_hint(); for (node, subcube) in self.nodes_and_cubes_mut(&domain) { if support_hint.intersects(&subcube) { node.insert( &subcube, d, new_leaf_depth, support ); } } } /// Construct a new instance for a different aggregator pub fn convert_aggregator<ANew, G>( self, generator : &G, domain : &Cube<F, N> ) -> Branches<F,D,ANew,N,P> where ANew : Aggregator, G : SupportGenerator<F, N, Id=D>, G::SupportType : LocalAnalysis<F, ANew, N> { let branch_at = self.branch_at; let subcube_iter = self.iter_subcubes(domain); let new_nodes = self.nodes.into_iter().zip(subcube_iter).map(|(node, subcube)| { // TODO: avoid clone node.convert_aggregator(generator, &subcube) }); Branches { branch_at : branch_at, nodes : collect_into_array_unchecked(new_nodes), } } /// Recalculate aggregator after changes to generator pub fn refresh_aggregator<G>( &mut self, generator : &G, domain : &Cube<F, N> ) where G : SupportGenerator<F, N, Id=D>, G::SupportType : LocalAnalysis<F, A, N> { for (node, subcube) in self.nodes_and_cubes_mut(domain) { node.refresh_aggregator(generator, &subcube) } } } impl<F : Float, D : Copy, A, const N : usize, const P : usize> Node<F,D,A,N,P> where Const<P> : BranchCount<N>, A : Aggregator { #[inline] pub fn new() -> Self { Node { data : NodeOption::Uninitialised, aggregator : A::new(), } } /// Get leaf data #[inline] pub fn get_leaf_data(&self, x : &Loc<F, N>) -> Option<&Vec<D>> { match self.data { NodeOption::Uninitialised => None, NodeOption::Leaf(ref data) => Some(data), NodeOption::Branches(ref b) => b.get_node(x).get_leaf_data(x), } } /// Returns a reference to the aggregator of this node #[inline] pub fn get_aggregator(&self) -> &A { &self.aggregator } /// Insert `d` into the tree. If an `Incomplete` node is encountered, a new /// leaf is created at a minimum depth of `new_leaf_depth` pub fn insert<M : Depth, S : LocalAnalysis <F, A, N>>( &mut self, domain : &Cube<F,N>, d : D, new_leaf_depth : M, support : &S ) { match &mut self.data { NodeOption::Uninitialised => { // Replace uninitialised node with a leaf or a branch self.data = match new_leaf_depth.lower() { None => { let a = support.local_analysis(&domain); self.aggregator.aggregate(once(a)); // TODO: this is currently a dirty hard-coded heuristic; // should add capacity as a parameter let mut vec = Vec::with_capacity(2*P+1); vec.push(d); NodeOption::Leaf(Rc::new(vec)) }, Some(lower) => { let b = Rc::new({ let mut b0 = Branches::new_with(domain, support); b0.insert(domain, d, lower, support); b0 }); self.aggregator.summarise(b.aggregators()); NodeOption::Branches(b) } } }, NodeOption::Leaf(leaf) => { Rc::make_mut(leaf).push(d); let a = support.local_analysis(&domain); self.aggregator.aggregate(once(a)); }, NodeOption::Branches(b) => { Rc::make_mut(b).insert(domain, d, new_leaf_depth.lower_or(), support); self.aggregator.summarise(b.aggregators()); }, } } /// Construct a new instance for a different aggregator pub fn convert_aggregator<ANew, G>( mut self, generator : &G, domain : &Cube<F, N> ) -> Node<F,D,ANew,N,P> where ANew : Aggregator, G : SupportGenerator<F, N, Id=D>, G::SupportType : LocalAnalysis<F, ANew, N> { // The mem::replace is needed due to the [`Drop`] implementation to extract self.data. match std::mem::replace(&mut self.data, NodeOption::Uninitialised) { NodeOption::Uninitialised => Node { data : NodeOption::Uninitialised, aggregator : ANew::new(), }, NodeOption::Leaf(v) => { let mut anew = ANew::new(); anew.aggregate(v.iter().map(|d| { let support = generator.support_for(*d); support.local_analysis(&domain) })); Node { data : NodeOption::Leaf(v), aggregator : anew, } }, NodeOption::Branches(b) => { // TODO: now with Rc, convert_aggregator should be reference-based. let bnew = Rc::new(Rc::unwrap_or_clone(b).convert_aggregator(generator, domain)); let mut anew = ANew::new(); anew.summarise(bnew.aggregators()); Node { data : NodeOption::Branches(bnew), aggregator : anew, } } } } /// Refresh aggregator after changes to generator pub fn refresh_aggregator<G>( &mut self, generator : &G, domain : &Cube<F, N> ) where G : SupportGenerator<F, N, Id=D>, G::SupportType : LocalAnalysis<F, A, N> { match &mut self.data { NodeOption::Uninitialised => { }, NodeOption::Leaf(v) => { self.aggregator = A::new(); self.aggregator.aggregate(v.iter().map(|d| { generator.support_for(*d) .local_analysis(&domain) })); }, NodeOption::Branches(ref mut b) => { // TODO: now with Rc, convert_aggregator should be reference-based. Rc::make_mut(b).refresh_aggregator(generator, domain); self.aggregator.summarise(b.aggregators()); } } } } impl<'a, D> Iterator for BTIter<'a,D> { type Item = &'a D; #[inline] fn next(&mut self) -> Option<&'a D> { self.iter.next() } } // // BT // /// Internal structure to hide the `const P : usize` parameter of [`Node`] until /// const generics are flexible enough to fix `P=pow(2, N)`. pub trait BTNode<F, D, A, const N : usize> where F : Float, D : 'static + Copy, A : Aggregator { type Node : Clone + std::fmt::Debug; } #[derive(Debug)] pub struct BTNodeLookup; /// Interface to a [`BT`] bisection tree. pub trait BTImpl<F : Float, const N : usize> : std::fmt::Debug + Clone + GlobalAnalysis<F, Self::Agg> { type Data : 'static + Copy; type Depth : Depth; type Agg : Aggregator; type Converted<ANew> : BTImpl<F, N, Data=Self::Data, Agg=ANew> where ANew : Aggregator; /// Insert d into the `BisectionTree`. fn insert<S : LocalAnalysis<F, Self::Agg, N>>( &mut self, d : Self::Data, support : &S ); /// Construct a new instance for a different aggregator fn convert_aggregator<ANew, G>(self, generator : &G) -> Self::Converted<ANew> where ANew : Aggregator, G : SupportGenerator<F, N, Id=Self::Data>, G::SupportType : LocalAnalysis<F, ANew, N>; /// Refresh aggregator after changes to generator fn refresh_aggregator<G>(&mut self, generator : &G) where G : SupportGenerator<F, N, Id=Self::Data>, G::SupportType : LocalAnalysis<F, Self::Agg, N>; /// Iterarate items at x fn iter_at<'a>(&'a self, x : &'a Loc<F,N>) -> BTIter<'a, Self::Data>; /// Create a new instance fn new(domain : Cube<F, N>, depth : Self::Depth) -> Self; } /// The main bisection tree structure. The interface operations are via [`BTImpl`] /// to hide the `const P : usize` parameter until const generics are flexible enough /// to fix `P=pow(2, N)`. #[derive(Clone,Debug)] pub struct BT< M : Depth, F : Float, D : 'static + Copy, A : Aggregator, const N : usize, > where BTNodeLookup : BTNode<F, D, A, N> { pub(super) depth : M, pub(super) domain : Cube<F, N>, pub(super) topnode : <BTNodeLookup as BTNode<F, D, A, N>>::Node, } macro_rules! impl_bt { ($($n:literal)*) => { $( impl<F, D, A> BTNode<F, D, A, $n> for BTNodeLookup where F : Float, D : 'static + Copy + std::fmt::Debug, A : Aggregator { type Node = Node<F,D,A,$n,{pow(2, $n)}>; } impl<M,F,D,A> BTImpl<F,$n> for BT<M,F,D,A,$n> where M : Depth, F : Float, D : 'static + Copy + std::fmt::Debug, A : Aggregator { type Data = D; type Depth = M; type Agg = A; type Converted<ANew> = BT<M,F,D,ANew,$n> where ANew : Aggregator; /// Insert `d` into the tree. fn insert<S: LocalAnalysis<F, A, $n>>( &mut self, d : D, support : &S ) { self.topnode.insert( &self.domain, d, self.depth, support ); } /// Construct a new instance for a different aggregator fn convert_aggregator<ANew, G>(self, generator : &G) -> Self::Converted<ANew> where ANew : Aggregator, G : SupportGenerator<F, $n, Id=D>, G::SupportType : LocalAnalysis<F, ANew, $n> { let topnode = self.topnode.convert_aggregator(generator, &self.domain); BT { depth : self.depth, domain : self.domain, topnode } } /// Refresh aggregator after changes to generator fn refresh_aggregator<G>(&mut self, generator : &G) where G : SupportGenerator<F, $n, Id=Self::Data>, G::SupportType : LocalAnalysis<F, Self::Agg, $n> { self.topnode.refresh_aggregator(generator, &self.domain); } /// Iterate elements at `x`. fn iter_at<'a>(&'a self, x : &'a Loc<F,$n>) -> BTIter<'a,D> { match self.topnode.get_leaf_data(x) { Some(data) => BTIter { iter : data.iter() }, None => BTIter { iter : [].iter() } } } fn new(domain : Cube<F, $n>, depth : M) -> Self { BT { depth : depth, domain : domain, topnode : Node::new(), } } } impl<M,F,D,A> GlobalAnalysis<F,A> for BT<M,F,D,A,$n> where M : Depth, F : Float, D : 'static + Copy + std::fmt::Debug, A : Aggregator { fn global_analysis(&self) -> A { self.topnode.get_aggregator().clone() } } )* } } impl_bt!(1 2 3 4);