Thu, 07 Nov 2024 12:49:51 -0500
colors
13 | 1 | /*! |
2 | Optimisation on non-Riemannian manifolds. | |
3 | */ | |
4 | 4 | |
5 | // We use unicode. We would like to use much more of it than Rust allows. | |
6 | // Live with it. Embrace it. | |
7 | #![allow(uncommon_codepoints)] | |
8 | #![allow(mixed_script_confusables)] | |
9 | #![allow(confusable_idents)] | |
1 | 10 | |
19 | 11 | mod manifold; |
12 | mod fb; | |
13 | mod cube; | |
14 | mod dist; | |
15 | mod zero; | |
16 | mod scaled; | |
17 | ||
20 | 18 | use serde::{Serialize, Deserialize}; |
16 | 19 | use alg_tools::logger::Logger; |
20 | use alg_tools::tabledump::{TableDump, write_csv}; | |
21 | use alg_tools::error::DynError; | |
22 | use alg_tools::lingrid::LinSpace; | |
23 | use alg_tools::loc::Loc; | |
24 | use alg_tools::types::*; | |
25 | use alg_tools::mapping::{Sum, Apply}; | |
26 | use alg_tools::iterate::{AlgIteratorOptions, AlgIteratorFactory, Verbose}; | |
24 | 27 | use alg_tools::mapping::Mapping; |
16 | 28 | use image::{ImageFormat, ImageBuffer, Rgb}; |
29 | ||
19 | 30 | use dist::{DistTo, DistToSquaredDiv2}; |
24 | 31 | use fb::{forward_backward, IterInfo, Desc, Prox}; |
7 | 32 | use manifold::EmbeddedManifoldPoint; |
11 | 33 | use cube::*; |
16 | 34 | use Face::*; |
19 | 35 | #[allow(unused_imports)] |
16 | 36 | use zero::ZeroFn; |
19 | 37 | use scaled::Scaled; |
1 | 38 | |
13 | 39 | /// Program entry point |
1 | 40 | fn main() { |
12 | 41 | simple_cube_test().unwrap() |
7 | 42 | } |
1 | 43 | |
20 | 44 | /// Helper structure for saving a point on a cube into a CSV file |
45 | #[derive(Serialize,Deserialize,Debug)] | |
46 | struct CSVPoint { | |
12 | 47 | face : Face, |
48 | x : f64, | |
49 | y : f64, | |
50 | z : f64 | |
51 | } | |
52 | ||
20 | 53 | impl From<&OnCube> for CSVPoint { |
54 | fn from(point : &OnCube) -> Self { | |
55 | let Loc([x,y,z]) = point.embedded_coords(); | |
56 | let face = point.face(); | |
57 | CSVPoint { face, x, y, z } | |
58 | } | |
59 | } | |
60 | ||
61 | /// Helper structure for saving the log into a CSV file | |
62 | #[derive(Serialize,Deserialize,Debug)] | |
63 | struct CSVLog { | |
64 | iter : usize, | |
65 | value : f64, | |
66 | // serde is junk | |
67 | //#[serde(flatten)] | |
68 | //point : CSVPoint | |
69 | face : Face, | |
70 | x : f64, | |
71 | y : f64, | |
72 | z : f64 | |
73 | } | |
74 | ||
75 | ||
13 | 76 | /// Location for saving results |
12 | 77 | static PREFIX : &str = "res"; |
78 | ||
79 | /// A simple test on the cube | |
80 | fn simple_cube_test() -> DynError { | |
7 | 81 | |
82 | let points = [ | |
35 | 83 | // OnCube::new(F1, Loc([0.5, 0.7])), |
20 | 84 | OnCube::new(F2, Loc([0.3, 0.5])), |
35 | 85 | // OnCube::new(F4, Loc([0.9, 0.9])), |
86 | //OnCube::new(F6, Loc([0.4, 0.3])), | |
87 | // OnCube::new(F4, Loc([0.3, 0.7])), | |
88 | // OnCube::new(F3, Loc([0.3, 0.7])), | |
7 | 89 | ]; |
90 | ||
20 | 91 | let origin = OnCube::new(F4, Loc([0.5, 0.5])); |
92 | ||
93 | write_points(format!("{PREFIX}/data"), points.iter())?; | |
94 | write_points(format!("{PREFIX}/origin"), std::iter::once(&origin))?; | |
95 | ||
7 | 96 | let f = Sum::new(points.into_iter().map(DistToSquaredDiv2)); |
35 | 97 | let g = ZeroFn::new(); |
98 | //let g = Scaled::new(0.5, DistTo(origin)); | |
32 | 99 | let τ = 0.1; |
12 | 100 | |
24 | 101 | std::fs::create_dir_all(PREFIX)?; |
102 | for face in Face::all() { | |
103 | write_face_csv(format!("{PREFIX}/{face}"), face, 32, |x| f.apply(x) + g.apply(x))?; | |
104 | } | |
31 | 105 | write_face_imgs(128, |x| f.apply(x) + g.apply(x))?; |
24 | 106 | |
107 | run_and_save("x1", &f, &g, OnCube::new(F3, Loc([0.1, 0.7])), τ)?; | |
33 | 108 | run_and_save("x2", &f, &g, OnCube::new(F2, Loc([0.3, 0.1])), τ)?; |
24 | 109 | run_and_save("x3", &f, &g, OnCube::new(F6, Loc([0.6, 0.2])), τ) |
110 | } | |
111 | ||
25 | 112 | /// Runs [forward_backward] and saves the results. |
24 | 113 | pub fn run_and_save<F, G>( |
114 | name : &str, | |
115 | f : &F, | |
116 | g : &G, | |
117 | x : OnCube, | |
118 | τ : f64, | |
119 | ) -> DynError | |
120 | where F : Desc<OnCube> + Mapping<OnCube, Codomain = f64>, | |
121 | G : Prox<OnCube> + Mapping<OnCube, Codomain = f64> { | |
122 | ||
12 | 123 | let mut logger = Logger::new(); |
124 | let logmap = |iter, IterInfo { value, point } : IterInfo<OnCube>| { | |
20 | 125 | let CSVPoint {x , y, z, face} = CSVPoint::from(&point); |
126 | CSVLog { | |
127 | iter, value, //point : CSVPoint::from(&point) | |
128 | x, y, z, face | |
129 | } | |
12 | 130 | }; |
7 | 131 | let iter = AlgIteratorOptions{ |
32 | 132 | max_iter : 20, |
7 | 133 | verbose_iter : Verbose::Every(1), |
134 | .. Default::default() | |
12 | 135 | }.mapped(logmap) |
136 | .into_log(&mut logger); | |
7 | 137 | |
24 | 138 | let x̂ = forward_backward(f, g, x, τ, iter); |
7 | 139 | println!("result = {}\n{:?}", x̂.embedded_coords(), &x̂); |
11 | 140 | |
24 | 141 | logger.write_csv(format!("{PREFIX}/{name}_log.csv"))?; |
12 | 142 | |
143 | Ok(()) | |
1 | 144 | } |
11 | 145 | |
12 | 146 | /// Writes the values of `f` on `face` of a [`OnCube`] into a PNG file |
147 | /// with resolution `n × n`. | |
31 | 148 | fn write_face_imgs(n : usize, mut f : impl FnMut(&OnCube) -> f64) -> DynError { |
11 | 149 | let grid = LinSpace { |
150 | start : Loc([0.0, 0.0]), | |
151 | end : Loc([1.0, 1.0]), | |
152 | count : [n, n] | |
153 | }; | |
31 | 154 | |
155 | let mut m = 0.0; | |
156 | let mut datas = Vec::new(); | |
157 | ||
158 | for face in Face::all() { | |
159 | let rawdata : Vec<_> = grid.into_iter() | |
160 | .map(|Loc([x,y])| f(&OnCube::new(face, Loc([x, 1.0-y])))) | |
161 | .collect(); | |
162 | m = rawdata.iter().copied().fold(m, f64::max); | |
163 | datas.push((face, rawdata)); | |
164 | } | |
11 | 165 | |
31 | 166 | for (face, rawdata) in datas { |
167 | let mut img = ImageBuffer::new(n as u32, n as u32); | |
168 | img.pixels_mut() | |
169 | .zip(rawdata) | |
170 | .for_each(|(p, v)| { | |
171 | let t = v/m; | |
172 | // A very colourful option for bug hunting. | |
36 | 173 | let rgb = [(50.0*t).cos(), (20.0*t).sin(), (3.0*t).cos()]; |
174 | //let rgb = [1.0-t, 1.0-t, 1.0]; | |
31 | 175 | *p = Rgb(rgb.map(|v| (v*(u8::RANGE_MAX as f64)) as u8)) |
176 | }); | |
177 | ||
178 | img.save_with_format(format!("{PREFIX}/{face}.png"), ImageFormat::Png)?; | |
179 | } | |
12 | 180 | |
181 | Ok(()) | |
11 | 182 | } |
16 | 183 | |
184 | /// Writes the values of `f` on `face` of a [`OnCube`] into a CSV file | |
185 | /// with resolution `n × n`. | |
186 | fn write_face_csv(filename : String, face : Face, n : usize, mut f : impl FnMut(&OnCube) -> f64) -> DynError { | |
187 | ||
188 | #[derive(Serialize)] | |
189 | struct CSVFace { u : f64, v : f64, value : f64 } | |
190 | ||
191 | let grid = LinSpace { | |
192 | start : Loc([0.0, 0.0]), | |
193 | end : Loc([1.0, 1.0]), | |
194 | count : [n, n] | |
195 | }; | |
196 | ||
197 | let data = grid.into_iter() | |
198 | .map(|p@Loc([u,v])| CSVFace{ u, v, value : f(&OnCube::new(face, p)) }); | |
199 | ||
200 | write_csv(data, format!("{filename}.csv"))?; | |
201 | ||
202 | Ok(()) | |
203 | } | |
20 | 204 | |
205 | /// Writes a list of points on a [`OnCube`] into a CSV file | |
206 | fn write_points<'a, I : Iterator<Item=&'a OnCube>>(filename : String, list : I) -> DynError { | |
207 | ||
208 | write_csv(list.map(CSVPoint::from), format!("{filename}.csv"))?; | |
209 | Ok(()) | |
210 | } | |
211 |