Sat, 22 Oct 2022 18:12:49 +0300
Fix some unit tests after fundamental changes that made them invalid
0 | 1 | ///! Integration with nalgebra |
2 | ||
3 | use nalgebra::{ | |
4 | Matrix, Storage, StorageMut, OMatrix, Dim, DefaultAllocator, Scalar, | |
5 | ClosedMul, ClosedAdd, SimdComplexField, Vector, OVector, RealField, | |
6 | LpNorm, UniformNorm | |
7 | }; | |
8 | use nalgebra::Norm as NalgebraNorm; | |
9 | use nalgebra::base::constraint::{ | |
10 | ShapeConstraint, SameNumberOfRows, SameNumberOfColumns | |
11 | }; | |
12 | use nalgebra::base::dimension::*; | |
13 | use nalgebra::base::allocator::Allocator; | |
14 | use std::ops::Mul; | |
15 | use num_traits::identities::{Zero, One}; | |
16 | use crate::linops::*; | |
17 | use crate::norms::Dot; | |
18 | use crate::types::Float; | |
19 | use crate::norms::*; | |
20 | ||
21 | impl<SM,SV,N,M,K,E> Linear<Matrix<E,M,K,SV>> for Matrix<E,N,M,SM> | |
22 | where SM: Storage<E,N,M>, SV: Storage<E,M,K>, | |
23 | N : Dim, M : Dim, K : Dim, E : Scalar + ClosedMul + ClosedAdd + Zero + One, | |
24 | DefaultAllocator : Allocator<E,N,K>, | |
25 | DefaultAllocator : Allocator<E,M,K>, | |
26 | DefaultAllocator : Allocator<E,N,M>, | |
27 | DefaultAllocator : Allocator<E,M,N> { | |
28 | type Codomain = OMatrix<E,N,K>; | |
29 | ||
30 | #[inline] | |
31 | fn apply(&self, x : &Matrix<E,M,K,SV>) -> Self::Codomain { | |
32 | self.mul(x) | |
33 | } | |
34 | } | |
35 | ||
36 | impl<SM,SV1,SV2,N,M,K,E> GEMV<E, Matrix<E,M,K,SV1>, Matrix<E,N,K,SV2>> for Matrix<E,N,M,SM> | |
37 | where SM: Storage<E,N,M>, SV1: Storage<E,M,K>, SV2: StorageMut<E,N,K>, | |
38 | N : Dim, M : Dim, K : Dim, E : Scalar + ClosedMul + ClosedAdd + Zero + One + Float, | |
39 | DefaultAllocator : Allocator<E,N,K>, | |
40 | DefaultAllocator : Allocator<E,M,K>, | |
41 | DefaultAllocator : Allocator<E,N,M>, | |
42 | DefaultAllocator : Allocator<E,M,N> { | |
43 | ||
44 | #[inline] | |
45 | fn gemv(&self, y : &mut Matrix<E,N,K,SV2>, α : E, x : &Matrix<E,M,K,SV1>, β : E) { | |
46 | Matrix::gemm(y, α, self, x, β) | |
47 | } | |
48 | ||
49 | #[inline] | |
50 | fn apply_mut<'a>(&self, y : &mut Matrix<E,N,K,SV2>, x : &Matrix<E,M,K,SV1>) { | |
51 | self.mul_to(x, y) | |
52 | } | |
53 | } | |
54 | ||
55 | impl<SM,SV1,M,E> AXPY<E, Vector<E,M,SV1>> for Vector<E,M,SM> | |
56 | where SM: StorageMut<E,M>, SV1: Storage<E,M>, | |
57 | M : Dim, E : Scalar + ClosedMul + ClosedAdd + Zero + One + Float, | |
58 | DefaultAllocator : Allocator<E,M> { | |
59 | ||
60 | #[inline] | |
61 | fn axpy(&mut self, α : E, x : &Vector<E,M,SV1>, β : E) { | |
62 | Matrix::axpy(self, α, x, β) | |
63 | } | |
64 | ||
65 | #[inline] | |
66 | fn copy_from(&mut self, y : &Vector<E,M,SV1>) { | |
67 | Matrix::copy_from(self, y) | |
68 | } | |
69 | } | |
70 | ||
71 | impl<SM,M,E> Projection<E, Linfinity> for Vector<E,M,SM> | |
72 | where SM: StorageMut<E,M>, | |
73 | M : Dim, E : Scalar + ClosedMul + ClosedAdd + Zero + One + Float + RealField, | |
74 | DefaultAllocator : Allocator<E,M> { | |
75 | #[inline] | |
76 | fn proj_ball_mut(&mut self, ρ : E, _ : Linfinity) { | |
77 | self.iter_mut().for_each(|v| *v = num_traits::clamp(*v, -ρ, ρ)) | |
78 | } | |
79 | } | |
80 | ||
81 | impl<'own,SV1,SV2,SM,N,M,K,E> Adjointable<Matrix<E,M,K,SV1>,Matrix<E,N,K,SV2>> | |
82 | for Matrix<E,N,M,SM> | |
83 | where SM: Storage<E,N,M>, SV1: Storage<E,M,K>, SV2: Storage<E,N,K>, | |
84 | N : Dim, M : Dim, K : Dim, E : Scalar + ClosedMul + ClosedAdd + Zero + One + SimdComplexField, | |
85 | DefaultAllocator : Allocator<E,N,K>, | |
86 | DefaultAllocator : Allocator<E,M,K>, | |
87 | DefaultAllocator : Allocator<E,N,M>, | |
88 | DefaultAllocator : Allocator<E,M,N> { | |
89 | type AdjointCodomain = OMatrix<E,M,K>; | |
90 | type Adjoint<'a> = OMatrix<E,M,N> where SM : 'a; | |
91 | ||
92 | #[inline] | |
93 | fn adjoint(&self) -> Self::Adjoint<'_> { | |
94 | Matrix::adjoint(self) | |
95 | } | |
96 | } | |
97 | ||
98 | impl<E,M,S,Si> Dot<Vector<E,M,Si>,E> | |
99 | for Vector<E,M,S> | |
100 | where M : Dim, | |
101 | E : Float + Scalar + ClosedMul + ClosedAdd + Zero + One, | |
102 | S : Storage<E,M>, | |
103 | Si : Storage<E,M>, | |
104 | DefaultAllocator : Allocator<E,M> { | |
105 | ||
106 | #[inline] | |
107 | fn dot(&self, other : &Vector<E,M,Si>) -> E { | |
108 | Vector::<E,M,S>::dot(self, other) | |
109 | } | |
110 | } | |
111 | ||
112 | /// This function is [`nalgebra::EuclideanNorm::metric_distance`] without the `sqrt`. | |
113 | #[inline] | |
114 | fn metric_distance_squared<T, R1, C1, S1, R2, C2, S2>( | |
115 | /*ed: &EuclideanNorm,*/ | |
116 | m1: &Matrix<T, R1, C1, S1>, | |
117 | m2: &Matrix<T, R2, C2, S2>, | |
118 | ) -> T::SimdRealField | |
119 | where | |
120 | T: SimdComplexField, | |
121 | R1: Dim, | |
122 | C1: Dim, | |
123 | S1: Storage<T, R1, C1>, | |
124 | R2: Dim, | |
125 | C2: Dim, | |
126 | S2: Storage<T, R2, C2>, | |
127 | ShapeConstraint: SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C2>, | |
128 | { | |
129 | m1.zip_fold(m2, T::SimdRealField::zero(), |acc, a, b| { | |
130 | let diff = a - b; | |
131 | acc + diff.simd_modulus_squared() | |
132 | }) | |
133 | } | |
134 | ||
135 | // TODO: should allow different input storages in `Euclidean`. | |
136 | ||
137 | impl<E,M,S> Euclidean<E> | |
138 | for Vector<E,M,S> | |
139 | where M : Dim, | |
140 | S : StorageMut<E,M>, | |
141 | E : Float + Scalar + ClosedMul + ClosedAdd + Zero + One + RealField, | |
142 | DefaultAllocator : Allocator<E,M> { | |
143 | ||
144 | type Output = OVector<E, M>; | |
145 | ||
146 | #[inline] | |
147 | fn similar_origin(&self) -> OVector<E, M> { | |
148 | OVector::zeros_generic(M::from_usize(self.len()), Const) | |
149 | } | |
150 | ||
151 | #[inline] | |
152 | fn norm2_squared(&self) -> E { | |
153 | Vector::<E,M,S>::norm_squared(self) | |
154 | } | |
155 | ||
156 | #[inline] | |
157 | fn dist2_squared(&self, other : &Self) -> E { | |
158 | metric_distance_squared(self, other) | |
159 | } | |
160 | } | |
161 | ||
162 | impl<E,M,S> StaticEuclidean<E> | |
163 | for Vector<E,M,S> | |
164 | where M : DimName, | |
165 | S : StorageMut<E,M>, | |
166 | E : Float + Scalar + ClosedMul + ClosedAdd + Zero + One + RealField, | |
167 | DefaultAllocator : Allocator<E,M> { | |
168 | ||
169 | #[inline] | |
170 | fn origin() -> OVector<E, M> { | |
171 | OVector::zeros() | |
172 | } | |
173 | } | |
174 | ||
175 | impl<E,M,S> Norm<E, L1> | |
176 | for Vector<E,M,S> | |
177 | where M : Dim, | |
178 | S : StorageMut<E,M>, | |
179 | E : Float + Scalar + ClosedMul + ClosedAdd + Zero + One + RealField, | |
180 | DefaultAllocator : Allocator<E,M> { | |
181 | ||
182 | #[inline] | |
183 | fn norm(&self, _ : L1) -> E { | |
184 | LpNorm(1).norm(self) | |
185 | } | |
186 | } | |
187 | ||
188 | impl<E,M,S> Dist<E, L1> | |
189 | for Vector<E,M,S> | |
190 | where M : Dim, | |
191 | S : StorageMut<E,M>, | |
192 | E : Float + Scalar + ClosedMul + ClosedAdd + Zero + One + RealField, | |
193 | DefaultAllocator : Allocator<E,M> { | |
194 | #[inline] | |
195 | fn dist(&self, other : &Self, _ : L1) -> E { | |
196 | LpNorm(1).metric_distance(self, other) | |
197 | } | |
198 | } | |
199 | ||
200 | impl<E,M,S> Norm<E, L2> | |
201 | for Vector<E,M,S> | |
202 | where M : Dim, | |
203 | S : StorageMut<E,M>, | |
204 | E : Float + Scalar + ClosedMul + ClosedAdd + Zero + One + RealField, | |
205 | DefaultAllocator : Allocator<E,M> { | |
206 | ||
207 | #[inline] | |
208 | fn norm(&self, _ : L2) -> E { | |
209 | LpNorm(2).norm(self) | |
210 | } | |
211 | } | |
212 | ||
213 | impl<E,M,S> Dist<E, L2> | |
214 | for Vector<E,M,S> | |
215 | where M : Dim, | |
216 | S : StorageMut<E,M>, | |
217 | E : Float + Scalar + ClosedMul + ClosedAdd + Zero + One + RealField, | |
218 | DefaultAllocator : Allocator<E,M> { | |
219 | #[inline] | |
220 | fn dist(&self, other : &Self, _ : L2) -> E { | |
221 | LpNorm(2).metric_distance(self, other) | |
222 | } | |
223 | } | |
224 | ||
225 | impl<E,M,S> Norm<E, Linfinity> | |
226 | for Vector<E,M,S> | |
227 | where M : Dim, | |
228 | S : StorageMut<E,M>, | |
229 | E : Float + Scalar + ClosedMul + ClosedAdd + Zero + One + RealField, | |
230 | DefaultAllocator : Allocator<E,M> { | |
231 | ||
232 | #[inline] | |
233 | fn norm(&self, _ : Linfinity) -> E { | |
234 | UniformNorm.norm(self) | |
235 | } | |
236 | } | |
237 | ||
238 | impl<E,M,S> Dist<E, Linfinity> | |
239 | for Vector<E,M,S> | |
240 | where M : Dim, | |
241 | S : StorageMut<E,M>, | |
242 | E : Float + Scalar + ClosedMul + ClosedAdd + Zero + One + RealField, | |
243 | DefaultAllocator : Allocator<E,M> { | |
244 | #[inline] | |
245 | fn dist(&self, other : &Self, _ : Linfinity) -> E { | |
246 | UniformNorm.metric_distance(self, other) | |
247 | } | |
248 | } | |
249 | ||
250 | /// Helper trait to hide the symbols of `nalgebra::RealField` | |
251 | /// while allowing nalgebra to be used in subroutines. | |
252 | pub trait ToNalgebraRealField : Float { | |
253 | type NalgebraType : RealField; | |
254 | type MixedType : RealField + Float; | |
255 | ||
256 | fn to_nalgebra(self) -> Self::NalgebraType; | |
257 | fn to_nalgebra_mixed(self) -> Self::MixedType; | |
258 | ||
259 | fn from_nalgebra(t : Self::NalgebraType) -> Self; | |
260 | fn from_nalgebra_mixed(t : Self::MixedType) -> Self; | |
261 | } | |
262 | ||
263 | impl ToNalgebraRealField for f32 { | |
264 | type NalgebraType = f32; | |
265 | type MixedType = f32; | |
266 | ||
267 | #[inline] | |
268 | fn to_nalgebra(self) -> Self::NalgebraType { self } | |
269 | ||
270 | #[inline] | |
271 | fn to_nalgebra_mixed(self) -> Self::MixedType { self } | |
272 | ||
273 | #[inline] | |
274 | fn from_nalgebra(t : Self::NalgebraType) -> Self { t } | |
275 | ||
276 | #[inline] | |
277 | fn from_nalgebra_mixed(t : Self::MixedType) -> Self { t } | |
278 | ||
279 | } | |
280 | ||
281 | impl ToNalgebraRealField for f64 { | |
282 | type NalgebraType = f64; | |
283 | type MixedType = f64; | |
284 | ||
285 | #[inline] | |
286 | fn to_nalgebra(self) -> Self::NalgebraType { self } | |
287 | ||
288 | #[inline] | |
289 | fn to_nalgebra_mixed(self) -> Self::MixedType { self } | |
290 | ||
291 | #[inline] | |
292 | fn from_nalgebra(t : Self::NalgebraType) -> Self { t } | |
293 | ||
294 | #[inline] | |
295 | fn from_nalgebra_mixed(t : Self::MixedType) -> Self { t } | |
296 | } | |
297 |