src/euclidean/wrap.rs

branch
dev
changeset 149
2f1798c65fd6
parent 147
d6009939e832
child 174
53ab61a41d70
equal deleted inserted replaced
148:26ef556870fd 149:2f1798c65fd6
1 /*! 1 /*!
2 Wrappers for implemention [`Euclidean`] operations. 2 Wrappers for implemention [`Euclidean`] operations.
3 */ 3 */
4 4
5 pub trait Wrapped { 5 use crate::euclidean::Euclidean;
6 type Unwrapped; 6 use crate::instance::Space;
7 type UnwrappedMut; 7 use crate::types::Float;
8 type UnwrappedOutput; 8
9 pub trait Wrapped: Space {
10 type WrappedField: Float;
11 type Unwrapped: Euclidean<Self::WrappedField>;
12 type UnwrappedMut: Euclidean<Self::WrappedField>;
13 type UnwrappedOutput: Euclidean<Self::WrappedField>;
14 type WrappedOutput;
9 fn get_view(&self) -> Self::Unwrapped; 15 fn get_view(&self) -> Self::Unwrapped;
10 fn get_view_mut(&mut self) -> Self::UnwrappedMut; 16 fn get_view_mut(&mut self) -> Self::UnwrappedMut;
11 fn wrap(output: Self::UnwrappedOutput) -> Self; 17 fn wrap(output: Self::UnwrappedOutput) -> Self::WrappedOutput;
12 } 18 }
13 19
14 #[macro_export] 20 #[macro_export]
15 macro_rules! wrap { 21 macro_rules! wrap {
16 (impl_unary<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => { 22 // Rust macros are totally fucked up. $trait:path does not work, have to
17 impl<$($qual)*> $trait for $type { 23 // manually code paths through $($trait:ident)::+.
18 type Output = Self::WrappedOutput; 24 (impl_unary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
25 impl<$($qual)*> $($trait)::+ for $type {
26 type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
19 fn $fn(self) -> Self::Output { 27 fn $fn(self) -> Self::Output {
20 Self::wrap(self.get_view().$fn()) 28 Self::wrap(self.get_view().$fn())
21 } 29 }
22 } 30 }
23 }; 31 };
24 (impl_binary<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => { 32 (impl_binary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
25 impl<$($qual)*> $trait<$type> for $type { 33 impl<$($qual)*> $($trait)::+<$type> for $type {
26 type Output = Self::WrappedOutput; 34 type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
27 fn $fn(self, other: $type) -> Self::Output { 35 fn $fn(self, other: $type) -> Self::Output {
28 Self::wrap(self.get_view().$fn(other.get_view())) 36 Self::wrap(self.get_view().$fn(other.get_view()))
29 } 37 }
30 } 38 }
31 39
32 impl<'a, $($qual)*> $trait<$type> for &'a $type { 40 impl<'a, $($qual)*> $($trait)::+<$type> for &'a $type {
33 type Output = Self::WrappedOutput; 41 type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
34 fn $fn(self, other: $type) -> Self::Output { 42 fn $fn(self, other: $type) -> Self::Output {
43 <$type>::wrap(self.get_view().$fn(other.get_view()))
44 }
45 }
46
47 impl<'a, 'b, $($qual)*> $($trait)::+<&'b $type> for &'a $type {
48 type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
49 fn $fn(self, other: &'b $type) -> Self::Output {
50 <$type>::wrap(self.get_view().$fn(other.get_view()))
51 }
52 }
53
54 impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type {
55 type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
56 fn $fn(self, other: &'b $type) -> Self::Output {
35 Self::wrap(self.get_view().$fn(other.get_view())) 57 Self::wrap(self.get_view().$fn(other.get_view()))
36 } 58 }
37 } 59 }
38 60 };
39 impl<'a, 'b, $($qual)*> $trait<&'b $type> for &'a $type { 61 (impl_scalar $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
40 type Output = Self::WrappedOutput; 62 impl<$($qual)*> $($trait)::+<$F> for $type
41 fn $fn(self, other: $type) -> Self::Output { 63 // where
42 Self::wrap(self.get_view().$fn(other.get_view())) 64 // $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
43 } 65 // //$type::Unwrapped: $($trait)::+<F>,
44 } 66 {
45 67 type Output = <Self as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
46 impl<'b, $($qual)*> $trait<&'b $type> for $type { 68 fn $fn(self, t: $F) -> Self::Output {
47 type Output = Self::WrappedOutput;
48 fn $fn(self, other: $type) -> Self::Output {
49 Self::wrap(self.get_view().$fn(other.get_view()))
50 }
51 }
52 };
53 (impl_scalar<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
54 impl<$($qual)*, F: Num> $trait<F> for $type,
55 where
56 $type::Unwrapped: $trait<F>,
57 {
58 type Output = Self::WrappedOutput;
59 fn $fn(self, t: F) -> Self::Output {
60 Self::wrap(self.get_view().$fn(t)) 69 Self::wrap(self.get_view().$fn(t))
61 } 70 }
62 } 71 }
63 72
64 impl<'a, $($qual)*, F: Num> $trait<F> for &'a $type, 73 impl<'a, $($qual)*> $($trait)::+<$F> for &'a $type
65 where 74 // where
66 $type::Unwrapped: $trait<F>, 75 // $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
67 { 76 // //$type::Unwrapped: $($trait)::+<F>,
68 type Output = Self::WrappedOutput; 77 {
69 fn $fn(self, t: F) -> Self::Output { 78 type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
70 Self::wrap(self.get_view().$fn(t)) 79 fn $fn(self, t: $F) -> Self::Output {
71 } 80 <$type>::wrap(self.get_view().$fn(t))
72 } 81 }
73 82 }
74 }; 83
75 (impl_scalar_lhs<$($qual:tt)*> $type:ty, $trait:path, $fn:ident, $F:ty) => { 84 };
76 impl<$($qual)*> $trait<$type> for $F 85 (impl_scalar_lhs $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
77 where 86 impl<$($qual)*> $($trait)::+<$type> for $F
78 $F: $type::Unwrapped, 87 // where
79 { 88 // $type: $crate::euclidean::wrap::Wrapped<WrappedField = $F>,
80 type Output = Self::WrappedOutput; 89 // // where
90 // // $F: $($trait)::+<$type::Unwrapped>,
91 {
92 type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput;
81 fn $fn(self, rhs: $type) -> Self::Output { 93 fn $fn(self, rhs: $type) -> Self::Output {
82 Self::wrap(self.$fn(rhs.get_view())) 94 <$type>::wrap(self.$fn(rhs.get_view()))
83 } 95 }
84 } 96 }
85 }; 97 };
86 (impl_binary_mut<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => { 98 (impl_binary_mut $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
87 impl<$($qual)*> $trait<$type> for $type { 99 impl<$($qual)*> $($trait)::+<$type> for $type {
88 fn $fn(&mut self, rhs: $type) { 100 fn $fn(&mut self, rhs: $type) {
89 self.get_view_mut().$fn(rhs.get_view()) 101 self.get_view_mut().$fn(rhs.get_view())
90 } 102 }
91 } 103 }
92 104
93 impl<'b, $($qual)*> $trait<&'b $type> for $type { 105 impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type {
94 fn $fn(&mut self, rhs: $type) { 106 fn $fn(&mut self, rhs: &'b $type) {
95 self.get_view_mut().$fn(rhs.get_view()) 107 self.get_view_mut().$fn(rhs.get_view())
96 } 108 }
97 } 109 }
98 }; 110 };
99 (impl_scalar_mut<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => { 111 (impl_scalar_mut $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => {
100 impl<$($qual)*> $trait<F> for $type 112 impl<$($qual)*> $($trait)::+<$F> for $type
101 where 113 // where
102 $type::UnwrappedMut: $trait<F>, 114 // $type: $crate::euclidean::wrap::Wrapped<WrappedField = F>,
103 { 115 // // where
104 fn $fn(&mut self, t: F) { 116 // // $type::UnwrappedMut: $($trait)::+<$($trait)::+<F>>,
105 self.unwrap_mut().$fn(t) 117 {
106 } 118 fn $fn(&mut self, t: $F) {
107 } 119 self.get_view_mut().$fn(t)
108 }; 120 }
109 ($type:ty) => { 121 }
110 $crate::wrap!(imp<> do $type); 122 };
111 }; 123 // ($type:ty) => {
112 (imp<$($qual:tt)*> $type:ty) => { 124 // $crate::wrap!(imp<> do $type);
113 $crate::wrap!(impl_unary<$($qual)*> $type, std::ops::Neg, neg); 125 // };
114 $crate::wrap!(impl_binary<$($qual)*> $type, std::ops::Add, add); 126 ($F:ty; $type:ty where $($qual:tt)*) => {
115 $crate::wrap!(impl_binary<$($qual)*> $type, std::ops::Sub, sub); 127 $crate::wrap!(impl_unary $type, std::ops::Neg, neg where $($qual)*);
116 $crate::wrap!(impl_scalar<$($qual)*> $type, std::ops::Mul, mul); 128 $crate::wrap!(impl_binary $type, std::ops::Add, add where $($qual)*);
117 $crate::wrap!(impl_scalar<$($qual)*> $type, std::ops::Div, div); 129 $crate::wrap!(impl_binary $type, std::ops::Sub, sub where $($qual)*);
118 $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Mul, mul, f32); 130 $crate::wrap!(impl_scalar $F, $type, std::ops::Mul, mul where $($qual)*);
119 $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Mul, mul, f64); 131 $crate::wrap!(impl_scalar $F, $type, std::ops::Div, div where $($qual)*);
120 $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Div, div, f32); 132 $crate::wrap!(impl_scalar_lhs $F, $type, std::ops::Mul, mul where $($qual)*);
121 $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Div, div, f64); 133 $crate::wrap!(impl_binary_mut $type, std::ops::AddAssign, add_assign where $($qual)*);
122 $crate::wrap!(impl_binary_mut<$($qual)*> $type, std::ops::AddAssign, add_assign); 134 $crate::wrap!(impl_binary_mut $type, std::ops::SubAssign, sub_assign where $($qual)*);
123 $crate::wrap!(impl_binary_mut<$($qual)*> $type, std::ops::SubAssign, sub_assign); 135 $crate::wrap!(impl_scalar_mut $F, $type, std::ops::MulAssign, mul_assign where $($qual)*);
124 $crate::wrap!(impl_scalar_mut<$($qual)*> $type, std::ops::MulAssign, mul_assign); 136 $crate::wrap!(impl_scalar_mut $F, $type, std::ops::DivAssign, div_assign where $($qual)*);
125 $crate::wrap!(impl_scalar_mut<$($qual)*> $type, std::ops::DivAssign, div_assign); 137
126 138 impl<$($qual)*> $crate::euclidean::Euclidean<$F> for $type
127 /// We only support 'closed' `Euclidean` `Pair`s, as more general ones cause 139 // where
128 /// compiler overflows. 140 // Self: $crate::euclidean::wrap::Wrapped<WrappedField = $F>
129 impl<$($qual)* F: $crate::types::Float> $crate::euclidean::Euclidean<F> for $type 141 // + Sized
130 where 142 // + std::ops::Mul<F, Output = <Self as $crate::linops::AXPY>::Owned>
131 //Pair<A, B>: Euclidean<F>, 143 // + std::ops::MulAssign<F>
132 Self: $crate::euclidean::wrap::Wrapped 144 // + std::ops::Div<F, Output = <Self as $crate::linops::AXPY>::Owned>
133 + Sized 145 // + std::ops::DivAssign<F>
134 + std::ops::Mul<F, Output = <Self as $crate::linops::AXPY>::Owned> 146 // + std::ops::Add<Self, Output = <Self as $crate::linops::AXPY>::Owned>
135 + std::ops::MulAssign<F> 147 // + std::ops::Sub<Self, Output = <Self as $crate::linops::AXPY>::Owned>
136 + std::ops::Div<F, Output = <Self as $crate::linops::AXPY>::Owned> 148 // + for<'b> std::ops::Add<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
137 + std::ops::DivAssign<F> 149 // + for<'b> std::ops::Sub<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
138 + std::ops::Add<Self, Output = <Self as $crate::linops::AXPY>::Owned> 150 // + std::ops::AddAssign<Self>
139 + std::ops::Sub<Self, Output = <Self as $crate::linops::AXPY>::Owned> 151 // + for<'b> std::ops::AddAssign<&'b Self>
140 + for<'b> std::ops::Add<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned> 152 // + std::ops::SubAssign<Self>
141 + for<'b> std::ops::Sub<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned> 153 // + for<'b> std::ops::SubAssign<&'b Self>
142 + std::ops::AddAssign<Self> 154 // + std::ops::Neg<Output = <Self as $crate::linops::AXPY>::Owned>,
143 + for<'b> std::ops::AddAssign<&'b Self> 155 {
144 + std::ops::SubAssign<Self> 156 fn dot<I: $crate::instance::Instance<Self>>(&self, other: I) -> $F {
145 + for<'b> std::ops::SubAssign<&'b Self> 157 other.eval_decompose(|x| self.get_view().dot(&x.get_view()))
146 + std::ops::Neg<Output = <Self as $crate::linops::AXPY>::Owned>, 158 }
147 { 159
148 fn dot<I: $crate::instance::Instance<Self>>(&self, other: I) -> F { 160 fn norm2_squared(&self) -> $F {
149 other.eval_decompose(|x| self.get_view().dot(x.get_view()))
150 }
151
152 fn norm2_squared(&self) -> F {
153 self.get_view().norm2_squared() 161 self.get_view().norm2_squared()
154 } 162 }
155 163
156 fn dist2_squared<I: $crate::instance::Instance<Self>>(&self, other: I) -> F { 164 fn dist2_squared<I: $crate::instance::Instance<Self>>(&self, other: I) -> $F {
157 other.eval_decompose(|x| self.get_view().dist2_squared(x.get_view())) 165 other.eval_decompose(|x| self.get_view().dist2_squared(x.get_view()))
158 } 166 }
159 } 167 }
160 168
161 impl<$($qual)* F : $crate::types::Float> $crate::linops::AXPY for $type 169 impl<$($qual)*> $crate::linops::AXPY for $type
162 where 170 // where
163 Self : $crate::euclidean::wrap::Wrapped, 171 // Self : $crate::euclidean::wrap::Wrapped<WrappedField = $F>,
164 Self::Unwrapped : $crate::linops::AXPY<Field = F>, 172 // Self::Unwrapped : $crate::linops::AXPY<Field = F>,
165 Self: std::ops::MulAssign<F> + std::ops::DivAssign<F>, 173 // Self: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
166 Self::Unwrapped: std::ops::MulAssign<F> + std::ops::DivAssign<F>, 174 // Self::Unwrapped: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
167 { 175 {
168 type Field = F; 176 type Field = $F;
169 type Owned = Self; 177 type Owned = Self;
170 178
171 fn axpy<I: $crate::instance::Instance<Self>>(&mut self, α: F, x: I, β: F) { 179 fn axpy<I: $crate::instance::Instance<Self>>(&mut self, α: $F, x: I, β: $F) {
172 x.eval_decompose(|v| { 180 x.eval_decompose(|v| {
173 self.get_mut_view().axpy(α, v.get_view(), β) 181 self.get_view_mut().axpy(α, v.get_view(), β)
174 }) 182 })
175 } 183 }
176 184
177 fn copy_from<I: $crate::instance::Instance<Self>>(&mut self, x: I) { 185 fn copy_from<I: $crate::instance::Instance<Self>>(&mut self, x: I) {
178 x.eval_decompose(|Pair(u, v)| { 186 x.eval_decompose(|v| {
179 self.get_mut_view().copy_from(v.get_view()) 187 self.get_view_mut().copy_from(v.get_view())
180 }) 188 })
181 } 189 }
182 190
183 fn scale_from<I: $crate::instance::Instance<Self>>(&mut self, α: F, x: I) { 191 fn scale_from<I: $crate::instance::Instance<Self>>(&mut self, α: $F, x: I) {
184 x.eval_decompose(|v| { 192 x.eval_decompose(|v| {
185 self.get_mut_view().scale_from(α, v.get_view()) 193 self.get_mut_view().scale_from(α, v.get_view())
186 }) 194 })
187 } 195 }
188 196
195 fn set_zero(&mut self) { 203 fn set_zero(&mut self) {
196 self.get_mut_view().set_zero() 204 self.get_mut_view().set_zero()
197 } 205 }
198 } 206 }
199 207
200 impl<$($qual)*> $crate::instance::Space for $type 208 impl<$($qual)*> $crate::instance::Space for $type {
201 where Self : $crate::euclidean::wrap::Wrapped { 209 type Decomp = <<Self as $crate::euclidean::wrap::Wrapped>::Unwrapped as $crate::instance::Space>::Decomp;
202 type Decomp = Self::Unwrapped::Decomp;
203 } 210 }
204 }; 211 };
205 } 212 }

mercurial