# HG changeset patch # User Tuomo Valkonen # Date 1729522977 18000 # Node ID 8979a6638424deac90717f19fd8c07f6869ee38e # Parent df962809228557f0290517dcec415f6e4d56e97c A simple test diff -r df9628092285 -r 8979a6638424 Cargo.toml --- a/Cargo.toml Mon Oct 21 08:44:23 2024 -0500 +++ b/Cargo.toml Mon Oct 21 10:02:57 2024 -0500 @@ -20,3 +20,4 @@ [dependencies] serde = { version = "1.0", features = ["derive"] } alg_tools = { version = "~0.3.0-dev", path = "../alg_tools", default-features = false } +colored = "~2.0.0" diff -r df9628092285 -r 8979a6638424 src/cube.rs --- a/src/cube.rs Mon Oct 21 08:44:23 2024 -0500 +++ b/src/cube.rs Mon Oct 21 10:02:57 2024 -0500 @@ -3,7 +3,7 @@ use alg_tools::loc::Loc; use alg_tools::norms::{Norm, L2}; -use crate::manifold::ManifoldPoint; +use crate::manifold::{ManifoldPoint, EmbeddedManifoldPoint}; #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum Face {F1, F2, F3, F4, F5, F6} @@ -183,6 +183,19 @@ (Equal, Greater) => self.adjacent_faces()[3], } } + + /// Get embedded 3D coordinates + pub fn embedded_coords(&self, p : &Point) -> Loc { + let &Loc([x, y]) = p; + Loc(match *self { + F1 => [x, y, 0.0], + F2 => [1.0, x, y], + F3 => [0.0, 1.0-x, y], + F4 => [x, 0.0, y], + F5 => [1.0 - x, 1.0, y], + F6 => [x, y, 1.0], + }) + } } #[derive(Clone, Debug, PartialEq)] @@ -192,6 +205,13 @@ } impl OnCube { + /// Creates a new point on the cube, given a face and face-relative coordinates + /// in [0, 1]^2 + pub fn new(face : Face, point : Point) -> Self { + assert!(face.is_in_face(&point)); + OnCube { face, point } + } + /// Calculates both the logarithmic map and distance to another point fn log_dist(&self, other : &Self) -> (::Tangent, f64) { let mut best_len = f64::INFINITY; @@ -208,6 +228,16 @@ } } + +impl EmbeddedManifoldPoint for OnCube { + type EmbeddedCoords = Loc; + + /// Get embedded 3D coordinates + fn embedded_coords(&self) -> Loc { + self.face.embedded_coords(&self.point) + } +} + impl ManifoldPoint for OnCube { type Tangent = Point; diff -r df9628092285 -r 8979a6638424 src/dist.rs --- a/src/dist.rs Mon Oct 21 08:44:23 2024 -0500 +++ b/src/dist.rs Mon Oct 21 10:02:57 2024 -0500 @@ -4,15 +4,13 @@ use crate::fb::{Grad, Desc}; /// Structure for distance-to functions -pub struct DistTo { - base : M -} +pub struct DistTo(pub M); impl Apply for DistTo { type Output = f64; fn apply(&self, x : M) -> Self::Output { - self.base.dist_to(&x) + self.0.dist_to(&x) } } @@ -20,20 +18,18 @@ type Output = f64; fn apply(&self, x : &'a M) -> Self::Output { - self.base.dist_to(x) + self.0.dist_to(x) } } /// Structure for distance-to functions -pub struct DistToSquaredDiv2 { - base : M -} +pub struct DistToSquaredDiv2(pub M); impl Apply for DistToSquaredDiv2 { type Output = f64; fn apply(&self, x : M) -> Self::Output { - let d = self.base.dist_to(&x); + let d = self.0.dist_to(&x); d*d / 2.0 } } @@ -42,7 +38,7 @@ type Output = f64; fn apply(&self, x : &'a M) -> Self::Output { - let d = self.base.dist_to(x); + let d = self.0.dist_to(x); d*d / 2.0 } } @@ -55,6 +51,6 @@ impl Grad for DistToSquaredDiv2 { fn grad(&self, x : &M) -> M::Tangent { - x.log(&self.base) + x.log(&self.0) } } diff -r df9628092285 -r 8979a6638424 src/fb.rs --- a/src/fb.rs Mon Oct 21 08:44:23 2024 -0500 +++ b/src/fb.rs Mon Oct 21 10:02:57 2024 -0500 @@ -1,10 +1,11 @@ -use alg_tools::iterate::AlgIteratorFactory; +use alg_tools::iterate::{AlgIteratorFactory, LogRepr}; use alg_tools::mapping::{Mapping, Sum}; use serde::Serialize; use std::iter::Sum as SumTrait; +use colored::ColoredString; -use crate::manifold::ManifoldPoint; +use crate::manifold::{EmbeddedManifoldPoint, ManifoldPoint}; /// Trait for function objects that implement gradients pub trait Grad { @@ -45,6 +46,15 @@ point : M, } +impl LogRepr for IterInfo { + fn logrepr(&self) -> ColoredString { + format!("{}\t {}", + self.value, + self.point.embedded_coords() + ).as_str().into() + } +} + pub fn forward_backward( f : &F, g : &G, @@ -52,7 +62,7 @@ τ : f64, iterator : I ) -> M -where M : ManifoldPoint, +where M : ManifoldPoint + EmbeddedManifoldPoint, F : Desc + Mapping, G : Prox + Mapping, I : AlgIteratorFactory> { diff -r df9628092285 -r 8979a6638424 src/main.rs --- a/src/main.rs Mon Oct 21 08:44:23 2024 -0500 +++ b/src/main.rs Mon Oct 21 10:02:57 2024 -0500 @@ -5,6 +5,10 @@ #![allow(mixed_script_confusables)] #![allow(confusable_idents)] +use dist::DistToSquaredDiv2; +use fb::forward_backward; +use manifold::EmbeddedManifoldPoint; + mod manifold; mod fb; mod cube; @@ -12,5 +16,34 @@ mod zero; fn main() { + simple_test() +} +fn simple_test() { + use cube::*; + use alg_tools::loc::Loc; + use Face::*; + use zero::ZeroFn; + use alg_tools::mapping::Sum; + use alg_tools::iterate::{AlgIteratorOptions, Verbose}; + + let points = [ + OnCube::new(F1, Loc([0.5, 0.5])), + OnCube::new(F2, Loc([0.5, 0.5])), + OnCube::new(F4, Loc([0.1, 0.1])), + ]; + + //let x = points[0].clone(); + let x = OnCube::new(F6, Loc([0.5, 0.5])); + let f = Sum::new(points.into_iter().map(DistToSquaredDiv2)); + let g = ZeroFn::new(); + let τ = 0.1; + let iter = AlgIteratorOptions{ + max_iter : 100, + verbose_iter : Verbose::Every(1), + .. Default::default() + }; + + let x̂ = forward_backward(&f, &g, x, τ, iter); + println!("result = {}\n{:?}", x̂.embedded_coords(), &x̂); } diff -r df9628092285 -r 8979a6638424 src/manifold.rs --- a/src/manifold.rs Mon Oct 21 08:44:23 2024 -0500 +++ b/src/manifold.rs Mon Oct 21 10:02:57 2024 -0500 @@ -19,3 +19,10 @@ fn tangent_origin(&self) -> Self::Tangent; } +/// Point on a manifold that possesses displayable embedded coordinates. +pub trait EmbeddedManifoldPoint : ManifoldPoint + std::fmt::Debug { + type EmbeddedCoords : std::fmt::Display; + + /// Convert a point on a manifold into embedded coordinates + fn embedded_coords(&self) -> Self::EmbeddedCoords; +}