src/norms.rs

branch
dev
changeset 71
511bf440e24b
parent 70
672aec2e1acd
child 72
44a4f258a1ff
equal deleted inserted replaced
70:672aec2e1acd 71:511bf440e24b
48 /// Exponent type for 2,1-[`Norm`]. 48 /// Exponent type for 2,1-[`Norm`].
49 /// (1-norm over a domain Ω, 2-norm of a vector at each point of the domain.) 49 /// (1-norm over a domain Ω, 2-norm of a vector at each point of the domain.)
50 #[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)] 50 #[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)]
51 pub struct L21; 51 pub struct L21;
52 impl NormExponent for L21 {} 52 impl NormExponent for L21 {}
53
54 impl<C : Constant, E : NormExponent> NormExponent for Weighted<E, C> {}
55 53
56 /// Norms for pairs (a, b). ‖(a,b)‖ = ‖(‖a‖_A, ‖b‖_B)‖_J 54 /// Norms for pairs (a, b). ‖(a,b)‖ = ‖(‖a‖_A, ‖b‖_B)‖_J
57 /// For use with [`crate::direct_product::Pair`] 55 /// For use with [`crate::direct_product::Pair`]
58 #[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)] 56 #[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)]
59 pub struct PairNorm<A, B, J>(pub A, pub B, pub J); 57 pub struct PairNorm<A, B, J>(pub A, pub B, pub J);
177 fn dist<I : Instance<Self>>(&self, other : I, huber : HuberL1<F>) -> F { 175 fn dist<I : Instance<Self>>(&self, other : I, huber : HuberL1<F>) -> F {
178 huber.apply(self.dist2_squared(other)) 176 huber.apply(self.dist2_squared(other))
179 } 177 }
180 } 178 }
181 179
182 impl<C, F, E, D> Norm<F, Weighted<E, C>> for D
183 where
184 F : Float,
185 D : Norm<F, E>,
186 C : Constant<Type = F>,
187 E : NormExponent,
188 {
189 fn norm(&self, e : Weighted<E, C>) -> F {
190 let v = e.weight.value();
191 assert!(v > F::ZERO);
192 v * self.norm(e.base_fn)
193 }
194 }
195
196 // impl<F : Float, E : Norm<F, L2>> Norm<F, L21> for Vec<E> { 180 // impl<F : Float, E : Norm<F, L2>> Norm<F, L21> for Vec<E> {
197 // fn norm(&self, _l21 : L21) -> F { 181 // fn norm(&self, _l21 : L21) -> F {
198 // self.iter().map(|e| e.norm(L2)).sum() 182 // self.iter().map(|e| e.norm(L2)).sum()
199 // } 183 // }
200 // } 184 // }
281 fn dual_exponent(&self) -> Self::DualExp { 265 fn dual_exponent(&self) -> Self::DualExp {
282 L1 266 L1
283 } 267 }
284 } 268 }
285 269
286 impl<C : Constant, E : HasDualExponent> HasDualExponent for Weighted<E, C> { 270 #[macro_export]
287 type DualExp = Weighted<E::DualExp, C::Type>; 271 macro_rules! impl_weighted_norm {
288 272 ($exponent : ty) => {
289 fn dual_exponent(&self) -> Self::DualExp { 273 impl<C, F, D> Norm<F, Weighted<$exponent, C>> for D
290 Weighted { 274 where
291 weight : C::Type::ONE / self.weight.value(), 275 F : Float,
292 base_fn : self.base_fn.dual_exponent() 276 D : Norm<F, $exponent>,
277 C : Constant<Type = F>,
278 {
279 fn norm(&self, e : Weighted<$exponent, C>) -> F {
280 let v = e.weight.value();
281 assert!(v > F::ZERO);
282 v * self.norm(e.base_fn)
283 }
293 } 284 }
294 } 285
295 } 286 impl<C : Constant> NormExponent for Weighted<$exponent, C> {}
296 287
297 impl<C, F, E, T> Projection<F, Weighted<E, C>> for T 288 impl<C : Constant> HasDualExponent for Weighted<$exponent, C>
298 where 289 where $exponent : HasDualExponent {
299 T : Projection<F, E>, 290 type DualExp = Weighted<<$exponent as HasDualExponent>::DualExp, C::Type>;
300 F : Float, 291
301 C : Constant<Type = F>, 292 fn dual_exponent(&self) -> Self::DualExp {
302 E : NormExponent, 293 Weighted {
303 { 294 weight : C::Type::ONE / self.weight.value(),
304 fn proj_ball(self, ρ : F, q : Weighted<E, C>) -> Self { 295 base_fn : self.base_fn.dual_exponent()
305 self.proj_ball(ρ / q.weight.value(), q.base_fn) 296 }
306 } 297 }
307 298 }
308 fn proj_ball_mut(&mut self, ρ : F, q : Weighted<E, C>) { 299
309 self.proj_ball_mut(ρ / q.weight.value(), q.base_fn) 300 impl<C, F, T> Projection<F, Weighted<$exponent , C>> for T
310 } 301 where
311 } 302 T : Projection<F, $exponent >,
303 F : Float,
304 C : Constant<Type = F>,
305 {
306 fn proj_ball(self, ρ : F, q : Weighted<$exponent , C>) -> Self {
307 self.proj_ball(ρ / q.weight.value(), q.base_fn)
308 }
309
310 fn proj_ball_mut(&mut self, ρ : F, q : Weighted<$exponent , C>) {
311 self.proj_ball_mut(ρ / q.weight.value(), q.base_fn)
312 }
313 }
314 }
315 }
316
317 //impl_weighted_norm!(L1);
318 //impl_weighted_norm!(L2);
319 //impl_weighted_norm!(Linfinity);
320

mercurial