Wed, 06 Nov 2024 10:00:06 -0500
Implement prox for DistTo
13 | 1 | /*! |
2 | Optimisation on non-Riemannian manifolds. | |
3 | */ | |
4 | 4 | |
5 | // We use unicode. We would like to use much more of it than Rust allows. | |
6 | // Live with it. Embrace it. | |
7 | #![allow(uncommon_codepoints)] | |
8 | #![allow(mixed_script_confusables)] | |
9 | #![allow(confusable_idents)] | |
1 | 10 | |
12 | 11 | use serde::Serialize; |
16 | 12 | use alg_tools::logger::Logger; |
13 | use alg_tools::tabledump::{TableDump, write_csv}; | |
14 | use alg_tools::error::DynError; | |
15 | use alg_tools::lingrid::LinSpace; | |
16 | use alg_tools::loc::Loc; | |
17 | use alg_tools::types::*; | |
18 | use alg_tools::mapping::{Sum, Apply}; | |
19 | use alg_tools::iterate::{AlgIteratorOptions, AlgIteratorFactory, Verbose}; | |
20 | use image::{ImageFormat, ImageBuffer, Rgb}; | |
21 | ||
7 | 22 | use dist::DistToSquaredDiv2; |
12 | 23 | use fb::{forward_backward, IterInfo}; |
7 | 24 | use manifold::EmbeddedManifoldPoint; |
11 | 25 | use cube::*; |
16 | 26 | use Face::*; |
27 | use zero::ZeroFn; | |
7 | 28 | |
1 | 29 | mod manifold; |
4 | 30 | mod fb; |
1 | 31 | mod cube; |
5 | 32 | mod dist; |
6
df9628092285
Add a zero function on manifolds
Tuomo Valkonen <tuomov@iki.fi>
parents:
5
diff
changeset
|
33 | mod zero; |
1 | 34 | |
13 | 35 | /// Program entry point |
1 | 36 | fn main() { |
12 | 37 | simple_cube_test().unwrap() |
7 | 38 | } |
1 | 39 | |
12 | 40 | /// Helper structure for saving the log into a CSV file |
41 | #[derive(Serialize)] | |
42 | struct CSVLog { | |
43 | iter : usize, | |
44 | value : f64, | |
45 | face : Face, | |
46 | x : f64, | |
47 | y : f64, | |
48 | z : f64 | |
49 | } | |
50 | ||
13 | 51 | /// Location for saving results |
12 | 52 | static PREFIX : &str = "res"; |
53 | ||
54 | /// A simple test on the cube | |
55 | fn simple_cube_test() -> DynError { | |
7 | 56 | |
57 | let points = [ | |
9 | 58 | //OnCube::new(F1, Loc([0.5, 0.5])), |
59 | //OnCube::new(F2, Loc([0.5, 0.5])), | |
60 | //OnCube::new(F4, Loc([0.1, 0.1])), | |
7 | 61 | OnCube::new(F1, Loc([0.5, 0.5])), |
9 | 62 | OnCube::new(F3, Loc([0.5, 0.5])), |
7 | 63 | OnCube::new(F2, Loc([0.5, 0.5])), |
64 | ]; | |
65 | ||
66 | //let x = points[0].clone(); | |
9 | 67 | // OnCube::new(F3, Loc([0.5, 0.5])); goes to opposite side |
68 | let x = OnCube::new(F3, Loc([0.5, 0.4])); | |
7 | 69 | let f = Sum::new(points.into_iter().map(DistToSquaredDiv2)); |
70 | let g = ZeroFn::new(); | |
71 | let τ = 0.1; | |
12 | 72 | |
73 | let mut logger = Logger::new(); | |
74 | let logmap = |iter, IterInfo { value, point } : IterInfo<OnCube>| { | |
75 | let Loc([x,y,z]) = point.embedded_coords(); | |
76 | let face = point.face(); | |
77 | CSVLog { iter, value, face, x, y, z } | |
78 | }; | |
7 | 79 | let iter = AlgIteratorOptions{ |
80 | max_iter : 100, | |
81 | verbose_iter : Verbose::Every(1), | |
82 | .. Default::default() | |
12 | 83 | }.mapped(logmap) |
84 | .into_log(&mut logger); | |
7 | 85 | |
86 | let x̂ = forward_backward(&f, &g, x, τ, iter); | |
87 | println!("result = {}\n{:?}", x̂.embedded_coords(), &x̂); | |
11 | 88 | |
12 | 89 | std::fs::create_dir_all(PREFIX)?; |
90 | ||
91 | logger.write_csv(format!("{PREFIX}/log.txt"))?; | |
92 | ||
11 | 93 | for face in Face::all() { |
16 | 94 | write_face_csv(format!("{PREFIX}/{face}"), face, 64, |x| f.apply(x) + g.apply(x))?; |
95 | write_face_img(format!("{PREFIX}/{face}"), face, 128, |x| f.apply(x) + g.apply(x))?; | |
11 | 96 | } |
12 | 97 | |
98 | Ok(()) | |
1 | 99 | } |
11 | 100 | |
12 | 101 | /// Writes the values of `f` on `face` of a [`OnCube`] into a PNG file |
102 | /// with resolution `n × n`. | |
16 | 103 | fn write_face_img(filename : String, face : Face, n : usize, mut f : impl FnMut(&OnCube) -> f64) -> DynError { |
11 | 104 | let mut img = ImageBuffer::new(n as u32, n as u32); |
105 | let grid = LinSpace { | |
106 | start : Loc([0.0, 0.0]), | |
107 | end : Loc([1.0, 1.0]), | |
108 | count : [n, n] | |
109 | }; | |
110 | let rawdata : Vec<_> = grid.into_iter() | |
111 | .map(|x| f(&OnCube::new(face, x))) | |
112 | .collect(); | |
113 | let a = rawdata.iter().copied().reduce(f64::max).unwrap(); | |
114 | img.pixels_mut() | |
115 | .zip(rawdata) | |
116 | .for_each(|(p, v)| { | |
117 | let t = v/a; | |
118 | let rgb = [1.0-t, 1.0-t, 1.0]; | |
119 | *p = Rgb(rgb.map(|v| (v*(u8::RANGE_MAX as f64)) as u8)) | |
120 | }); | |
121 | ||
12 | 122 | img.save_with_format(format!("{filename}.png"), ImageFormat::Png)?; |
123 | ||
124 | Ok(()) | |
11 | 125 | } |
16 | 126 | |
127 | /// Writes the values of `f` on `face` of a [`OnCube`] into a CSV file | |
128 | /// with resolution `n × n`. | |
129 | fn write_face_csv(filename : String, face : Face, n : usize, mut f : impl FnMut(&OnCube) -> f64) -> DynError { | |
130 | ||
131 | #[derive(Serialize)] | |
132 | struct CSVFace { u : f64, v : f64, value : f64 } | |
133 | ||
134 | let grid = LinSpace { | |
135 | start : Loc([0.0, 0.0]), | |
136 | end : Loc([1.0, 1.0]), | |
137 | count : [n, n] | |
138 | }; | |
139 | ||
140 | let data = grid.into_iter() | |
141 | .map(|p@Loc([u,v])| CSVFace{ u, v, value : f(&OnCube::new(face, p)) }); | |
142 | ||
143 | write_csv(data, format!("{filename}.csv"))?; | |
144 | ||
145 | Ok(()) | |
146 | } |