| 1 use crate::mapping::{ |
1 use crate::instance::{ClosedSpace, Instance, Ownable, Space}; |
| 2 BasicDecomposition, DifferentiableImpl, DifferentiableMapping, Instance, Mapping, Space, |
2 use crate::mapping::{BasicDecomposition, DifferentiableImpl, DifferentiableMapping, Mapping}; |
| 3 }; |
|
| 4 use crate::types::Float; |
3 use crate::types::Float; |
| 5 use numeric_literals::replace_float_literals; |
4 use numeric_literals::replace_float_literals; |
| 6 use std::iter::Sum; |
5 use std::iter::Sum; |
| 7 use std::marker::PhantomData; |
6 use std::marker::PhantomData; |
| 8 use std::sync::Arc; |
7 use std::sync::Arc; |
| 35 bt: BT, |
34 bt: BT, |
| 36 generator: Arc<G>, |
35 generator: Arc<G>, |
| 37 _phantoms: PhantomData<F>, |
36 _phantoms: PhantomData<F>, |
| 38 } |
37 } |
| 39 |
38 |
| 40 impl<F: Float, G, BT, const N: usize> Space for BTFN<F, G, BT, N> |
39 impl<F: Float, G, BT, const N: usize> Ownable for BTFN<F, G, BT, N> |
| 41 where |
40 where |
| 42 G: SupportGenerator<N, F, Id = BT::Data>, |
41 G: SupportGenerator<N, F, Id = BT::Data>, |
| 43 G::SupportType: LocalAnalysis<F, BT::Agg, N>, |
42 G::SupportType: LocalAnalysis<F, BT::Agg, N>, |
| 44 BT: BTImpl<N, F>, |
43 BT: BTImpl<N, F>, |
| 45 { |
44 { |
| |
45 type OwnedVariant = Self; |
| |
46 |
| |
47 fn into_owned(self) -> Self::OwnedVariant { |
| |
48 self |
| |
49 } |
| |
50 |
| |
51 /// Returns an owned instance of a reference. |
| |
52 fn clone_owned(&self) -> Self::OwnedVariant { |
| |
53 self.clone() |
| |
54 } |
| |
55 } |
| |
56 |
| |
57 impl<F: Float, G, BT, const N: usize> Space for BTFN<F, G, BT, N> |
| |
58 where |
| |
59 G: SupportGenerator<N, F, Id = BT::Data>, |
| |
60 G::SupportType: LocalAnalysis<F, BT::Agg, N>, |
| |
61 BT: BTImpl<N, F>, |
| |
62 { |
| |
63 type OwnedSpace = Self; |
| 46 type Decomp = BasicDecomposition; |
64 type Decomp = BasicDecomposition; |
| 47 } |
65 } |
| 48 |
66 |
| 49 impl<F: Float, G, BT, const N: usize> BTFN<F, G, BT, N> |
67 impl<F: Float, G, BT, const N: usize> BTFN<F, G, BT, N> |
| 50 where |
68 where |
| 62 pub fn new(bt: BT, generator: G) -> Self { |
80 pub fn new(bt: BT, generator: G) -> Self { |
| 63 Self::new_arc(bt, Arc::new(generator)) |
81 Self::new_arc(bt, Arc::new(generator)) |
| 64 } |
82 } |
| 65 |
83 |
| 66 fn new_arc(bt: BT, generator: Arc<G>) -> Self { |
84 fn new_arc(bt: BT, generator: Arc<G>) -> Self { |
| 67 BTFN { |
85 BTFN { bt, generator, _phantoms: std::marker::PhantomData } |
| 68 bt: bt, |
|
| 69 generator: generator, |
|
| 70 _phantoms: std::marker::PhantomData, |
|
| 71 } |
|
| 72 } |
86 } |
| 73 |
87 |
| 74 /// Create a new BTFN support generator and a pre-initialised bisection tree, |
88 /// Create a new BTFN support generator and a pre-initialised bisection tree, |
| 75 /// cloning the tree and refreshing aggregators. |
89 /// cloning the tree and refreshing aggregators. |
| 76 /// |
90 /// |
| 159 where |
173 where |
| 160 G: SupportGenerator<N, F>, |
174 G: SupportGenerator<N, F>, |
| 161 { |
175 { |
| 162 /// Create a new [`PreBTFN`] with no bisection tree. |
176 /// Create a new [`PreBTFN`] with no bisection tree. |
| 163 pub fn new_pre(generator: G) -> Self { |
177 pub fn new_pre(generator: G) -> Self { |
| 164 BTFN { |
178 BTFN { bt: (), generator: Arc::new(generator), _phantoms: std::marker::PhantomData } |
| 165 bt: (), |
|
| 166 generator: Arc::new(generator), |
|
| 167 _phantoms: std::marker::PhantomData, |
|
| 168 } |
|
| 169 } |
179 } |
| 170 } |
180 } |
| 171 |
181 |
| 172 impl<F: Float, G, BT, const N: usize> BTFN<F, G, BT, N> |
182 impl<F: Float, G, BT, const N: usize> BTFN<F, G, BT, N> |
| 173 where |
183 where |
| 186 |
196 |
| 187 for (d, support) in both.all_right_data() { |
197 for (d, support) in both.all_right_data() { |
| 188 bt.insert(d, &support); |
198 bt.insert(d, &support); |
| 189 } |
199 } |
| 190 |
200 |
| 191 BTFN { |
201 BTFN { bt: bt, generator: Arc::new(both), _phantoms: std::marker::PhantomData } |
| 192 bt: bt, |
|
| 193 generator: Arc::new(both), |
|
| 194 _phantoms: std::marker::PhantomData, |
|
| 195 } |
|
| 196 } |
202 } |
| 197 } |
203 } |
| 198 |
204 |
| 199 macro_rules! make_btfn_add { |
205 macro_rules! make_btfn_add { |
| 200 ($lhs:ty, $preprocess:path, $($extra_trait:ident)?) => { |
206 ($lhs:ty, $preprocess:path, $($extra_trait:ident)?) => { |
| 409 impl<F: Float, G, BT, V, const N: usize> Mapping<Loc<N, F>> for BTFN<F, G, BT, N> |
415 impl<F: Float, G, BT, V, const N: usize> Mapping<Loc<N, F>> for BTFN<F, G, BT, N> |
| 410 where |
416 where |
| 411 BT: BTImpl<N, F>, |
417 BT: BTImpl<N, F>, |
| 412 G: SupportGenerator<N, F, Id = BT::Data>, |
418 G: SupportGenerator<N, F, Id = BT::Data>, |
| 413 G::SupportType: LocalAnalysis<F, BT::Agg, N> + Mapping<Loc<N, F>, Codomain = V>, |
419 G::SupportType: LocalAnalysis<F, BT::Agg, N> + Mapping<Loc<N, F>, Codomain = V>, |
| 414 V: Sum + Space, |
420 V: Sum + ClosedSpace, |
| 415 { |
421 { |
| 416 type Codomain = V; |
422 type Codomain = V; |
| 417 |
423 |
| 418 fn apply<I: Instance<Loc<N, F>>>(&self, x: I) -> Self::Codomain { |
424 fn apply<I: Instance<Loc<N, F>>>(&self, x: I) -> Self::Codomain { |
| 419 let xc = x.cow(); |
425 let xc = x.cow(); |
| 428 where |
434 where |
| 429 BT: BTImpl<N, F>, |
435 BT: BTImpl<N, F>, |
| 430 G: SupportGenerator<N, F, Id = BT::Data>, |
436 G: SupportGenerator<N, F, Id = BT::Data>, |
| 431 G::SupportType: |
437 G::SupportType: |
| 432 LocalAnalysis<F, BT::Agg, N> + DifferentiableMapping<Loc<N, F>, DerivativeDomain = V>, |
438 LocalAnalysis<F, BT::Agg, N> + DifferentiableMapping<Loc<N, F>, DerivativeDomain = V>, |
| 433 V: Sum + Space, |
439 V: Sum + ClosedSpace, |
| 434 { |
440 { |
| 435 type Derivative = V; |
441 type Derivative = V; |
| 436 |
442 |
| 437 fn differential_impl<I: Instance<Loc<N, F>>>(&self, x: I) -> Self::Derivative { |
443 fn differential_impl<I: Instance<Loc<N, F>>>(&self, x: I) -> Self::Derivative { |
| 438 let xc = x.cow(); |
444 let xc = x.cow(); |
| 844 G: SupportGenerator<N, F, Id = BT::Data>, |
850 G: SupportGenerator<N, F, Id = BT::Data>, |
| 845 G::SupportType: Mapping<Loc<N, F>, Codomain = F> + LocalAnalysis<F, Bounds<F>, N>, |
851 G::SupportType: Mapping<Loc<N, F>, Codomain = F> + LocalAnalysis<F, Bounds<F>, N>, |
| 846 Cube<N, F>: P2Minimise<Loc<N, F>, F>, |
852 Cube<N, F>: P2Minimise<Loc<N, F>, F>, |
| 847 { |
853 { |
| 848 fn maximise(&mut self, tolerance: F, max_steps: usize) -> (Loc<N, F>, F) { |
854 fn maximise(&mut self, tolerance: F, max_steps: usize) -> (Loc<N, F>, F) { |
| 849 let refiner = P2Refiner { |
855 let refiner = P2Refiner { tolerance, max_steps, how: RefineMax, bound: None }; |
| 850 tolerance, |
|
| 851 max_steps, |
|
| 852 how: RefineMax, |
|
| 853 bound: None, |
|
| 854 }; |
|
| 855 self.bt |
856 self.bt |
| 856 .search_and_refine(refiner, &self.generator) |
857 .search_and_refine(refiner, &self.generator) |
| 857 .expect("Refiner failure.") |
858 .expect("Refiner failure.") |
| 858 .unwrap() |
859 .unwrap() |
| 859 } |
860 } |
| 862 &mut self, |
863 &mut self, |
| 863 bound: F, |
864 bound: F, |
| 864 tolerance: F, |
865 tolerance: F, |
| 865 max_steps: usize, |
866 max_steps: usize, |
| 866 ) -> Option<(Loc<N, F>, F)> { |
867 ) -> Option<(Loc<N, F>, F)> { |
| 867 let refiner = P2Refiner { |
868 let refiner = P2Refiner { tolerance, max_steps, how: RefineMax, bound: Some(bound) }; |
| 868 tolerance, |
|
| 869 max_steps, |
|
| 870 how: RefineMax, |
|
| 871 bound: Some(bound), |
|
| 872 }; |
|
| 873 self.bt |
869 self.bt |
| 874 .search_and_refine(refiner, &self.generator) |
870 .search_and_refine(refiner, &self.generator) |
| 875 .expect("Refiner failure.") |
871 .expect("Refiner failure.") |
| 876 } |
872 } |
| 877 |
873 |
| 878 fn minimise(&mut self, tolerance: F, max_steps: usize) -> (Loc<N, F>, F) { |
874 fn minimise(&mut self, tolerance: F, max_steps: usize) -> (Loc<N, F>, F) { |
| 879 let refiner = P2Refiner { |
875 let refiner = P2Refiner { tolerance, max_steps, how: RefineMin, bound: None }; |
| 880 tolerance, |
|
| 881 max_steps, |
|
| 882 how: RefineMin, |
|
| 883 bound: None, |
|
| 884 }; |
|
| 885 self.bt |
876 self.bt |
| 886 .search_and_refine(refiner, &self.generator) |
877 .search_and_refine(refiner, &self.generator) |
| 887 .expect("Refiner failure.") |
878 .expect("Refiner failure.") |
| 888 .unwrap() |
879 .unwrap() |
| 889 } |
880 } |
| 892 &mut self, |
883 &mut self, |
| 893 bound: F, |
884 bound: F, |
| 894 tolerance: F, |
885 tolerance: F, |
| 895 max_steps: usize, |
886 max_steps: usize, |
| 896 ) -> Option<(Loc<N, F>, F)> { |
887 ) -> Option<(Loc<N, F>, F)> { |
| 897 let refiner = P2Refiner { |
888 let refiner = P2Refiner { tolerance, max_steps, how: RefineMin, bound: Some(bound) }; |
| 898 tolerance, |
|
| 899 max_steps, |
|
| 900 how: RefineMin, |
|
| 901 bound: Some(bound), |
|
| 902 }; |
|
| 903 self.bt |
889 self.bt |
| 904 .search_and_refine(refiner, &self.generator) |
890 .search_and_refine(refiner, &self.generator) |
| 905 .expect("Refiner failure.") |
891 .expect("Refiner failure.") |
| 906 } |
892 } |
| 907 |
893 |
| 908 fn has_upper_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool { |
894 fn has_upper_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool { |
| 909 let refiner = BoundRefiner { |
895 let refiner = BoundRefiner { bound, tolerance, max_steps, how: RefineMax }; |
| 910 bound, |
|
| 911 tolerance, |
|
| 912 max_steps, |
|
| 913 how: RefineMax, |
|
| 914 }; |
|
| 915 self.bt |
896 self.bt |
| 916 .search_and_refine(refiner, &self.generator) |
897 .search_and_refine(refiner, &self.generator) |
| 917 .expect("Refiner failure.") |
898 .expect("Refiner failure.") |
| 918 } |
899 } |
| 919 |
900 |
| 920 fn has_lower_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool { |
901 fn has_lower_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool { |
| 921 let refiner = BoundRefiner { |
902 let refiner = BoundRefiner { bound, tolerance, max_steps, how: RefineMin }; |
| 922 bound, |
|
| 923 tolerance, |
|
| 924 max_steps, |
|
| 925 how: RefineMin, |
|
| 926 }; |
|
| 927 self.bt |
903 self.bt |
| 928 .search_and_refine(refiner, &self.generator) |
904 .search_and_refine(refiner, &self.generator) |
| 929 .expect("Refiner failure.") |
905 .expect("Refiner failure.") |
| 930 } |
906 } |
| 931 } |
907 } |