src/bisection_tree/bt.rs

branch
dev
changeset 96
962c8e346ab9
parent 9
f40dfaf2166d
child 97
4e80fb049dca
equal deleted inserted replaced
95:9cb8225a3a41 96:962c8e346ab9
1
2 /*! 1 /*!
3 Bisection tree basics, [`BT`] type and the [`BTImpl`] trait. 2 Bisection tree basics, [`BT`] type and the [`BTImpl`] trait.
4 */ 3 */
5 4
5 use itertools::izip;
6 pub(super) use nalgebra::Const;
7 use serde::{Deserialize, Serialize};
8 use std::iter::once;
6 use std::slice::IterMut; 9 use std::slice::IterMut;
7 use std::iter::once;
8 use std::sync::Arc; 10 use std::sync::Arc;
9 use serde::{Serialize, Deserialize}; 11
10 pub(super) use nalgebra::Const; 12 use super::aggregator::*;
11 use itertools::izip; 13 use super::support::*;
12 14 use crate::coefficients::pow;
15 use crate::loc::Loc;
16 use crate::maputil::{array_init, collect_into_array_unchecked, map2, map2_indexed};
17 use crate::parallelism::{with_task_budget, TaskBudget};
18 use crate::sets::Cube;
13 use crate::types::{Float, Num}; 19 use crate::types::{Float, Num};
14 use crate::parallelism::{with_task_budget, TaskBudget};
15 use crate::coefficients::pow;
16 use crate::maputil::{
17 array_init,
18 map2,
19 map2_indexed,
20 collect_into_array_unchecked
21 };
22 use crate::sets::Cube;
23 use crate::loc::Loc;
24 use super::support::*;
25 use super::aggregator::*;
26 20
27 /// An enum that indicates whether a [`Node`] of a [`BT`] is uninitialised, leaf, or branch. 21 /// An enum that indicates whether a [`Node`] of a [`BT`] is uninitialised, leaf, or branch.
28 /// 22 ///
29 /// For the type and const parametere, see the [module level documentation][super]. 23 /// For the type and const parametere, see the [module level documentation][super].
30 #[derive(Clone,Debug)] 24 #[derive(Clone, Debug)]
31 pub(super) enum NodeOption<F : Num, D, A : Aggregator, const N : usize, const P : usize> { 25 pub(super) enum NodeOption<F: Num, D, A: Aggregator, const N: usize, const P: usize> {
32 /// Indicates an uninitilised node; may become a branch or a leaf. 26 /// Indicates an uninitilised node; may become a branch or a leaf.
33 // TODO: Could optimise Uninitialised away by simply treat Leaf with an empty Vec as 27 // TODO: Could optimise Uninitialised away by simply treat Leaf with an empty Vec as
34 // something that can be still replaced with Branches. 28 // something that can be still replaced with Branches.
35 Uninitialised, 29 Uninitialised,
36 /// Indicates a leaf node containing a copy-on-write reference-counted vector 30 /// Indicates a leaf node containing a copy-on-write reference-counted vector
42 36
43 /// Node of a [`BT`] bisection tree. 37 /// Node of a [`BT`] bisection tree.
44 /// 38 ///
45 /// For the type and const parameteres, see the [module level documentation][super]. 39 /// For the type and const parameteres, see the [module level documentation][super].
46 #[derive(Clone, Debug)] 40 #[derive(Clone, Debug)]
47 pub struct Node<F : Num, D, A : Aggregator, const N : usize, const P : usize> { 41 pub struct Node<F: Num, D, A: Aggregator, const N: usize, const P: usize> {
48 /// The data or branches under the node. 42 /// The data or branches under the node.
49 pub(super) data : NodeOption<F, D, A, N, P>, 43 pub(super) data: NodeOption<F, D, A, N, P>,
50 /// Aggregator for `data`. 44 /// Aggregator for `data`.
51 pub(super) aggregator : A, 45 pub(super) aggregator: A,
52 } 46 }
53 47
54 /// Branching information of a [`Node`] of a [`BT`] bisection tree into `P` subnodes. 48 /// Branching information of a [`Node`] of a [`BT`] bisection tree into `P` subnodes.
55 /// 49 ///
56 /// For the type and const parameters, see the [module level documentation][super]. 50 /// For the type and const parameters, see the [module level documentation][super].
57 #[derive(Clone, Debug)] 51 #[derive(Clone, Debug)]
58 pub(super) struct Branches<F : Num, D, A : Aggregator, const N : usize, const P : usize> { 52 pub(super) struct Branches<F: Num, D, A: Aggregator, const N: usize, const P: usize> {
59 /// Point for subdivision of the (unstored) [`Cube`] corresponding to the node. 53 /// Point for subdivision of the (unstored) [`Cube`] corresponding to the node.
60 pub(super) branch_at : Loc<F, N>, 54 pub(super) branch_at: Loc<F, N>,
61 /// Subnodes 55 /// Subnodes
62 pub(super) nodes : [Node<F, D, A, N, P>; P], 56 pub(super) nodes: [Node<F, D, A, N, P>; P],
63 } 57 }
64 58
65 /// Dirty workaround to broken Rust drop, see [https://github.com/rust-lang/rust/issues/58068](). 59 /// Dirty workaround to broken Rust drop, see [https://github.com/rust-lang/rust/issues/58068]().
66 impl<F : Num, D, A : Aggregator, const N : usize, const P : usize> 60 impl<F: Num, D, A: Aggregator, const N: usize, const P: usize> Drop for Node<F, D, A, N, P> {
67 Drop for Node<F, D, A, N, P> {
68 fn drop(&mut self) { 61 fn drop(&mut self) {
69 use NodeOption as NO; 62 use NodeOption as NO;
70 63
71 let process = |brc : Arc<Branches<F, D, A, N, P>>, 64 let process = |brc: Arc<Branches<F, D, A, N, P>>,
72 to_drop : &mut Vec<Arc<Branches<F, D, A, N, P>>>| { 65 to_drop: &mut Vec<Arc<Branches<F, D, A, N, P>>>| {
73 // We only drop Branches if we have the only strong reference. 66 // We only drop Branches if we have the only strong reference.
74 // FIXME: update the RwLocks on Nodes. 67 // FIXME: update the RwLocks on Nodes.
75 Arc::try_unwrap(brc).ok().map(|branches| branches.nodes.map(|mut node| { 68 Arc::try_unwrap(brc).ok().map(|branches| {
76 if let NO::Branches(brc2) = std::mem::replace(&mut node.data, NO::Uninitialised) { 69 branches.nodes.map(|mut node| {
77 to_drop.push(brc2) 70 if let NO::Branches(brc2) = std::mem::replace(&mut node.data, NO::Uninitialised)
78 } 71 {
79 })); 72 to_drop.push(brc2)
73 }
74 })
75 });
80 }; 76 };
81 77
82 // We mark Self as NodeOption::Uninitialised, extracting the real contents. 78 // We mark Self as NodeOption::Uninitialised, extracting the real contents.
83 // If we have subprocess, we need to process them. 79 // If we have subprocess, we need to process them.
84 if let NO::Branches(brc1) = std::mem::replace(&mut self.data, NO::Uninitialised) { 80 if let NO::Branches(brc1) = std::mem::replace(&mut self.data, NO::Uninitialised) {
96 } 92 }
97 93
98 /// Trait for the depth of a [`BT`]. 94 /// Trait for the depth of a [`BT`].
99 /// 95 ///
100 /// This will generally be either a runtime [`DynamicDepth`] or compile-time [`Const`] depth. 96 /// This will generally be either a runtime [`DynamicDepth`] or compile-time [`Const`] depth.
101 pub trait Depth : 'static + Copy + Send + Sync + std::fmt::Debug { 97 pub trait Depth: 'static + Copy + Send + Sync + std::fmt::Debug {
102 /// Lower depth type. 98 /// Lower depth type.
103 type Lower : Depth; 99 type Lower: Depth;
104 100
105 /// Returns a lower depth, if there still is one. 101 /// Returns a lower depth, if there still is one.
106 fn lower(&self) -> Option<Self::Lower>; 102 fn lower(&self) -> Option<Self::Lower>;
107 103
108 /// Returns a lower depth or self if this is the lowest depth. 104 /// Returns a lower depth or self if this is the lowest depth.
111 /// Returns the numeric value of the depth 107 /// Returns the numeric value of the depth
112 fn value(&self) -> u32; 108 fn value(&self) -> u32;
113 } 109 }
114 110
115 /// Dynamic (runtime) [`Depth`] for a [`BT`]. 111 /// Dynamic (runtime) [`Depth`] for a [`BT`].
116 #[derive(Copy,Clone,Debug,Serialize,Deserialize)] 112 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
117 pub struct DynamicDepth( 113 pub struct DynamicDepth(
118 /// The depth 114 /// The depth
119 pub u8 115 pub u8,
120 ); 116 );
121 117
122 impl Depth for DynamicDepth { 118 impl Depth for DynamicDepth {
123 type Lower = Self; 119 type Lower = Self;
124 #[inline] 120 #[inline]
125 fn lower(&self) -> Option<Self> { 121 fn lower(&self) -> Option<Self> {
126 if self.0>0 { 122 if self.0 > 0 {
127 Some(DynamicDepth(self.0-1)) 123 Some(DynamicDepth(self.0 - 1))
128 } else { 124 } else {
129 None 125 None
130 } 126 }
131 } 127 }
132 128
133 #[inline] 129 #[inline]
134 fn lower_or(&self) -> Self { 130 fn lower_or(&self) -> Self {
135 DynamicDepth(if self.0>0 { self.0 - 1 } else { 0 }) 131 DynamicDepth(if self.0 > 0 { self.0 - 1 } else { 0 })
136 } 132 }
137 133
138 #[inline] 134 #[inline]
139 fn value(&self) -> u32 { 135 fn value(&self) -> u32 {
140 self.0 as u32 136 self.0 as u32
141 } 137 }
142 } 138 }
143 139
144 impl Depth for Const<0> { 140 impl Depth for Const<0> {
145 type Lower = Self; 141 type Lower = Self;
146 fn lower(&self) -> Option<Self::Lower> { None } 142 fn lower(&self) -> Option<Self::Lower> {
147 fn lower_or(&self) -> Self::Lower { Const } 143 None
148 fn value(&self) -> u32 { 0 } 144 }
145 fn lower_or(&self) -> Self::Lower {
146 Const
147 }
148 fn value(&self) -> u32 {
149 0
150 }
149 } 151 }
150 152
151 macro_rules! impl_constdepth { 153 macro_rules! impl_constdepth {
152 ($($n:literal)*) => { $( 154 ($($n:literal)*) => { $(
153 impl Depth for Const<$n> { 155 impl Depth for Const<$n> {
163 /// Trait for counting the branching factor of a [`BT`] of dimension `N`. 165 /// Trait for counting the branching factor of a [`BT`] of dimension `N`.
164 /// 166 ///
165 /// The const parameter `P` from the [module level documentation][super] is required to satisfy 167 /// The const parameter `P` from the [module level documentation][super] is required to satisfy
166 /// `Const<P> : Branchcount<N>`. 168 /// `Const<P> : Branchcount<N>`.
167 /// This trait is implemented for `P=pow(2, N)` for small `N`. 169 /// This trait is implemented for `P=pow(2, N)` for small `N`.
168 pub trait BranchCount<const N : usize> {} 170 pub trait BranchCount<const N: usize> {}
169 macro_rules! impl_branchcount { 171 macro_rules! impl_branchcount {
170 ($($n:literal)*) => { $( 172 ($($n:literal)*) => { $(
171 impl BranchCount<$n> for Const<{pow(2, $n)}>{} 173 impl BranchCount<$n> for Const<{pow(2, $n)}>{}
172 )* } 174 )* }
173 } 175 }
174 impl_branchcount!(1 2 3 4 5 6 7 8); 176 impl_branchcount!(1 2 3 4 5 6 7 8);
175 177
176 impl<F : Float, D, A, const N : usize, const P : usize> Branches<F,D,A,N,P> 178 impl<F: Float, D, A, const N: usize, const P: usize> Branches<F, D, A, N, P>
177 where Const<P> : BranchCount<N>, 179 where
178 A : Aggregator 180 Const<P>: BranchCount<N>,
181 A: Aggregator,
179 { 182 {
180 /// Returns the index in {0, …, `P`-1} for the branch to which the point `x` corresponds. 183 /// Returns the index in {0, …, `P`-1} for the branch to which the point `x` corresponds.
181 /// 184 ///
182 /// This only takes the branch subdivision point $d$ into account, so is always succesfull. 185 /// This only takes the branch subdivision point $d$ into account, so is always succesfull.
183 /// Thus, for this point, each branch corresponds to a quadrant of $ℝ^N$ relative to $d$. 186 /// Thus, for this point, each branch corresponds to a quadrant of $ℝ^N$ relative to $d$.
184 fn get_node_index(&self, x : &Loc<F, N>) -> usize { 187 fn get_node_index(&self, x: &Loc<F, N>) -> usize {
185 izip!(0..P, x.iter(), self.branch_at.iter()).map(|(i, x_i, branch_i)| 188 izip!(0..P, x.iter(), self.branch_at.iter())
186 if x_i > branch_i { 1<<i } else { 0 } 189 .map(|(i, x_i, branch_i)| if x_i > branch_i { 1 << i } else { 0 })
187 ).sum() 190 .sum()
188 } 191 }
189 192
190 /// Returns the node within `Self` containing the point `x`. 193 /// Returns the node within `Self` containing the point `x`.
191 /// 194 ///
192 /// This only takes the branch subdivision point $d$ into account, so is always succesfull. 195 /// This only takes the branch subdivision point $d$ into account, so is always succesfull.
193 /// Thus, for this point, each branch corresponds to a quadrant of $ℝ^N$ relative to $d$. 196 /// Thus, for this point, each branch corresponds to a quadrant of $ℝ^N$ relative to $d$.
194 #[inline] 197 #[inline]
195 fn get_node(&self, x : &Loc<F,N>) -> &Node<F,D,A,N,P> { 198 fn get_node(&self, x: &Loc<F, N>) -> &Node<F, D, A, N, P> {
196 &self.nodes[self.get_node_index(x)] 199 &self.nodes[self.get_node_index(x)]
197 } 200 }
198 } 201 }
199 202
200 /// An iterator over the $P=2^N$ subcubes of a [`Cube`] subdivided at a point `d`. 203 /// An iterator over the $P=2^N$ subcubes of a [`Cube`] subdivided at a point `d`.
201 pub(super) struct SubcubeIter<'b, F : Float, const N : usize, const P : usize> { 204 pub(super) struct SubcubeIter<'b, F: Float, const N: usize, const P: usize> {
202 domain : &'b Cube<F, N>, 205 domain: &'b Cube<F, N>,
203 branch_at : Loc<F, N>, 206 branch_at: Loc<F, N>,
204 index : usize, 207 index: usize,
205 } 208 }
206 209
207 /// Returns the `i`:th subcube of `domain` subdivided at `branch_at`. 210 /// Returns the `i`:th subcube of `domain` subdivided at `branch_at`.
208 #[inline] 211 #[inline]
209 fn get_subcube<F : Float, const N : usize>( 212 fn get_subcube<F: Float, const N: usize>(
210 branch_at : &Loc<F, N>, 213 branch_at: &Loc<F, N>,
211 domain : &Cube<F, N>, 214 domain: &Cube<F, N>,
212 i : usize 215 i: usize,
213 ) -> Cube<F, N> { 216 ) -> Cube<F, N> {
214 map2_indexed(branch_at, domain, move |j, &branch, &[start, end]| { 217 map2_indexed(branch_at, domain, move |j, &branch, &[start, end]| {
215 if i & (1 << j) != 0 { 218 if i & (1 << j) != 0 {
216 [branch, end] 219 [branch, end]
217 } else { 220 } else {
218 [start, branch] 221 [start, branch]
219 } 222 }
220 }).into() 223 })
221 } 224 .into()
222 225 }
223 impl<'a, 'b, F : Float, const N : usize, const P : usize> Iterator 226
224 for SubcubeIter<'b, F, N, P> { 227 impl<'a, 'b, F: Float, const N: usize, const P: usize> Iterator for SubcubeIter<'b, F, N, P> {
225 type Item = Cube<F, N>; 228 type Item = Cube<F, N>;
226 #[inline] 229 #[inline]
227 fn next(&mut self) -> Option<Self::Item> { 230 fn next(&mut self) -> Option<Self::Item> {
228 if self.index < P { 231 if self.index < P {
229 let i = self.index; 232 let i = self.index;
233 None 236 None
234 } 237 }
235 } 238 }
236 } 239 }
237 240
238 impl<F : Float, D, A, const N : usize, const P : usize> 241 impl<F: Float, D, A, const N: usize, const P: usize> Branches<F, D, A, N, P>
239 Branches<F,D,A,N,P> 242 where
240 where Const<P> : BranchCount<N>, 243 Const<P>: BranchCount<N>,
241 A : Aggregator, 244 A: Aggregator,
242 D : 'static + Copy + Send + Sync { 245 D: 'static + Copy + Send + Sync,
243 246 {
244 /// Creates a new node branching structure, subdividing `domain` based on the 247 /// Creates a new node branching structure, subdividing `domain` based on the
245 /// [hint][Support::support_hint] of `support`. 248 /// [hint][Support::support_hint] of `support`.
246 pub(super) fn new_with<S : LocalAnalysis <F, A, N>>( 249 pub(super) fn new_with<S: LocalAnalysis<F, A, N>>(domain: &Cube<F, N>, support: &S) -> Self {
247 domain : &Cube<F,N>,
248 support : &S
249 ) -> Self {
250 let hint = support.bisection_hint(domain); 250 let hint = support.bisection_hint(domain);
251 let branch_at = map2(&hint, domain, |h, r| { 251 let branch_at = map2(&hint, domain, |h, r| {
252 h.unwrap_or_else(|| (r[0]+r[1])/F::TWO).max(r[0]).min(r[1]) 252 h.unwrap_or_else(|| (r[0] + r[1]) / F::TWO)
253 }).into(); 253 .max(r[0])
254 Branches{ 254 .min(r[1])
255 branch_at : branch_at, 255 })
256 nodes : array_init(|| Node::new()), 256 .into();
257 Branches {
258 branch_at: branch_at,
259 nodes: array_init(|| Node::new()),
257 } 260 }
258 } 261 }
259 262
260 /// Summarises the aggregators of these branches into `agg` 263 /// Summarises the aggregators of these branches into `agg`
261 pub(super) fn summarise_into(&self, agg : &mut A) { 264 pub(super) fn summarise_into(&self, agg: &mut A) {
262 // We need to create an array of the aggregators clones due to the RwLock. 265 // We need to create an array of the aggregators clones due to the RwLock.
263 agg.summarise(self.nodes.iter().map(Node::get_aggregator)); 266 agg.summarise(self.nodes.iter().map(Node::get_aggregator));
264 } 267 }
265 268
266 /// Returns an iterator over the subcubes of `domain` subdivided at the branching point 269 /// Returns an iterator over the subcubes of `domain` subdivided at the branching point
267 /// of `self`. 270 /// of `self`.
268 #[inline] 271 #[inline]
269 pub(super) fn iter_subcubes<'b>(&self, domain : &'b Cube<F, N>) 272 pub(super) fn iter_subcubes<'b>(&self, domain: &'b Cube<F, N>) -> SubcubeIter<'b, F, N, P> {
270 -> SubcubeIter<'b, F, N, P> {
271 SubcubeIter { 273 SubcubeIter {
272 domain : domain, 274 domain: domain,
273 branch_at : self.branch_at, 275 branch_at: self.branch_at,
274 index : 0, 276 index: 0,
275 } 277 }
276 } 278 }
277 279
278 /* 280 /*
279 /// Returns an iterator over all nodes and corresponding subcubes of `self`. 281 /// Returns an iterator over all nodes and corresponding subcubes of `self`.
284 } 286 }
285 */ 287 */
286 288
287 /// Mutably iterate over all nodes and corresponding subcubes of `self`. 289 /// Mutably iterate over all nodes and corresponding subcubes of `self`.
288 #[inline] 290 #[inline]
289 pub(super) fn nodes_and_cubes_mut<'a, 'b>(&'a mut self, domain : &'b Cube<F, N>) 291 pub(super) fn nodes_and_cubes_mut<'a, 'b>(
290 -> std::iter::Zip<IterMut<'a, Node<F,D,A,N,P>>, SubcubeIter<'b, F, N, P>> { 292 &'a mut self,
293 domain: &'b Cube<F, N>,
294 ) -> std::iter::Zip<IterMut<'a, Node<F, D, A, N, P>>, SubcubeIter<'b, F, N, P>> {
291 let subcube_iter = self.iter_subcubes(domain); 295 let subcube_iter = self.iter_subcubes(domain);
292 self.nodes.iter_mut().zip(subcube_iter) 296 self.nodes.iter_mut().zip(subcube_iter)
293 } 297 }
294 298
295 /// Call `f` on all `(subnode, subcube)` pairs in multiple threads, if `guard` so deems. 299 /// Call `f` on all `(subnode, subcube)` pairs in multiple threads, if `guard` so deems.
296 #[inline] 300 #[inline]
297 fn recurse<'scope, 'smaller, 'refs>( 301 fn recurse<'scope, 'smaller, 'refs>(
298 &'smaller mut self, 302 &'smaller mut self,
299 domain : &'smaller Cube<F, N>, 303 domain: &'smaller Cube<F, N>,
300 task_budget : TaskBudget<'scope, 'refs>, 304 task_budget: TaskBudget<'scope, 'refs>,
301 guard : impl Fn(&Node<F,D,A,N,P>, &Cube<F, N>) -> bool + Send + 'smaller, 305 guard: impl Fn(&Node<F, D, A, N, P>, &Cube<F, N>) -> bool + Send + 'smaller,
302 mut f : impl for<'a> FnMut(&mut Node<F,D,A,N,P>, &Cube<F, N>, TaskBudget<'smaller, 'a>) 306 mut f: impl for<'a> FnMut(&mut Node<F, D, A, N, P>, &Cube<F, N>, TaskBudget<'smaller, 'a>)
303 + Send + Copy + 'smaller 307 + Send
304 ) where 'scope : 'smaller { 308 + Copy
309 + 'smaller,
310 ) where
311 'scope: 'smaller,
312 {
305 let subs = self.nodes_and_cubes_mut(domain); 313 let subs = self.nodes_and_cubes_mut(domain);
306 task_budget.zoom(move |s| { 314 task_budget.zoom(move |s| {
307 for (node, subcube) in subs { 315 for (node, subcube) in subs {
308 if guard(node, &subcube) { 316 if guard(node, &subcube) {
309 s.execute(move |new_budget| f(node, &subcube, new_budget)) 317 s.execute(move |new_budget| f(node, &subcube, new_budget))
319 /// * `d` is the data to be inserted 327 /// * `d` is the data to be inserted
320 /// * `new_leaf_depth` is the depth relative to `self` at which the data is to be inserted. 328 /// * `new_leaf_depth` is the depth relative to `self` at which the data is to be inserted.
321 /// * `support` is the [`Support`] that is used determine with which subcubes of `domain` 329 /// * `support` is the [`Support`] that is used determine with which subcubes of `domain`
322 /// (at subdivision depth `new_leaf_depth`) the data `d` is to be associated with. 330 /// (at subdivision depth `new_leaf_depth`) the data `d` is to be associated with.
323 /// 331 ///
324 pub(super) fn insert<'refs, 'scope, M : Depth, S : LocalAnalysis<F, A, N>>( 332 pub(super) fn insert<'refs, 'scope, M: Depth, S: LocalAnalysis<F, A, N>>(
325 &mut self, 333 &mut self,
326 domain : &Cube<F,N>, 334 domain: &Cube<F, N>,
327 d : D, 335 d: D,
328 new_leaf_depth : M, 336 new_leaf_depth: M,
329 support : &S, 337 support: &S,
330 task_budget : TaskBudget<'scope, 'refs>, 338 task_budget: TaskBudget<'scope, 'refs>,
331 ) { 339 ) {
332 let support_hint = support.support_hint(); 340 let support_hint = support.support_hint();
333 self.recurse(domain, task_budget, 341 self.recurse(
334 |_, subcube| support_hint.intersects(&subcube), 342 domain,
335 move |node, subcube, new_budget| node.insert(subcube, d, new_leaf_depth, support, 343 task_budget,
336 new_budget)); 344 |_, subcube| support_hint.intersects(&subcube),
345 move |node, subcube, new_budget| {
346 node.insert(subcube, d, new_leaf_depth, support, new_budget)
347 },
348 );
337 } 349 }
338 350
339 /// Construct a new instance of the branch for a different aggregator. 351 /// Construct a new instance of the branch for a different aggregator.
340 /// 352 ///
341 /// The `generator` is used to convert the data of type `D` of the branch into corresponding 353 /// The `generator` is used to convert the data of type `D` of the branch into corresponding
342 /// [`Support`]s. The `domain` is the cube corresponding to `self`. 354 /// [`Support`]s. The `domain` is the cube corresponding to `self`.
343 /// The type parameter `ANew´ is the new aggregator, and needs to be implemented for the 355 /// The type parameter `ANew´ is the new aggregator, and needs to be implemented for the
344 /// generator's `SupportType`. 356 /// generator's `SupportType`.
345 pub(super) fn convert_aggregator<ANew, G>( 357 pub(super) fn convert_aggregator<ANew, G>(
346 self, 358 self,
347 generator : &G, 359 generator: &G,
348 domain : &Cube<F, N> 360 domain: &Cube<F, N>,
349 ) -> Branches<F,D,ANew,N,P> 361 ) -> Branches<F, D, ANew, N, P>
350 where ANew : Aggregator, 362 where
351 G : SupportGenerator<F, N, Id=D>, 363 ANew: Aggregator,
352 G::SupportType : LocalAnalysis<F, ANew, N> { 364 G: SupportGenerator<F, N, Id = D>,
365 G::SupportType: LocalAnalysis<F, ANew, N>,
366 {
353 let branch_at = self.branch_at; 367 let branch_at = self.branch_at;
354 let subcube_iter = self.iter_subcubes(domain); 368 let subcube_iter = self.iter_subcubes(domain);
355 let new_nodes = self.nodes.into_iter().zip(subcube_iter).map(|(node, subcube)| { 369 let new_nodes = self
356 Node::convert_aggregator(node, generator, &subcube) 370 .nodes
357 }); 371 .into_iter()
372 .zip(subcube_iter)
373 .map(|(node, subcube)| Node::convert_aggregator(node, generator, &subcube));
358 Branches { 374 Branches {
359 branch_at : branch_at, 375 branch_at: branch_at,
360 nodes : collect_into_array_unchecked(new_nodes), 376 nodes: collect_into_array_unchecked(new_nodes),
361 } 377 }
362 } 378 }
363 379
364 /// Recalculate aggregator after changes to generator. 380 /// Recalculate aggregator after changes to generator.
365 /// 381 ///
366 /// The `generator` is used to convert the data of type `D` of the branch into corresponding 382 /// The `generator` is used to convert the data of type `D` of the branch into corresponding
367 /// [`Support`]s. The `domain` is the cube corresponding to `self`. 383 /// [`Support`]s. The `domain` is the cube corresponding to `self`.
368 pub(super) fn refresh_aggregator<'refs, 'scope, G>( 384 pub(super) fn refresh_aggregator<'refs, 'scope, G>(
369 &mut self, 385 &mut self,
370 generator : &G, 386 generator: &G,
371 domain : &Cube<F, N>, 387 domain: &Cube<F, N>,
372 task_budget : TaskBudget<'scope, 'refs>, 388 task_budget: TaskBudget<'scope, 'refs>,
373 ) where G : SupportGenerator<F, N, Id=D>, 389 ) where
374 G::SupportType : LocalAnalysis<F, A, N> { 390 G: SupportGenerator<F, N, Id = D>,
375 self.recurse(domain, task_budget, 391 G::SupportType: LocalAnalysis<F, A, N>,
376 |_, _| true, 392 {
377 move |node, subcube, new_budget| node.refresh_aggregator(generator, subcube, 393 self.recurse(
378 new_budget)); 394 domain,
379 } 395 task_budget,
380 } 396 |_, _| true,
381 397 move |node, subcube, new_budget| {
382 impl<F : Float, D, A, const N : usize, const P : usize> 398 node.refresh_aggregator(generator, subcube, new_budget)
383 Node<F,D,A,N,P> 399 },
384 where Const<P> : BranchCount<N>, 400 );
385 A : Aggregator, 401 }
386 D : 'static + Copy + Send + Sync { 402 }
387 403
404 impl<F: Float, D, A, const N: usize, const P: usize> Node<F, D, A, N, P>
405 where
406 Const<P>: BranchCount<N>,
407 A: Aggregator,
408 D: 'static + Copy + Send + Sync,
409 {
388 /// Create a new node 410 /// Create a new node
389 #[inline] 411 #[inline]
390 pub(super) fn new() -> Self { 412 pub(super) fn new() -> Self {
391 Node { 413 Node {
392 data : NodeOption::Uninitialised, 414 data: NodeOption::Uninitialised,
393 aggregator : A::new(), 415 aggregator: A::new(),
394 } 416 }
395 } 417 }
396 418
397 /* 419 /*
398 /// Get leaf data 420 /// Get leaf data
405 } 427 }
406 }*/ 428 }*/
407 429
408 /// Get leaf data iterator 430 /// Get leaf data iterator
409 #[inline] 431 #[inline]
410 pub(super) fn get_leaf_data_iter(&self, x : &Loc<F, N>) -> Option<std::slice::Iter<'_, D>> { 432 pub(super) fn get_leaf_data_iter(&self, x: &Loc<F, N>) -> Option<std::slice::Iter<'_, D>> {
411 match self.data { 433 match self.data {
412 NodeOption::Uninitialised => None, 434 NodeOption::Uninitialised => None,
413 NodeOption::Leaf(ref data) => Some(data.iter()), 435 NodeOption::Leaf(ref data) => Some(data.iter()),
414 NodeOption::Branches(ref b) => b.get_node(x).get_leaf_data_iter(x), 436 NodeOption::Branches(ref b) => b.get_node(x).get_leaf_data_iter(x),
415 } 437 }
432 /// 454 ///
433 /// If `self` is already [`NodeOption::Leaf`], the data is inserted directly in this node. 455 /// If `self` is already [`NodeOption::Leaf`], the data is inserted directly in this node.
434 /// If `self` is a [`NodeOption::Branches`], the data is passed to branches whose subcubes 456 /// If `self` is a [`NodeOption::Branches`], the data is passed to branches whose subcubes
435 /// `support` intersects. If an [`NodeOption::Uninitialised`] node is encountered, a new leaf is 457 /// `support` intersects. If an [`NodeOption::Uninitialised`] node is encountered, a new leaf is
436 /// created at a minimum depth of `new_leaf_depth`. 458 /// created at a minimum depth of `new_leaf_depth`.
437 pub(super) fn insert<'refs, 'scope, M : Depth, S : LocalAnalysis <F, A, N>>( 459 pub(super) fn insert<'refs, 'scope, M: Depth, S: LocalAnalysis<F, A, N>>(
438 &mut self, 460 &mut self,
439 domain : &Cube<F,N>, 461 domain: &Cube<F, N>,
440 d : D, 462 d: D,
441 new_leaf_depth : M, 463 new_leaf_depth: M,
442 support : &S, 464 support: &S,
443 task_budget : TaskBudget<'scope, 'refs>, 465 task_budget: TaskBudget<'scope, 'refs>,
444 ) { 466 ) {
445 match &mut self.data { 467 match &mut self.data {
446 NodeOption::Uninitialised => { 468 NodeOption::Uninitialised => {
447 // Replace uninitialised node with a leaf or a branch 469 // Replace uninitialised node with a leaf or a branch
448 self.data = match new_leaf_depth.lower() { 470 self.data = match new_leaf_depth.lower() {
449 None => { 471 None => {
450 let a = support.local_analysis(&domain); 472 let a = support.local_analysis(&domain);
451 self.aggregator.aggregate(once(a)); 473 self.aggregator.aggregate(once(a));
452 // TODO: this is currently a dirty hard-coded heuristic; 474 // TODO: this is currently a dirty hard-coded heuristic;
453 // should add capacity as a parameter 475 // should add capacity as a parameter
454 let mut vec = Vec::with_capacity(2*P+1); 476 let mut vec = Vec::with_capacity(2 * P + 1);
455 vec.push(d); 477 vec.push(d);
456 NodeOption::Leaf(vec) 478 NodeOption::Leaf(vec)
457 }, 479 }
458 Some(lower) => { 480 Some(lower) => {
459 let b = Arc::new({ 481 let b = Arc::new({
460 let mut b0 = Branches::new_with(domain, support); 482 let mut b0 = Branches::new_with(domain, support);
461 b0.insert(domain, d, lower, support, task_budget); 483 b0.insert(domain, d, lower, support, task_budget);
462 b0 484 b0
463 }); 485 });
464 b.summarise_into(&mut self.aggregator); 486 b.summarise_into(&mut self.aggregator);
465 NodeOption::Branches(b) 487 NodeOption::Branches(b)
466 } 488 }
467 } 489 }
468 }, 490 }
469 NodeOption::Leaf(leaf) => { 491 NodeOption::Leaf(leaf) => {
470 leaf.push(d); 492 leaf.push(d);
471 let a = support.local_analysis(&domain); 493 let a = support.local_analysis(&domain);
472 self.aggregator.aggregate(once(a)); 494 self.aggregator.aggregate(once(a));
473 }, 495 }
474 NodeOption::Branches(b) => { 496 NodeOption::Branches(b) => {
475 // FIXME: recursion that may cause stack overflow if the tree becomes 497 // FIXME: recursion that may cause stack overflow if the tree becomes
476 // very deep, e.g. due to [`BTSearch::search_and_refine`]. 498 // very deep, e.g. due to [`BTSearch::search_and_refine`].
477 let bm = Arc::make_mut(b); 499 let bm = Arc::make_mut(b);
478 bm.insert(domain, d, new_leaf_depth.lower_or(), support, task_budget); 500 bm.insert(domain, d, new_leaf_depth.lower_or(), support, task_budget);
479 bm.summarise_into(&mut self.aggregator); 501 bm.summarise_into(&mut self.aggregator);
480 }, 502 }
481 } 503 }
482 } 504 }
483 505
484 /// Construct a new instance of the node for a different aggregator 506 /// Construct a new instance of the node for a different aggregator
485 /// 507 ///
487 /// [`Support`]s. The `domain` is the cube corresponding to `self`. 509 /// [`Support`]s. The `domain` is the cube corresponding to `self`.
488 /// The type parameter `ANew´ is the new aggregator, and needs to be implemented for the 510 /// The type parameter `ANew´ is the new aggregator, and needs to be implemented for the
489 /// generator's `SupportType`. 511 /// generator's `SupportType`.
490 pub(super) fn convert_aggregator<ANew, G>( 512 pub(super) fn convert_aggregator<ANew, G>(
491 mut self, 513 mut self,
492 generator : &G, 514 generator: &G,
493 domain : &Cube<F, N> 515 domain: &Cube<F, N>,
494 ) -> Node<F,D,ANew,N,P> 516 ) -> Node<F, D, ANew, N, P>
495 where ANew : Aggregator, 517 where
496 G : SupportGenerator<F, N, Id=D>, 518 ANew: Aggregator,
497 G::SupportType : LocalAnalysis<F, ANew, N> { 519 G: SupportGenerator<F, N, Id = D>,
498 520 G::SupportType: LocalAnalysis<F, ANew, N>,
521 {
499 // The mem::replace is needed due to the [`Drop`] implementation to extract self.data. 522 // The mem::replace is needed due to the [`Drop`] implementation to extract self.data.
500 match std::mem::replace(&mut self.data, NodeOption::Uninitialised) { 523 match std::mem::replace(&mut self.data, NodeOption::Uninitialised) {
501 NodeOption::Uninitialised => Node { 524 NodeOption::Uninitialised => Node {
502 data : NodeOption::Uninitialised, 525 data: NodeOption::Uninitialised,
503 aggregator : ANew::new(), 526 aggregator: ANew::new(),
504 }, 527 },
505 NodeOption::Leaf(v) => { 528 NodeOption::Leaf(v) => {
506 let mut anew = ANew::new(); 529 let mut anew = ANew::new();
507 anew.aggregate(v.iter().map(|d| { 530 anew.aggregate(v.iter().map(|d| {
508 let support = generator.support_for(*d); 531 let support = generator.support_for(*d);
509 support.local_analysis(&domain) 532 support.local_analysis(&domain)
510 })); 533 }));
511 534
512 Node { 535 Node {
513 data : NodeOption::Leaf(v), 536 data: NodeOption::Leaf(v),
514 aggregator : anew, 537 aggregator: anew,
515 } 538 }
516 }, 539 }
517 NodeOption::Branches(b) => { 540 NodeOption::Branches(b) => {
518 // FIXME: recursion that may cause stack overflow if the tree becomes 541 // FIXME: recursion that may cause stack overflow if the tree becomes
519 // very deep, e.g. due to [`BTSearch::search_and_refine`]. 542 // very deep, e.g. due to [`BTSearch::search_and_refine`].
520 let bnew = Arc::unwrap_or_clone(b).convert_aggregator(generator, domain); 543 let bnew = Arc::unwrap_or_clone(b).convert_aggregator(generator, domain);
521 let mut anew = ANew::new(); 544 let mut anew = ANew::new();
522 bnew.summarise_into(&mut anew); 545 bnew.summarise_into(&mut anew);
523 Node { 546 Node {
524 data : NodeOption::Branches(Arc::new(bnew)), 547 data: NodeOption::Branches(Arc::new(bnew)),
525 aggregator : anew, 548 aggregator: anew,
526 } 549 }
527 } 550 }
528 } 551 }
529 } 552 }
530 553
532 /// 555 ///
533 /// The `generator` is used to convert the data of type `D` of the node into corresponding 556 /// The `generator` is used to convert the data of type `D` of the node into corresponding
534 /// [`Support`]s. The `domain` is the cube corresponding to `self`. 557 /// [`Support`]s. The `domain` is the cube corresponding to `self`.
535 pub(super) fn refresh_aggregator<'refs, 'scope, G>( 558 pub(super) fn refresh_aggregator<'refs, 'scope, G>(
536 &mut self, 559 &mut self,
537 generator : &G, 560 generator: &G,
538 domain : &Cube<F, N>, 561 domain: &Cube<F, N>,
539 task_budget : TaskBudget<'scope, 'refs>, 562 task_budget: TaskBudget<'scope, 'refs>,
540 ) where G : SupportGenerator<F, N, Id=D>, 563 ) where
541 G::SupportType : LocalAnalysis<F, A, N> { 564 G: SupportGenerator<F, N, Id = D>,
565 G::SupportType: LocalAnalysis<F, A, N>,
566 {
542 match &mut self.data { 567 match &mut self.data {
543 NodeOption::Uninitialised => { }, 568 NodeOption::Uninitialised => {}
544 NodeOption::Leaf(v) => { 569 NodeOption::Leaf(v) => {
545 self.aggregator = A::new(); 570 self.aggregator = A::new();
546 self.aggregator.aggregate(v.iter().map(|d| { 571 self.aggregator.aggregate(
547 generator.support_for(*d) 572 v.iter()
548 .local_analysis(&domain) 573 .map(|d| generator.support_for(*d).local_analysis(&domain)),
549 })); 574 );
550 }, 575 }
551 NodeOption::Branches(ref mut b) => { 576 NodeOption::Branches(ref mut b) => {
552 // FIXME: recursion that may cause stack overflow if the tree becomes 577 // FIXME: recursion that may cause stack overflow if the tree becomes
553 // very deep, e.g. due to [`BTSearch::search_and_refine`]. 578 // very deep, e.g. due to [`BTSearch::search_and_refine`].
554 let bm = Arc::make_mut(b); 579 let bm = Arc::make_mut(b);
555 bm.refresh_aggregator(generator, domain, task_budget); 580 bm.refresh_aggregator(generator, domain, task_budget);
561 586
562 /// Helper trait for working with [`Node`]s without the knowledge of `P`. 587 /// Helper trait for working with [`Node`]s without the knowledge of `P`.
563 /// 588 ///
564 /// This can be removed and the methods implemented directly on [`BT`] once Rust's const generics 589 /// This can be removed and the methods implemented directly on [`BT`] once Rust's const generics
565 /// are flexible enough to allow fixing `P=pow(2, N)`. 590 /// are flexible enough to allow fixing `P=pow(2, N)`.
566 pub trait BTNode<F, D, A, const N : usize> 591 pub trait BTNode<F, D, A, const N: usize>
567 where F : Float, 592 where
568 D : 'static + Copy, 593 F: Float,
569 A : Aggregator { 594 D: 'static + Copy,
570 type Node : Clone + std::fmt::Debug; 595 A: Aggregator,
596 {
597 type Node: Clone + std::fmt::Debug;
571 } 598 }
572 599
573 /// Helper structure for looking up a [`Node`] without the knowledge of `P`. 600 /// Helper structure for looking up a [`Node`] without the knowledge of `P`.
574 /// 601 ///
575 /// This can be removed once Rust's const generics are flexible enough to allow fixing 602 /// This can be removed once Rust's const generics are flexible enough to allow fixing
578 pub struct BTNodeLookup; 605 pub struct BTNodeLookup;
579 606
580 /// Basic interface to a [`BT`] bisection tree. 607 /// Basic interface to a [`BT`] bisection tree.
581 /// 608 ///
582 /// Further routines are provided by the [`BTSearch`][super::refine::BTSearch] trait. 609 /// Further routines are provided by the [`BTSearch`][super::refine::BTSearch] trait.
583 pub trait BTImpl<F : Float, const N : usize> : std::fmt::Debug + Clone + GlobalAnalysis<F, Self::Agg> { 610 pub trait BTImpl<F: Float, const N: usize>:
611 std::fmt::Debug + Clone + GlobalAnalysis<F, Self::Agg>
612 {
584 /// The data type stored in the tree 613 /// The data type stored in the tree
585 type Data : 'static + Copy + Send + Sync; 614 type Data: 'static + Copy + Send + Sync;
586 /// The depth type of the tree 615 /// The depth type of the tree
587 type Depth : Depth; 616 type Depth: Depth;
588 /// The type for the [aggregate information][Aggregator] about the `Data` stored in each node 617 /// The type for the [aggregate information][Aggregator] about the `Data` stored in each node
589 /// of the tree. 618 /// of the tree.
590 type Agg : Aggregator; 619 type Agg: Aggregator;
591 /// The type of the tree with the aggregator converted to `ANew`. 620 /// The type of the tree with the aggregator converted to `ANew`.
592 type Converted<ANew> : BTImpl<F, N, Data=Self::Data, Agg=ANew> where ANew : Aggregator; 621 type Converted<ANew>: BTImpl<F, N, Data = Self::Data, Agg = ANew>
622 where
623 ANew: Aggregator;
593 624
594 /// Insert the data `d` into the tree for `support`. 625 /// Insert the data `d` into the tree for `support`.
595 /// 626 ///
596 /// Every leaf node of the tree that intersects the `support` will contain a copy of 627 /// Every leaf node of the tree that intersects the `support` will contain a copy of
597 /// `d`. 628 /// `d`.
598 fn insert<S : LocalAnalysis<F, Self::Agg, N>>( 629 fn insert<S: LocalAnalysis<F, Self::Agg, N>>(&mut self, d: Self::Data, support: &S);
599 &mut self,
600 d : Self::Data,
601 support : &S
602 );
603 630
604 /// Construct a new instance of the tree for a different aggregator 631 /// Construct a new instance of the tree for a different aggregator
605 /// 632 ///
606 /// The `generator` is used to convert the data of type [`Self::Data`] contained in the tree 633 /// The `generator` is used to convert the data of type [`Self::Data`] contained in the tree
607 /// into corresponding [`Support`]s. 634 /// into corresponding [`Support`]s.
608 fn convert_aggregator<ANew, G>(self, generator : &G) 635 fn convert_aggregator<ANew, G>(self, generator: &G) -> Self::Converted<ANew>
609 -> Self::Converted<ANew> 636 where
610 where ANew : Aggregator, 637 ANew: Aggregator,
611 G : SupportGenerator<F, N, Id=Self::Data>, 638 G: SupportGenerator<F, N, Id = Self::Data>,
612 G::SupportType : LocalAnalysis<F, ANew, N>; 639 G::SupportType: LocalAnalysis<F, ANew, N>;
613
614 640
615 /// Refreshes the aggregator of the three after possible changes to the support generator. 641 /// Refreshes the aggregator of the three after possible changes to the support generator.
616 /// 642 ///
617 /// The `generator` is used to convert the data of type [`Self::Data`] contained in the tree 643 /// The `generator` is used to convert the data of type [`Self::Data`] contained in the tree
618 /// into corresponding [`Support`]s. 644 /// into corresponding [`Support`]s.
619 fn refresh_aggregator<G>(&mut self, generator : &G) 645 fn refresh_aggregator<G>(&mut self, generator: &G)
620 where G : SupportGenerator<F, N, Id=Self::Data>, 646 where
621 G::SupportType : LocalAnalysis<F, Self::Agg, N>; 647 G: SupportGenerator<F, N, Id = Self::Data>,
648 G::SupportType: LocalAnalysis<F, Self::Agg, N>;
622 649
623 /// Returns an iterator over all [`Self::Data`] items at the point `x` of the domain. 650 /// Returns an iterator over all [`Self::Data`] items at the point `x` of the domain.
624 fn iter_at(&self, x : &Loc<F,N>) -> std::slice::Iter<'_, Self::Data>; 651 fn iter_at(&self, x: &Loc<F, N>) -> std::slice::Iter<'_, Self::Data>;
625 652
626 /* 653 /*
627 /// Returns all [`Self::Data`] items at the point `x` of the domain. 654 /// Returns all [`Self::Data`] items at the point `x` of the domain.
628 fn data_at(&self, x : &Loc<F,N>) -> Arc<Vec<Self::Data>>; 655 fn data_at(&self, x : &Loc<F,N>) -> Arc<Vec<Self::Data>>;
629 */ 656 */
630 657
631 /// Create a new tree on `domain` of indicated `depth`. 658 /// Create a new tree on `domain` of indicated `depth`.
632 fn new(domain : Cube<F, N>, depth : Self::Depth) -> Self; 659 fn new(domain: Cube<F, N>, depth: Self::Depth) -> Self;
633 } 660 }
634 661
635 /// The main bisection tree structure. 662 /// The main bisection tree structure.
636 /// 663 ///
637 /// It should be accessed via the [`BTImpl`] trait to hide the `const P : usize` parameter until 664 /// It should be accessed via the [`BTImpl`] trait to hide the `const P : usize` parameter until
638 /// const generics are flexible enough to fix `P=pow(2, N)` and thus also get rid of 665 /// const generics are flexible enough to fix `P=pow(2, N)` and thus also get rid of
639 /// the `BTNodeLookup : BTNode<F, D, A, N>` trait bound. 666 /// the `BTNodeLookup : BTNode<F, D, A, N>` trait bound.
640 #[derive(Clone,Debug)] 667 #[derive(Clone, Debug)]
641 pub struct BT< 668 pub struct BT<M: Depth, F: Float, D: 'static + Copy, A: Aggregator, const N: usize>
642 M : Depth, 669 where
643 F : Float, 670 BTNodeLookup: BTNode<F, D, A, N>,
644 D : 'static + Copy, 671 {
645 A : Aggregator,
646 const N : usize,
647 > where BTNodeLookup : BTNode<F, D, A, N> {
648 /// The depth of the tree (initial, before refinement) 672 /// The depth of the tree (initial, before refinement)
649 pub(super) depth : M, 673 pub(super) depth: M,
650 /// The domain of the toplevel node 674 /// The domain of the toplevel node
651 pub(super) domain : Cube<F, N>, 675 pub(super) domain: Cube<F, N>,
652 /// The toplevel node of the tree 676 /// The toplevel node of the tree
653 pub(super) topnode : <BTNodeLookup as BTNode<F, D, A, N>>::Node, 677 pub(super) topnode: <BTNodeLookup as BTNode<F, D, A, N>>::Node,
654 } 678 }
655 679
656 macro_rules! impl_bt { 680 macro_rules! impl_bt {
657 ($($n:literal)*) => { $( 681 ($($n:literal)*) => { $(
658 impl<F, D, A> BTNode<F, D, A, $n> for BTNodeLookup 682 impl<F, D, A> BTNode<F, D, A, $n> for BTNodeLookup
737 } 761 }
738 )* } 762 )* }
739 } 763 }
740 764
741 impl_bt!(1 2 3 4); 765 impl_bt!(1 2 3 4);
742

mercurial