src/euclidean/wrap.rs

branch
dev
changeset 147
d6009939e832
parent 146
3f9a03f95457
child 149
2f1798c65fd6
equal deleted inserted replaced
146:3f9a03f95457 147:d6009939e832
1 /*! 1 /*!
2 Wrappers for implemention [`Euclidean`] operations. 2 Wrappers for implemention [`Euclidean`] operations.
3 */ 3 */
4 4
5 use super::Euclidean; 5 pub trait Wrapped {
6 use crate::instance::{Decomposition, DecompositionMut, Instance, InstanceMut, MyCow}; 6 type Unwrapped;
7 use crate::linops::AXPY; 7 type UnwrappedMut;
8 use crate::loc::Loc; 8 type UnwrappedOutput;
9 use crate::mapping::Space; 9 fn get_view(&self) -> Self::Unwrapped;
10 use crate::norms::{HasDual, Norm, NormExponent, Normed, PairNorm, L2}; 10 fn get_view_mut(&mut self) -> Self::UnwrappedMut;
11 use crate::types::{Float, Num}; 11 fn wrap(output: Self::UnwrappedOutput) -> Self;
12 use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; 12 }
13 use serde::{Deserialize, Serialize}; 13
14 use std::clone::Clone; 14 #[macro_export]
15 15 macro_rules! wrap {
16 macro_rules! impl_unary { 16 (impl_unary<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
17 ($type:ty, $trait:ident, $fn:ident) => { 17 impl<$($qual)*> $trait for $type {
18 impl $trait for $type {
19 type Output = Self::WrappedOutput; 18 type Output = Self::WrappedOutput;
20 fn $fn(self) -> Self::Output { 19 fn $fn(self) -> Self::Output {
21 $type::wrap(self.get_view().$fn()) 20 Self::wrap(self.get_view().$fn())
22 } 21 }
23 } 22 }
24 }; 23 };
25 } 24 (impl_binary<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
26 25 impl<$($qual)*> $trait<$type> for $type {
27 macro_rules! impl_binary { 26 type Output = Self::WrappedOutput;
28 ($type:ty, $trait:ident, $fn:ident) => { 27 fn $fn(self, other: $type) -> Self::Output {
29 impl $trait<$type> for $type { 28 Self::wrap(self.get_view().$fn(other.get_view()))
30 type Output = Self::WrappedOutput; 29 }
31 fn $fn(self, other: $type) -> Self::Output { 30 }
32 $type::wrap(self.get_view().$fn(other.get_view())) 31
33 } 32 impl<'a, $($qual)*> $trait<$type> for &'a $type {
34 } 33 type Output = Self::WrappedOutput;
35 34 fn $fn(self, other: $type) -> Self::Output {
36 impl<'a> $trait<$type> for &'a $type { 35 Self::wrap(self.get_view().$fn(other.get_view()))
37 type Output = Self::WrappedOutput; 36 }
38 fn $fn(self, other: $type) -> Self::Output { 37 }
39 $type::wrap(self.get_view().$fn(other.get_view())) 38
40 } 39 impl<'a, 'b, $($qual)*> $trait<&'b $type> for &'a $type {
41 } 40 type Output = Self::WrappedOutput;
42 41 fn $fn(self, other: $type) -> Self::Output {
43 impl<'a, 'b> $trait<&'b $type> for &'a $type { 42 Self::wrap(self.get_view().$fn(other.get_view()))
44 type Output = Self::WrappedOutput; 43 }
45 fn $fn(self, other: $type) -> Self::Output { 44 }
46 $type::wrap(self.get_view().$fn(other.get_view())) 45
47 } 46 impl<'b, $($qual)*> $trait<&'b $type> for $type {
48 } 47 type Output = Self::WrappedOutput;
49 48 fn $fn(self, other: $type) -> Self::Output {
50 impl<'b> $trait<&'b $type> for $type { 49 Self::wrap(self.get_view().$fn(other.get_view()))
51 type Output = Self::WrappedOutput; 50 }
52 fn $fn(self, other: $type) -> Self::Output { 51 }
53 $type::wrap(self.get_view().$fn(other.get_view())) 52 };
54 } 53 (impl_scalar<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
55 } 54 impl<$($qual)*, F: Num> $trait<F> for $type,
56 };
57 }
58
59 macro_rules! impl_scalar {
60 ($type:ty, $trait:ident, $fn:ident) => {
61 impl<F: Num> $trait<F> for $type,
62 where 55 where
63 $type::Unwrapped: $trait<F>, 56 $type::Unwrapped: $trait<F>,
64 { 57 {
65 type Output = Self::WrappedOutput; 58 type Output = Self::WrappedOutput;
66 fn $fn(self, t: F) -> Self::Output { 59 fn $fn(self, t: F) -> Self::Output {
67 $type::wrap(self.get_view().$fn(t)) 60 Self::wrap(self.get_view().$fn(t))
68 } 61 }
69 } 62 }
70 63
71 impl<'a, F: Num> $trait<F> for &'a $type, 64 impl<'a, $($qual)*, F: Num> $trait<F> for &'a $type,
72 where 65 where
73 $type::Unwrapped: $trait<F>, 66 $type::Unwrapped: $trait<F>,
74 { 67 {
75 type Output = Self::WrappedOutput; 68 type Output = Self::WrappedOutput;
76 fn $fn(self, t: F) -> Self::Output { 69 fn $fn(self, t: F) -> Self::Output {
77 $type::wrap(self.get_view().$fn(t)) 70 Self::wrap(self.get_view().$fn(t))
78 } 71 }
79 } 72 }
80 73
81 }; 74 };
82 } 75 (impl_scalar_lhs<$($qual:tt)*> $type:ty, $trait:path, $fn:ident, $F:ty) => {
83 76 impl<$($qual)*> $trait<$type> for $F
84 macro_rules! impl_scalar_lhs {
85 ($type:ty, $trait:ident, $fn:ident, $F:ty) => {
86 impl<A, B> $trait<$type> for $F
87 where 77 where
88 $F: $type::Unwrapped, 78 $F: $type::Unwrapped,
89 { 79 {
90 type Output = Self::WrappedOutput; 80 type Output = Self::WrappedOutput;
91 fn $fn(self, rhs: $type) -> Self::Output { 81 fn $fn(self, rhs: $type) -> Self::Output {
92 $type::wrap(self.$fn(rhs.get_view())) 82 Self::wrap(self.$fn(rhs.get_view()))
93 } 83 }
94 } 84 }
95 }; 85 };
96 } 86 (impl_binary_mut<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
97 87 impl<$($qual)*> $trait<$type> for $type {
98 macro_rules! impl_binary_mut {
99 ($type:ty, $trait:ident, $fn:ident) => {
100 impl $trait<$type> for $type {
101 fn $fn(&mut self, rhs: $type) { 88 fn $fn(&mut self, rhs: $type) {
102 self.get_view_mut().$fn(rhs.get_view()) 89 self.get_view_mut().$fn(rhs.get_view())
103 } 90 }
104 } 91 }
105 92
106 impl<'b> $trait<&'b $type> for $type { 93 impl<'b, $($qual)*> $trait<&'b $type> for $type {
107 fn $fn(&mut self, rhs: $type) { 94 fn $fn(&mut self, rhs: $type) {
108 self.get_view_mut().$fn(rhs.get_view()) 95 self.get_view_mut().$fn(rhs.get_view())
109 } 96 }
110 } 97 }
111 }; 98 };
112 } 99 (impl_scalar_mut<$($qual:tt)*> $type:ty, $trait:path, $fn:ident) => {
113 100 impl<$($qual)*> $trait<F> for $type
114 macro_rules! impl_scalar_mut {
115 ($type:ty, $trait:ident, $fn:ident) => {
116 impl $trait<F> for $type
117 where 101 where
118 $type::UnwrappedMut: $trait<F>, 102 $type::UnwrappedMut: $trait<F>,
119 { 103 {
120 fn $fn(&mut self, t: F) { 104 fn $fn(&mut self, t: F) {
121 self.unwrap_mut().$fn(t) 105 self.unwrap_mut().$fn(t)
122 } 106 }
123 } 107 }
124 }; 108 };
125 }
126
127 macro_rules! wrap {
128 ($type:ty) => { 109 ($type:ty) => {
129 impl_unary!($type, Neg, neg); 110 $crate::wrap!(imp<> do $type);
130 impl_binary!($type, Add, add); 111 };
131 impl_binary!($type, Sub, sub); 112 (imp<$($qual:tt)*> $type:ty) => {
132 impl_scalar!($type, Mul, mul); 113 $crate::wrap!(impl_unary<$($qual)*> $type, std::ops::Neg, neg);
133 impl_scalar!($type, Div, div); 114 $crate::wrap!(impl_binary<$($qual)*> $type, std::ops::Add, add);
134 impl_scalar_lhs!($type, Mul, mul, f32); 115 $crate::wrap!(impl_binary<$($qual)*> $type, std::ops::Sub, sub);
135 impl_scalar_lhs!($type, Mul, mul, f64); 116 $crate::wrap!(impl_scalar<$($qual)*> $type, std::ops::Mul, mul);
136 impl_scalar_lhs!($type, Div, div, f32); 117 $crate::wrap!(impl_scalar<$($qual)*> $type, std::ops::Div, div);
137 impl_scalar_lhs!($type, Div, div, f64); 118 $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Mul, mul, f32);
138 impl_binary_mut!($type, AddAssign, add_assign); 119 $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Mul, mul, f64);
139 impl_binary_mut!($type, SubAssign, sub_assign); 120 $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Div, div, f32);
140 impl_scalar_mut!($type, MulAssign, mul_assign); 121 $crate::wrap!(impl_scalar_lhs<$($qual)*> $type, std::ops::Div, div, f64);
141 impl_scalar_mut!($type, DivAssign, div_assign); 122 $crate::wrap!(impl_binary_mut<$($qual)*> $type, std::ops::AddAssign, add_assign);
123 $crate::wrap!(impl_binary_mut<$($qual)*> $type, std::ops::SubAssign, sub_assign);
124 $crate::wrap!(impl_scalar_mut<$($qual)*> $type, std::ops::MulAssign, mul_assign);
125 $crate::wrap!(impl_scalar_mut<$($qual)*> $type, std::ops::DivAssign, div_assign);
142 126
143 /// We only support 'closed' `Euclidean` `Pair`s, as more general ones cause 127 /// We only support 'closed' `Euclidean` `Pair`s, as more general ones cause
144 /// compiler overflows. 128 /// compiler overflows.
145 impl<F: Float> Euclidean<F> for $type 129 impl<$($qual)* F: $crate::types::Float> $crate::euclidean::Euclidean<F> for $type
146 where 130 where
147 //Pair<A, B>: Euclidean<F>, 131 //Pair<A, B>: Euclidean<F>,
148 Self: Sized 132 Self: $crate::euclidean::wrap::Wrapped
149 + Mul<F, Output = <Self as AXPY>::Owned> 133 + Sized
150 + MulAssign<F> 134 + std::ops::Mul<F, Output = <Self as $crate::linops::AXPY>::Owned>
151 + Div<F, Output = <Self as AXPY>::Owned> 135 + std::ops::MulAssign<F>
152 + DivAssign<F> 136 + std::ops::Div<F, Output = <Self as $crate::linops::AXPY>::Owned>
153 + Add<Self, Output = <Self as AXPY>::Owned> 137 + std::ops::DivAssign<F>
154 + Sub<Self, Output = <Self as AXPY>::Owned> 138 + std::ops::Add<Self, Output = <Self as $crate::linops::AXPY>::Owned>
155 + for<'b> Add<&'b Self, Output = <Self as AXPY>::Owned> 139 + std::ops::Sub<Self, Output = <Self as $crate::linops::AXPY>::Owned>
156 + for<'b> Sub<&'b Self, Output = <Self as AXPY>::Owned> 140 + for<'b> std::ops::Add<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
157 + AddAssign<Self> 141 + for<'b> std::ops::Sub<&'b Self, Output = <Self as $crate::linops::AXPY>::Owned>
158 + for<'b> AddAssign<&'b Self> 142 + std::ops::AddAssign<Self>
159 + SubAssign<Self> 143 + for<'b> std::ops::AddAssign<&'b Self>
160 + for<'b> SubAssign<&'b Self> 144 + std::ops::SubAssign<Self>
161 + Neg<Output = <Self as AXPY>::Owned>, 145 + for<'b> std::ops::SubAssign<&'b Self>
162 { 146 + std::ops::Neg<Output = <Self as $crate::linops::AXPY>::Owned>,
163 fn dot<I: Instance<Self>>(&self, other: I) -> F { 147 {
148 fn dot<I: $crate::instance::Instance<Self>>(&self, other: I) -> F {
164 other.eval_decompose(|x| self.get_view().dot(x.get_view())) 149 other.eval_decompose(|x| self.get_view().dot(x.get_view()))
165 } 150 }
166 151
167 fn norm2_squared(&self) -> F { 152 fn norm2_squared(&self) -> F {
168 self.get_view().norm2_squared() 153 self.get_view().norm2_squared()
169 } 154 }
170 155
171 fn dist2_squared<I: Instance<Self>>(&self, other: I) -> F { 156 fn dist2_squared<I: $crate::instance::Instance<Self>>(&self, other: I) -> F {
172 other.eval_decompose(|x| self.get_view().dist2_squared(x.get_view())) 157 other.eval_decompose(|x| self.get_view().dist2_squared(x.get_view()))
173 } 158 }
174 } 159 }
175 160
176 impl<F> AXPY for $type 161 impl<$($qual)* F : $crate::types::Float> $crate::linops::AXPY for $type
177 where 162 where
178 Self::Unwrapped : AXPY<Field = $F>, 163 Self : $crate::euclidean::wrap::Wrapped,
179 F: Num, 164 Self::Unwrapped : $crate::linops::AXPY<Field = F>,
180 Self: MulAssign<F> + DivAssign<F>, 165 Self: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
181 Self::Unwrapped: MulAssign<F> + DivAssign<F>, 166 Self::Unwrapped: std::ops::MulAssign<F> + std::ops::DivAssign<F>,
182 { 167 {
183 type Field = F; 168 type Field = F;
184 type Owned = Pair<A::Owned, B::Owned>; 169 type Owned = Self;
185 170
186 fn axpy<I: Instance<Self>>(&mut self, α: F, x: I, β: F) { 171 fn axpy<I: $crate::instance::Instance<Self>>(&mut self, α: F, x: I, β: F) {
187 x.eval_decompose(|v| { 172 x.eval_decompose(|v| {
188 self.get_mut_view().axpy(α, v.get_view(), β) 173 self.get_mut_view().axpy(α, v.get_view(), β)
189 }) 174 })
190 } 175 }
191 176
192 fn copy_from<I: Instance<Self>>(&mut self, x: I) { 177 fn copy_from<I: $crate::instance::Instance<Self>>(&mut self, x: I) {
193 x.eval_decompose(|Pair(u, v)| { 178 x.eval_decompose(|Pair(u, v)| {
194 self.get_mut_view().copy_from(v.get_view()) 179 self.get_mut_view().copy_from(v.get_view())
195 }) 180 })
196 } 181 }
197 182
198 fn scale_from<I: Instance<Self>>(&mut self, α: F, x: I) { 183 fn scale_from<I: $crate::instance::Instance<Self>>(&mut self, α: F, x: I) {
199 x.eval_decompose(|v| { 184 x.eval_decompose(|v| {
200 self.get_mut_view().scale_from(α, v.get_view()) 185 self.get_mut_view().scale_from(α, v.get_view())
201 }) 186 })
202 } 187 }
203 188
204 /// Return a similar zero as `self`. 189 /// Return a similar zero as `self`.
205 fn similar_origin(&self) -> Self::Owned { 190 fn similar_origin(&self) -> Self::Owned {
206 $type::wrap(self.get_view().similar_origin()) 191 Self::wrap(self.get_view().similar_origin())
207 } 192 }
208 193
209 /// Set self to zero. 194 /// Set self to zero.
210 fn set_zero(&mut self) { 195 fn set_zero(&mut self) {
211 self.get_mut_view().set_zero() 196 self.get_mut_view().set_zero()
212 } 197 }
213 } 198 }
214 199
215 impl<A: Space, B: Space> Space for $type { 200 impl<$($qual)*> $crate::instance::Space for $type
216 type Decomp = $type::Unwrapped::Decomp 201 where Self : $crate::euclidean::wrap::Wrapped {
202 type Decomp = Self::Unwrapped::Decomp;
217 } 203 }
218 }; 204 };
219 } 205 }

mercurial