src/bisection_tree/btfn.rs

branch
dev
changeset 150
c4e394a9c84c
parent 125
25b6c8b20d1b
child 152
dab30b331f49
child 162
bea0c3841ced
equal deleted inserted replaced
149:2f1798c65fd6 150:c4e394a9c84c
1 use crate::mapping::{ 1 use crate::instance::{ClosedSpace, Instance, Ownable, Space};
2 BasicDecomposition, DifferentiableImpl, DifferentiableMapping, Instance, Mapping, Space, 2 use crate::mapping::{BasicDecomposition, DifferentiableImpl, DifferentiableMapping, Mapping};
3 };
4 use crate::types::Float; 3 use crate::types::Float;
5 use numeric_literals::replace_float_literals; 4 use numeric_literals::replace_float_literals;
6 use std::iter::Sum; 5 use std::iter::Sum;
7 use std::marker::PhantomData; 6 use std::marker::PhantomData;
8 use std::sync::Arc; 7 use std::sync::Arc;
35 bt: BT, 34 bt: BT,
36 generator: Arc<G>, 35 generator: Arc<G>,
37 _phantoms: PhantomData<F>, 36 _phantoms: PhantomData<F>,
38 } 37 }
39 38
40 impl<F: Float, G, BT, const N: usize> Space for BTFN<F, G, BT, N> 39 impl<F: Float, G, BT, const N: usize> Ownable for BTFN<F, G, BT, N>
41 where 40 where
42 G: SupportGenerator<N, F, Id = BT::Data>, 41 G: SupportGenerator<N, F, Id = BT::Data>,
43 G::SupportType: LocalAnalysis<F, BT::Agg, N>, 42 G::SupportType: LocalAnalysis<F, BT::Agg, N>,
44 BT: BTImpl<N, F>, 43 BT: BTImpl<N, F>,
45 { 44 {
45 type OwnedVariant = Self;
46
47 fn into_owned(self) -> Self::OwnedVariant {
48 self
49 }
50
51 /// Returns an owned instance of a reference.
52 fn clone_owned(&self) -> Self::OwnedVariant {
53 self.clone()
54 }
55 }
56
57 impl<F: Float, G, BT, const N: usize> Space for BTFN<F, G, BT, N>
58 where
59 G: SupportGenerator<N, F, Id = BT::Data>,
60 G::SupportType: LocalAnalysis<F, BT::Agg, N>,
61 BT: BTImpl<N, F>,
62 {
63 type OwnedSpace = Self;
46 type Decomp = BasicDecomposition; 64 type Decomp = BasicDecomposition;
47 } 65 }
48 66
49 impl<F: Float, G, BT, const N: usize> BTFN<F, G, BT, N> 67 impl<F: Float, G, BT, const N: usize> BTFN<F, G, BT, N>
50 where 68 where
62 pub fn new(bt: BT, generator: G) -> Self { 80 pub fn new(bt: BT, generator: G) -> Self {
63 Self::new_arc(bt, Arc::new(generator)) 81 Self::new_arc(bt, Arc::new(generator))
64 } 82 }
65 83
66 fn new_arc(bt: BT, generator: Arc<G>) -> Self { 84 fn new_arc(bt: BT, generator: Arc<G>) -> Self {
67 BTFN { 85 BTFN { bt, generator, _phantoms: std::marker::PhantomData }
68 bt: bt,
69 generator: generator,
70 _phantoms: std::marker::PhantomData,
71 }
72 } 86 }
73 87
74 /// Create a new BTFN support generator and a pre-initialised bisection tree, 88 /// Create a new BTFN support generator and a pre-initialised bisection tree,
75 /// cloning the tree and refreshing aggregators. 89 /// cloning the tree and refreshing aggregators.
76 /// 90 ///
159 where 173 where
160 G: SupportGenerator<N, F>, 174 G: SupportGenerator<N, F>,
161 { 175 {
162 /// Create a new [`PreBTFN`] with no bisection tree. 176 /// Create a new [`PreBTFN`] with no bisection tree.
163 pub fn new_pre(generator: G) -> Self { 177 pub fn new_pre(generator: G) -> Self {
164 BTFN { 178 BTFN { bt: (), generator: Arc::new(generator), _phantoms: std::marker::PhantomData }
165 bt: (),
166 generator: Arc::new(generator),
167 _phantoms: std::marker::PhantomData,
168 }
169 } 179 }
170 } 180 }
171 181
172 impl<F: Float, G, BT, const N: usize> BTFN<F, G, BT, N> 182 impl<F: Float, G, BT, const N: usize> BTFN<F, G, BT, N>
173 where 183 where
186 196
187 for (d, support) in both.all_right_data() { 197 for (d, support) in both.all_right_data() {
188 bt.insert(d, &support); 198 bt.insert(d, &support);
189 } 199 }
190 200
191 BTFN { 201 BTFN { bt: bt, generator: Arc::new(both), _phantoms: std::marker::PhantomData }
192 bt: bt,
193 generator: Arc::new(both),
194 _phantoms: std::marker::PhantomData,
195 }
196 } 202 }
197 } 203 }
198 204
199 macro_rules! make_btfn_add { 205 macro_rules! make_btfn_add {
200 ($lhs:ty, $preprocess:path, $($extra_trait:ident)?) => { 206 ($lhs:ty, $preprocess:path, $($extra_trait:ident)?) => {
409 impl<F: Float, G, BT, V, const N: usize> Mapping<Loc<N, F>> for BTFN<F, G, BT, N> 415 impl<F: Float, G, BT, V, const N: usize> Mapping<Loc<N, F>> for BTFN<F, G, BT, N>
410 where 416 where
411 BT: BTImpl<N, F>, 417 BT: BTImpl<N, F>,
412 G: SupportGenerator<N, F, Id = BT::Data>, 418 G: SupportGenerator<N, F, Id = BT::Data>,
413 G::SupportType: LocalAnalysis<F, BT::Agg, N> + Mapping<Loc<N, F>, Codomain = V>, 419 G::SupportType: LocalAnalysis<F, BT::Agg, N> + Mapping<Loc<N, F>, Codomain = V>,
414 V: Sum + Space, 420 V: Sum + ClosedSpace,
415 { 421 {
416 type Codomain = V; 422 type Codomain = V;
417 423
418 fn apply<I: Instance<Loc<N, F>>>(&self, x: I) -> Self::Codomain { 424 fn apply<I: Instance<Loc<N, F>>>(&self, x: I) -> Self::Codomain {
419 let xc = x.cow(); 425 let xc = x.cow();
428 where 434 where
429 BT: BTImpl<N, F>, 435 BT: BTImpl<N, F>,
430 G: SupportGenerator<N, F, Id = BT::Data>, 436 G: SupportGenerator<N, F, Id = BT::Data>,
431 G::SupportType: 437 G::SupportType:
432 LocalAnalysis<F, BT::Agg, N> + DifferentiableMapping<Loc<N, F>, DerivativeDomain = V>, 438 LocalAnalysis<F, BT::Agg, N> + DifferentiableMapping<Loc<N, F>, DerivativeDomain = V>,
433 V: Sum + Space, 439 V: Sum + ClosedSpace,
434 { 440 {
435 type Derivative = V; 441 type Derivative = V;
436 442
437 fn differential_impl<I: Instance<Loc<N, F>>>(&self, x: I) -> Self::Derivative { 443 fn differential_impl<I: Instance<Loc<N, F>>>(&self, x: I) -> Self::Derivative {
438 let xc = x.cow(); 444 let xc = x.cow();
844 G: SupportGenerator<N, F, Id = BT::Data>, 850 G: SupportGenerator<N, F, Id = BT::Data>,
845 G::SupportType: Mapping<Loc<N, F>, Codomain = F> + LocalAnalysis<F, Bounds<F>, N>, 851 G::SupportType: Mapping<Loc<N, F>, Codomain = F> + LocalAnalysis<F, Bounds<F>, N>,
846 Cube<N, F>: P2Minimise<Loc<N, F>, F>, 852 Cube<N, F>: P2Minimise<Loc<N, F>, F>,
847 { 853 {
848 fn maximise(&mut self, tolerance: F, max_steps: usize) -> (Loc<N, F>, F) { 854 fn maximise(&mut self, tolerance: F, max_steps: usize) -> (Loc<N, F>, F) {
849 let refiner = P2Refiner { 855 let refiner = P2Refiner { tolerance, max_steps, how: RefineMax, bound: None };
850 tolerance,
851 max_steps,
852 how: RefineMax,
853 bound: None,
854 };
855 self.bt 856 self.bt
856 .search_and_refine(refiner, &self.generator) 857 .search_and_refine(refiner, &self.generator)
857 .expect("Refiner failure.") 858 .expect("Refiner failure.")
858 .unwrap() 859 .unwrap()
859 } 860 }
862 &mut self, 863 &mut self,
863 bound: F, 864 bound: F,
864 tolerance: F, 865 tolerance: F,
865 max_steps: usize, 866 max_steps: usize,
866 ) -> Option<(Loc<N, F>, F)> { 867 ) -> Option<(Loc<N, F>, F)> {
867 let refiner = P2Refiner { 868 let refiner = P2Refiner { tolerance, max_steps, how: RefineMax, bound: Some(bound) };
868 tolerance,
869 max_steps,
870 how: RefineMax,
871 bound: Some(bound),
872 };
873 self.bt 869 self.bt
874 .search_and_refine(refiner, &self.generator) 870 .search_and_refine(refiner, &self.generator)
875 .expect("Refiner failure.") 871 .expect("Refiner failure.")
876 } 872 }
877 873
878 fn minimise(&mut self, tolerance: F, max_steps: usize) -> (Loc<N, F>, F) { 874 fn minimise(&mut self, tolerance: F, max_steps: usize) -> (Loc<N, F>, F) {
879 let refiner = P2Refiner { 875 let refiner = P2Refiner { tolerance, max_steps, how: RefineMin, bound: None };
880 tolerance,
881 max_steps,
882 how: RefineMin,
883 bound: None,
884 };
885 self.bt 876 self.bt
886 .search_and_refine(refiner, &self.generator) 877 .search_and_refine(refiner, &self.generator)
887 .expect("Refiner failure.") 878 .expect("Refiner failure.")
888 .unwrap() 879 .unwrap()
889 } 880 }
892 &mut self, 883 &mut self,
893 bound: F, 884 bound: F,
894 tolerance: F, 885 tolerance: F,
895 max_steps: usize, 886 max_steps: usize,
896 ) -> Option<(Loc<N, F>, F)> { 887 ) -> Option<(Loc<N, F>, F)> {
897 let refiner = P2Refiner { 888 let refiner = P2Refiner { tolerance, max_steps, how: RefineMin, bound: Some(bound) };
898 tolerance,
899 max_steps,
900 how: RefineMin,
901 bound: Some(bound),
902 };
903 self.bt 889 self.bt
904 .search_and_refine(refiner, &self.generator) 890 .search_and_refine(refiner, &self.generator)
905 .expect("Refiner failure.") 891 .expect("Refiner failure.")
906 } 892 }
907 893
908 fn has_upper_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool { 894 fn has_upper_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool {
909 let refiner = BoundRefiner { 895 let refiner = BoundRefiner { bound, tolerance, max_steps, how: RefineMax };
910 bound,
911 tolerance,
912 max_steps,
913 how: RefineMax,
914 };
915 self.bt 896 self.bt
916 .search_and_refine(refiner, &self.generator) 897 .search_and_refine(refiner, &self.generator)
917 .expect("Refiner failure.") 898 .expect("Refiner failure.")
918 } 899 }
919 900
920 fn has_lower_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool { 901 fn has_lower_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool {
921 let refiner = BoundRefiner { 902 let refiner = BoundRefiner { bound, tolerance, max_steps, how: RefineMin };
922 bound,
923 tolerance,
924 max_steps,
925 how: RefineMin,
926 };
927 self.bt 903 self.bt
928 .search_and_refine(refiner, &self.generator) 904 .search_and_refine(refiner, &self.generator)
929 .expect("Refiner failure.") 905 .expect("Refiner failure.")
930 } 906 }
931 } 907 }

mercurial