A simple test

Mon, 21 Oct 2024 10:02:57 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Mon, 21 Oct 2024 10:02:57 -0500
changeset 7
8979a6638424
parent 6
df9628092285
child 8
17d71ca4ce84

A simple test

Cargo.toml file | annotate | diff | comparison | revisions
src/cube.rs file | annotate | diff | comparison | revisions
src/dist.rs file | annotate | diff | comparison | revisions
src/fb.rs file | annotate | diff | comparison | revisions
src/main.rs file | annotate | diff | comparison | revisions
src/manifold.rs file | annotate | diff | comparison | revisions
--- a/Cargo.toml	Mon Oct 21 08:44:23 2024 -0500
+++ b/Cargo.toml	Mon Oct 21 10:02:57 2024 -0500
@@ -20,3 +20,4 @@
 [dependencies]
 serde = { version = "1.0", features = ["derive"] }
 alg_tools = { version = "~0.3.0-dev", path = "../alg_tools", default-features = false }
+colored = "~2.0.0"
--- a/src/cube.rs	Mon Oct 21 08:44:23 2024 -0500
+++ b/src/cube.rs	Mon Oct 21 10:02:57 2024 -0500
@@ -3,7 +3,7 @@
 
 use alg_tools::loc::Loc;
 use alg_tools::norms::{Norm, L2};
-use crate::manifold::ManifoldPoint;
+use crate::manifold::{ManifoldPoint, EmbeddedManifoldPoint};
 
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum Face {F1, F2, F3, F4, F5, F6}
@@ -183,6 +183,19 @@
             (Equal, Greater) => self.adjacent_faces()[3],
         }
     }
+
+    /// Get embedded 3D coordinates
+    pub fn embedded_coords(&self, p : &Point) -> Loc<f64, 3> {
+        let &Loc([x, y]) = p;
+        Loc(match *self {
+            F1 => [x, y, 0.0],
+            F2 => [1.0, x, y],
+            F3 => [0.0, 1.0-x, y],
+            F4 => [x, 0.0, y],
+            F5 => [1.0 - x, 1.0, y],
+            F6 => [x, y, 1.0],
+        })
+    }
 }
 
 #[derive(Clone, Debug, PartialEq)]
@@ -192,6 +205,13 @@
 }
 
 impl OnCube {
+    /// Creates a new point on the cube, given a face and face-relative coordinates
+    /// in [0, 1]^2
+    pub fn new(face : Face, point : Point) -> Self {
+        assert!(face.is_in_face(&point));
+        OnCube { face, point }
+    }
+
     /// Calculates both the logarithmic map and distance to another point
     fn log_dist(&self, other : &Self) -> (<Self as ManifoldPoint>::Tangent, f64) {
         let mut best_len = f64::INFINITY;
@@ -208,6 +228,16 @@
     }
 }
 
+
+impl EmbeddedManifoldPoint for OnCube {
+    type EmbeddedCoords = Loc<f64, 3>;
+
+    /// Get embedded 3D coordinates
+    fn embedded_coords(&self) -> Loc<f64, 3> {
+        self.face.embedded_coords(&self.point)
+    }
+}
+
 impl ManifoldPoint for OnCube {
     type Tangent = Point;
 
--- a/src/dist.rs	Mon Oct 21 08:44:23 2024 -0500
+++ b/src/dist.rs	Mon Oct 21 10:02:57 2024 -0500
@@ -4,15 +4,13 @@
 use crate::fb::{Grad, Desc};
 
 /// Structure for distance-to functions
-pub struct DistTo<M : ManifoldPoint> {
-    base : M
-}
+pub struct DistTo<M : ManifoldPoint>(pub M);
 
 impl<M : ManifoldPoint> Apply<M> for DistTo<M> {
     type Output = f64;
 
     fn apply(&self, x : M) -> Self::Output {
-        self.base.dist_to(&x)
+        self.0.dist_to(&x)
     }
 }
 
@@ -20,20 +18,18 @@
     type Output = f64;
 
     fn apply(&self, x : &'a M) -> Self::Output {
-        self.base.dist_to(x)
+        self.0.dist_to(x)
     }
 }
 
 /// Structure for distance-to functions
-pub struct DistToSquaredDiv2<M : ManifoldPoint> {
-    base : M
-}
+pub struct DistToSquaredDiv2<M : ManifoldPoint>(pub M);
 
 impl<M : ManifoldPoint> Apply<M> for DistToSquaredDiv2<M> {
     type Output = f64;
 
     fn apply(&self, x : M) -> Self::Output {
-        let d = self.base.dist_to(&x);
+        let d = self.0.dist_to(&x);
         d*d / 2.0
     }
 }
@@ -42,7 +38,7 @@
     type Output = f64;
 
     fn apply(&self, x : &'a M) -> Self::Output {
-        let d = self.base.dist_to(x);
+        let d = self.0.dist_to(x);
         d*d / 2.0
     }
 }
@@ -55,6 +51,6 @@
 
 impl<M : ManifoldPoint> Grad<M> for DistToSquaredDiv2<M> {
     fn grad(&self, x : &M) -> M::Tangent {
-       x.log(&self.base)
+       x.log(&self.0)
     }
 }
--- a/src/fb.rs	Mon Oct 21 08:44:23 2024 -0500
+++ b/src/fb.rs	Mon Oct 21 10:02:57 2024 -0500
@@ -1,10 +1,11 @@
 
-use alg_tools::iterate::AlgIteratorFactory;
+use alg_tools::iterate::{AlgIteratorFactory, LogRepr};
 use alg_tools::mapping::{Mapping, Sum};
 use serde::Serialize;
 use std::iter::Sum as SumTrait;
+use colored::ColoredString;
 
-use crate::manifold::ManifoldPoint;
+use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint};
 
 /// Trait for function objects that implement gradients
 pub trait Grad<M : ManifoldPoint> {
@@ -45,6 +46,15 @@
     point : M,
 }
 
+impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> {
+    fn logrepr(&self) -> ColoredString {
+        format!("{}\t {}",
+            self.value,
+            self.point.embedded_coords()
+        ).as_str().into()
+    }
+}
+
 pub fn forward_backward<M, F, G, I>(
     f : &F,
     g : &G,
@@ -52,7 +62,7 @@
     τ : f64,
     iterator : I
 ) -> M
-where M : ManifoldPoint,
+where M : ManifoldPoint + EmbeddedManifoldPoint,
       F : Desc<M> +  Mapping<M, Codomain = f64>,
       G : Prox<M> +  Mapping<M, Codomain = f64>,
       I : AlgIteratorFactory<IterInfo<M>> {
--- a/src/main.rs	Mon Oct 21 08:44:23 2024 -0500
+++ b/src/main.rs	Mon Oct 21 10:02:57 2024 -0500
@@ -5,6 +5,10 @@
 #![allow(mixed_script_confusables)]
 #![allow(confusable_idents)]
 
+use dist::DistToSquaredDiv2;
+use fb::forward_backward;
+use manifold::EmbeddedManifoldPoint;
+
 mod manifold;
 mod fb;
 mod cube;
@@ -12,5 +16,34 @@
 mod zero;
 
 fn main() {
+    simple_test()
+}
 
+fn simple_test() {
+    use cube::*;
+    use alg_tools::loc::Loc;
+    use Face::*;
+    use zero::ZeroFn;
+    use alg_tools::mapping::Sum;
+    use alg_tools::iterate::{AlgIteratorOptions, Verbose};
+    
+    let points = [
+        OnCube::new(F1, Loc([0.5, 0.5])),
+        OnCube::new(F2, Loc([0.5, 0.5])),
+        OnCube::new(F4, Loc([0.1, 0.1])),
+    ];
+
+    //let x = points[0].clone();
+    let x = OnCube::new(F6, Loc([0.5, 0.5]));
+    let f = Sum::new(points.into_iter().map(DistToSquaredDiv2));
+    let g = ZeroFn::new();
+    let τ = 0.1;
+    let iter = AlgIteratorOptions{
+        max_iter : 100,
+        verbose_iter : Verbose::Every(1),
+        .. Default::default()
+    };
+
+    let x̂ = forward_backward(&f, &g, x, τ, iter);
+    println!("result = {}\n{:?}", x̂.embedded_coords(), &x̂);
 }
--- a/src/manifold.rs	Mon Oct 21 08:44:23 2024 -0500
+++ b/src/manifold.rs	Mon Oct 21 10:02:57 2024 -0500
@@ -19,3 +19,10 @@
     fn tangent_origin(&self) -> Self::Tangent;
 }
 
+/// Point on a manifold that possesses displayable embedded coordinates.
+pub trait EmbeddedManifoldPoint : ManifoldPoint + std::fmt::Debug {
+    type EmbeddedCoords : std::fmt::Display;
+
+    /// Convert a point on a manifold into embedded coordinates
+    fn embedded_coords(&self) -> Self::EmbeddedCoords;
+}

mercurial