Mon, 21 Oct 2024 10:02:57 -0500
A simple test
Cargo.toml | file | annotate | diff | comparison | revisions | |
src/cube.rs | file | annotate | diff | comparison | revisions | |
src/dist.rs | file | annotate | diff | comparison | revisions | |
src/fb.rs | file | annotate | diff | comparison | revisions | |
src/main.rs | file | annotate | diff | comparison | revisions | |
src/manifold.rs | file | annotate | diff | comparison | revisions |
--- a/Cargo.toml Mon Oct 21 08:44:23 2024 -0500 +++ b/Cargo.toml Mon Oct 21 10:02:57 2024 -0500 @@ -20,3 +20,4 @@ [dependencies] serde = { version = "1.0", features = ["derive"] } alg_tools = { version = "~0.3.0-dev", path = "../alg_tools", default-features = false } +colored = "~2.0.0"
--- a/src/cube.rs Mon Oct 21 08:44:23 2024 -0500 +++ b/src/cube.rs Mon Oct 21 10:02:57 2024 -0500 @@ -3,7 +3,7 @@ use alg_tools::loc::Loc; use alg_tools::norms::{Norm, L2}; -use crate::manifold::ManifoldPoint; +use crate::manifold::{ManifoldPoint, EmbeddedManifoldPoint}; #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum Face {F1, F2, F3, F4, F5, F6} @@ -183,6 +183,19 @@ (Equal, Greater) => self.adjacent_faces()[3], } } + + /// Get embedded 3D coordinates + pub fn embedded_coords(&self, p : &Point) -> Loc<f64, 3> { + let &Loc([x, y]) = p; + Loc(match *self { + F1 => [x, y, 0.0], + F2 => [1.0, x, y], + F3 => [0.0, 1.0-x, y], + F4 => [x, 0.0, y], + F5 => [1.0 - x, 1.0, y], + F6 => [x, y, 1.0], + }) + } } #[derive(Clone, Debug, PartialEq)] @@ -192,6 +205,13 @@ } impl OnCube { + /// Creates a new point on the cube, given a face and face-relative coordinates + /// in [0, 1]^2 + pub fn new(face : Face, point : Point) -> Self { + assert!(face.is_in_face(&point)); + OnCube { face, point } + } + /// Calculates both the logarithmic map and distance to another point fn log_dist(&self, other : &Self) -> (<Self as ManifoldPoint>::Tangent, f64) { let mut best_len = f64::INFINITY; @@ -208,6 +228,16 @@ } } + +impl EmbeddedManifoldPoint for OnCube { + type EmbeddedCoords = Loc<f64, 3>; + + /// Get embedded 3D coordinates + fn embedded_coords(&self) -> Loc<f64, 3> { + self.face.embedded_coords(&self.point) + } +} + impl ManifoldPoint for OnCube { type Tangent = Point;
--- a/src/dist.rs Mon Oct 21 08:44:23 2024 -0500 +++ b/src/dist.rs Mon Oct 21 10:02:57 2024 -0500 @@ -4,15 +4,13 @@ use crate::fb::{Grad, Desc}; /// Structure for distance-to functions -pub struct DistTo<M : ManifoldPoint> { - base : M -} +pub struct DistTo<M : ManifoldPoint>(pub M); impl<M : ManifoldPoint> Apply<M> for DistTo<M> { type Output = f64; fn apply(&self, x : M) -> Self::Output { - self.base.dist_to(&x) + self.0.dist_to(&x) } } @@ -20,20 +18,18 @@ type Output = f64; fn apply(&self, x : &'a M) -> Self::Output { - self.base.dist_to(x) + self.0.dist_to(x) } } /// Structure for distance-to functions -pub struct DistToSquaredDiv2<M : ManifoldPoint> { - base : M -} +pub struct DistToSquaredDiv2<M : ManifoldPoint>(pub M); impl<M : ManifoldPoint> Apply<M> for DistToSquaredDiv2<M> { type Output = f64; fn apply(&self, x : M) -> Self::Output { - let d = self.base.dist_to(&x); + let d = self.0.dist_to(&x); d*d / 2.0 } } @@ -42,7 +38,7 @@ type Output = f64; fn apply(&self, x : &'a M) -> Self::Output { - let d = self.base.dist_to(x); + let d = self.0.dist_to(x); d*d / 2.0 } } @@ -55,6 +51,6 @@ impl<M : ManifoldPoint> Grad<M> for DistToSquaredDiv2<M> { fn grad(&self, x : &M) -> M::Tangent { - x.log(&self.base) + x.log(&self.0) } }
--- a/src/fb.rs Mon Oct 21 08:44:23 2024 -0500 +++ b/src/fb.rs Mon Oct 21 10:02:57 2024 -0500 @@ -1,10 +1,11 @@ -use alg_tools::iterate::AlgIteratorFactory; +use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; use alg_tools::mapping::{Mapping, Sum}; use serde::Serialize; use std::iter::Sum as SumTrait; +use colored::ColoredString; -use crate::manifold::ManifoldPoint; +use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; /// Trait for function objects that implement gradients pub trait Grad<M : ManifoldPoint> { @@ -45,6 +46,15 @@ point : M, } +impl<M : ManifoldPoint + EmbeddedManifoldPoint> LogRepr for IterInfo<M> { + fn logrepr(&self) -> ColoredString { + format!("{}\t {}", + self.value, + self.point.embedded_coords() + ).as_str().into() + } +} + pub fn forward_backward<M, F, G, I>( f : &F, g : &G, @@ -52,7 +62,7 @@ τ : f64, iterator : I ) -> M -where M : ManifoldPoint, +where M : ManifoldPoint + EmbeddedManifoldPoint, F : Desc<M> + Mapping<M, Codomain = f64>, G : Prox<M> + Mapping<M, Codomain = f64>, I : AlgIteratorFactory<IterInfo<M>> {
--- a/src/main.rs Mon Oct 21 08:44:23 2024 -0500 +++ b/src/main.rs Mon Oct 21 10:02:57 2024 -0500 @@ -5,6 +5,10 @@ #![allow(mixed_script_confusables)] #![allow(confusable_idents)] +use dist::DistToSquaredDiv2; +use fb::forward_backward; +use manifold::EmbeddedManifoldPoint; + mod manifold; mod fb; mod cube; @@ -12,5 +16,34 @@ mod zero; fn main() { + simple_test() +} +fn simple_test() { + use cube::*; + use alg_tools::loc::Loc; + use Face::*; + use zero::ZeroFn; + use alg_tools::mapping::Sum; + use alg_tools::iterate::{AlgIteratorOptions, Verbose}; + + let points = [ + OnCube::new(F1, Loc([0.5, 0.5])), + OnCube::new(F2, Loc([0.5, 0.5])), + OnCube::new(F4, Loc([0.1, 0.1])), + ]; + + //let x = points[0].clone(); + let x = OnCube::new(F6, Loc([0.5, 0.5])); + let f = Sum::new(points.into_iter().map(DistToSquaredDiv2)); + let g = ZeroFn::new(); + let τ = 0.1; + let iter = AlgIteratorOptions{ + max_iter : 100, + verbose_iter : Verbose::Every(1), + .. Default::default() + }; + + let x̂ = forward_backward(&f, &g, x, τ, iter); + println!("result = {}\n{:?}", x̂.embedded_coords(), &x̂); }
--- a/src/manifold.rs Mon Oct 21 08:44:23 2024 -0500 +++ b/src/manifold.rs Mon Oct 21 10:02:57 2024 -0500 @@ -19,3 +19,10 @@ fn tangent_origin(&self) -> Self::Tangent; } +/// Point on a manifold that possesses displayable embedded coordinates. +pub trait EmbeddedManifoldPoint : ManifoldPoint + std::fmt::Debug { + type EmbeddedCoords : std::fmt::Display; + + /// Convert a point on a manifold into embedded coordinates + fn embedded_coords(&self) -> Self::EmbeddedCoords; +}