Sat, 19 Oct 2024 10:46:13 -0500
Some distance functions etc.
src/cube.rs | file | annotate | diff | comparison | revisions | |
src/dist.rs | file | annotate | diff | comparison | revisions | |
src/fb.rs | file | annotate | diff | comparison | revisions | |
src/main.rs | file | annotate | diff | comparison | revisions | |
src/manifold.rs | file | annotate | diff | comparison | revisions |
--- a/src/cube.rs Fri Oct 18 19:52:06 2024 -0500 +++ b/src/cube.rs Sat Oct 19 10:46:13 2024 -0500 @@ -191,6 +191,23 @@ point : Point, } +impl OnCube { + /// Calculates both the logarithmic map and distance to another point + fn log_dist(&self, other : &Self) -> (<Self as ManifoldPoint>::Tangent, f64) { + let mut best_len = f64::INFINITY; + let mut best_tan = Loc([0.0, 0.0]); + for path in self.face.paths(other.face) { + let tan = self.face.convert(&path, &other.point) - &self.point; + let len = tan.norm(L2); + if len < best_len { + best_tan = tan; + best_len = len; + } + } + (best_tan, best_len) + } +} + impl ManifoldPoint for OnCube { type Tangent = Point; @@ -209,17 +226,11 @@ } fn log(&self, other : &Self) -> Self::Tangent { - let mut best_len = f64::INFINITY; - let mut best_tan = Loc([0.0, 0.0]); - for path in self.face.paths(other.face) { - let tan = self.face.convert(&path, &other.point) - &self.point; - let len = tan.norm(L2); - if len < best_len { - best_tan = tan; - best_len = len; - } - } - best_tan + self.log_dist(other).0 + } + + fn dist_to(&self, other : &Self) -> f64 { + self.log_dist(other).1 } }
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/dist.rs Sat Oct 19 10:46:13 2024 -0500 @@ -0,0 +1,60 @@ + +use alg_tools::mapping::Apply; +use crate::manifold::ManifoldPoint; +use crate::fb::{Grad, Desc}; + +/// Structure for distance-to functions +pub struct DistTo<M : ManifoldPoint> { + base : M +} + +impl<M : ManifoldPoint> Apply<M> for DistTo<M> { + type Output = f64; + + fn apply(&self, x : M) -> Self::Output { + self.base.dist_to(&x) + } +} + +impl<'a, M : ManifoldPoint> Apply<&'a M> for DistTo<M> { + type Output = f64; + + fn apply(&self, x : &'a M) -> Self::Output { + self.base.dist_to(x) + } +} + +/// Structure for distance-to functions +pub struct DistToSquaredDiv2<M : ManifoldPoint> { + base : M +} + +impl<M : ManifoldPoint> Apply<M> for DistToSquaredDiv2<M> { + type Output = f64; + + fn apply(&self, x : M) -> Self::Output { + let d = self.base.dist_to(&x); + d*d / 2.0 + } +} + +impl<'a, M : ManifoldPoint> Apply<&'a M> for DistToSquaredDiv2<M> { + type Output = f64; + + fn apply(&self, x : &'a M) -> Self::Output { + let d = self.base.dist_to(x); + d*d / 2.0 + } +} + +impl<M : ManifoldPoint> Desc<M> for DistToSquaredDiv2<M> { + fn desc(&self, τ : f64, x : M) -> M { + x.exp(&(self.grad(&x) * τ)) + } +} + +impl<M : ManifoldPoint> Grad<M> for DistToSquaredDiv2<M> { + fn grad(&self, x : &M) -> M::Tangent { + x.log(&self.base) + } +}
--- a/src/fb.rs Fri Oct 18 19:52:06 2024 -0500 +++ b/src/fb.rs Sat Oct 19 10:46:13 2024 -0500 @@ -2,17 +2,41 @@ use alg_tools::iterate::AlgIteratorFactory; use alg_tools::mapping::{Mapping, Sum}; use serde::Serialize; +use std::iter::Sum as SumTrait; use crate::manifold::ManifoldPoint; +/// Trait for function objects that implement gradients +pub trait Grad<M : ManifoldPoint> { + fn grad(&self, x : &M) -> M::Tangent; +} + /// Trait for function objects that implement gradient steps pub trait Desc<M : ManifoldPoint> { - fn desc(&self, τ : f64, pt : M) -> M; + fn desc(&self, τ : f64, x : M) -> M; +} + +/*impl<M : ManifoldPoint, T : Grad<M>> Desc<M> for T { + fn desc(&self, τ : f64, x : M) -> M { + x.exp(self.grad(x) * τ) + } +}*/ + +impl<M, T > Desc<M> for Sum<M, T> +where M : ManifoldPoint, + T : Grad<M> + Mapping<M, Codomain=f64>, + M::Tangent : SumTrait { + fn desc(&self, τ : f64, x : M) -> M { + let t : M::Tangent = self.iter() + .map(|f| f.grad(&x)) + .sum(); + x.exp(&(t * τ)) + } } /// Trait for function objects that implement proximal steps pub trait Prox<M : ManifoldPoint> { - fn prox(&self, τ : f64, pt : M) -> M; + fn prox(&self, τ : f64, x : M) -> M; } #[derive(Clone,Debug,Serialize)]
--- a/src/main.rs Fri Oct 18 19:52:06 2024 -0500 +++ b/src/main.rs Sat Oct 19 10:46:13 2024 -0500 @@ -8,6 +8,7 @@ mod manifold; mod fb; mod cube; +mod dist; fn main() {
--- a/src/manifold.rs Fri Oct 18 19:52:06 2024 -0500 +++ b/src/manifold.rs Sat Oct 19 10:46:13 2024 -0500 @@ -1,3 +1,5 @@ + +use alg_tools::euclidean::Euclidean; /// A point on a manifold pub trait ManifoldPoint : Clone + PartialEq { @@ -9,4 +11,8 @@ /// Logarithmic map fn log(&self, other : &Self) -> Self::Tangent; + + /// Distance to `other` + fn dist_to(&self, other : &Self) -> f64; } +