src/euclidean/wrap.rs

branch
dev
changeset 146
3f9a03f95457
child 147
d6009939e832
equal deleted inserted replaced
145:0b9aecd7bb76 146:3f9a03f95457
1 /*!
2 Wrappers for implemention [`Euclidean`] operations.
3 */
4
5 use super::Euclidean;
6 use crate::instance::{Decomposition, DecompositionMut, Instance, InstanceMut, MyCow};
7 use crate::linops::AXPY;
8 use crate::loc::Loc;
9 use crate::mapping::Space;
10 use crate::norms::{HasDual, Norm, NormExponent, Normed, PairNorm, L2};
11 use crate::types::{Float, Num};
12 use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
13 use serde::{Deserialize, Serialize};
14 use std::clone::Clone;
15
16 macro_rules! impl_unary {
17 ($type:ty, $trait:ident, $fn:ident) => {
18 impl $trait for $type {
19 type Output = Self::WrappedOutput;
20 fn $fn(self) -> Self::Output {
21 $type::wrap(self.get_view().$fn())
22 }
23 }
24 };
25 }
26
27 macro_rules! impl_binary {
28 ($type:ty, $trait:ident, $fn:ident) => {
29 impl $trait<$type> for $type {
30 type Output = Self::WrappedOutput;
31 fn $fn(self, other: $type) -> Self::Output {
32 $type::wrap(self.get_view().$fn(other.get_view()))
33 }
34 }
35
36 impl<'a> $trait<$type> for &'a $type {
37 type Output = Self::WrappedOutput;
38 fn $fn(self, other: $type) -> Self::Output {
39 $type::wrap(self.get_view().$fn(other.get_view()))
40 }
41 }
42
43 impl<'a, 'b> $trait<&'b $type> for &'a $type {
44 type Output = Self::WrappedOutput;
45 fn $fn(self, other: $type) -> Self::Output {
46 $type::wrap(self.get_view().$fn(other.get_view()))
47 }
48 }
49
50 impl<'b> $trait<&'b $type> for $type {
51 type Output = Self::WrappedOutput;
52 fn $fn(self, other: $type) -> Self::Output {
53 $type::wrap(self.get_view().$fn(other.get_view()))
54 }
55 }
56 };
57 }
58
59 macro_rules! impl_scalar {
60 ($type:ty, $trait:ident, $fn:ident) => {
61 impl<F: Num> $trait<F> for $type,
62 where
63 $type::Unwrapped: $trait<F>,
64 {
65 type Output = Self::WrappedOutput;
66 fn $fn(self, t: F) -> Self::Output {
67 $type::wrap(self.get_view().$fn(t))
68 }
69 }
70
71 impl<'a, F: Num> $trait<F> for &'a $type,
72 where
73 $type::Unwrapped: $trait<F>,
74 {
75 type Output = Self::WrappedOutput;
76 fn $fn(self, t: F) -> Self::Output {
77 $type::wrap(self.get_view().$fn(t))
78 }
79 }
80
81 };
82 }
83
84 macro_rules! impl_scalar_lhs {
85 ($type:ty, $trait:ident, $fn:ident, $F:ty) => {
86 impl<A, B> $trait<$type> for $F
87 where
88 $F: $type::Unwrapped,
89 {
90 type Output = Self::WrappedOutput;
91 fn $fn(self, rhs: $type) -> Self::Output {
92 $type::wrap(self.$fn(rhs.get_view()))
93 }
94 }
95 };
96 }
97
98 macro_rules! impl_binary_mut {
99 ($type:ty, $trait:ident, $fn:ident) => {
100 impl $trait<$type> for $type {
101 fn $fn(&mut self, rhs: $type) {
102 self.get_view_mut().$fn(rhs.get_view())
103 }
104 }
105
106 impl<'b> $trait<&'b $type> for $type {
107 fn $fn(&mut self, rhs: $type) {
108 self.get_view_mut().$fn(rhs.get_view())
109 }
110 }
111 };
112 }
113
114 macro_rules! impl_scalar_mut {
115 ($type:ty, $trait:ident, $fn:ident) => {
116 impl $trait<F> for $type
117 where
118 $type::UnwrappedMut: $trait<F>,
119 {
120 fn $fn(&mut self, t: F) {
121 self.unwrap_mut().$fn(t)
122 }
123 }
124 };
125 }
126
127 macro_rules! wrap {
128 ($type:ty) => {
129 impl_unary!($type, Neg, neg);
130 impl_binary!($type, Add, add);
131 impl_binary!($type, Sub, sub);
132 impl_scalar!($type, Mul, mul);
133 impl_scalar!($type, Div, div);
134 impl_scalar_lhs!($type, Mul, mul, f32);
135 impl_scalar_lhs!($type, Mul, mul, f64);
136 impl_scalar_lhs!($type, Div, div, f32);
137 impl_scalar_lhs!($type, Div, div, f64);
138 impl_binary_mut!($type, AddAssign, add_assign);
139 impl_binary_mut!($type, SubAssign, sub_assign);
140 impl_scalar_mut!($type, MulAssign, mul_assign);
141 impl_scalar_mut!($type, DivAssign, div_assign);
142
143 /// We only support 'closed' `Euclidean` `Pair`s, as more general ones cause
144 /// compiler overflows.
145 impl<F: Float> Euclidean<F> for $type
146 where
147 //Pair<A, B>: Euclidean<F>,
148 Self: Sized
149 + Mul<F, Output = <Self as AXPY>::Owned>
150 + MulAssign<F>
151 + Div<F, Output = <Self as AXPY>::Owned>
152 + DivAssign<F>
153 + Add<Self, Output = <Self as AXPY>::Owned>
154 + Sub<Self, Output = <Self as AXPY>::Owned>
155 + for<'b> Add<&'b Self, Output = <Self as AXPY>::Owned>
156 + for<'b> Sub<&'b Self, Output = <Self as AXPY>::Owned>
157 + AddAssign<Self>
158 + for<'b> AddAssign<&'b Self>
159 + SubAssign<Self>
160 + for<'b> SubAssign<&'b Self>
161 + Neg<Output = <Self as AXPY>::Owned>,
162 {
163 fn dot<I: Instance<Self>>(&self, other: I) -> F {
164 other.eval_decompose(|x| self.get_view().dot(x.get_view()))
165 }
166
167 fn norm2_squared(&self) -> F {
168 self.get_view().norm2_squared()
169 }
170
171 fn dist2_squared<I: Instance<Self>>(&self, other: I) -> F {
172 other.eval_decompose(|x| self.get_view().dist2_squared(x.get_view()))
173 }
174 }
175
176 impl<F> AXPY for $type
177 where
178 Self::Unwrapped : AXPY<Field = $F>,
179 F: Num,
180 Self: MulAssign<F> + DivAssign<F>,
181 Self::Unwrapped: MulAssign<F> + DivAssign<F>,
182 {
183 type Field = F;
184 type Owned = Pair<A::Owned, B::Owned>;
185
186 fn axpy<I: Instance<Self>>(&mut self, α: F, x: I, β: F) {
187 x.eval_decompose(|v| {
188 self.get_mut_view().axpy(α, v.get_view(), β)
189 })
190 }
191
192 fn copy_from<I: Instance<Self>>(&mut self, x: I) {
193 x.eval_decompose(|Pair(u, v)| {
194 self.get_mut_view().copy_from(v.get_view())
195 })
196 }
197
198 fn scale_from<I: Instance<Self>>(&mut self, α: F, x: I) {
199 x.eval_decompose(|v| {
200 self.get_mut_view().scale_from(α, v.get_view())
201 })
202 }
203
204 /// Return a similar zero as `self`.
205 fn similar_origin(&self) -> Self::Owned {
206 $type::wrap(self.get_view().similar_origin())
207 }
208
209 /// Set self to zero.
210 fn set_zero(&mut self) {
211 self.get_mut_view().set_zero()
212 }
213 }
214
215 impl<A: Space, B: Space> Space for $type {
216 type Decomp = $type::Unwrapped::Decomp
217 }
218 };
219 }

mercurial