src/main.rs

Thu, 07 Nov 2024 13:04:20 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 07 Nov 2024 13:04:20 -0500
changeset 31
49227d097d14
parent 25
9ac11616a2c5
child 32
eb07aee01d57
permissions
-rw-r--r--

image saving changes

/*!
Optimisation on non-Riemannian manifolds.
*/

// We use unicode. We would like to use much more of it than Rust allows.
// Live with it. Embrace it.
#![allow(uncommon_codepoints)]
#![allow(mixed_script_confusables)]
#![allow(confusable_idents)]

mod manifold;
mod fb;
mod cube;
mod dist;
mod zero;
mod scaled;

use serde::{Serialize, Deserialize};
use alg_tools::logger::Logger;
use alg_tools::tabledump::{TableDump, write_csv};
use alg_tools::error::DynError;
use alg_tools::lingrid::LinSpace;
use alg_tools::loc::Loc;
use alg_tools::types::*;
use alg_tools::mapping::{Sum, Apply};
use alg_tools::iterate::{AlgIteratorOptions, AlgIteratorFactory, Verbose};
use alg_tools::mapping::Mapping;
use image::{ImageFormat, ImageBuffer, Rgb};

use dist::{DistTo, DistToSquaredDiv2};
use fb::{forward_backward, IterInfo, Desc, Prox};
use manifold::EmbeddedManifoldPoint;
use cube::*;
use Face::*;
#[allow(unused_imports)]
use zero::ZeroFn;
use scaled::Scaled;

/// Program entry point
fn main() {
    simple_cube_test().unwrap()
}

/// Helper structure for saving a point on a cube into a CSV file
#[derive(Serialize,Deserialize,Debug)]
struct CSVPoint {
    face : Face,
    x : f64,
    y : f64,
    z : f64
}

impl From<&OnCube> for CSVPoint {
    fn from(point : &OnCube) -> Self {
        let Loc([x,y,z]) = point.embedded_coords();
        let face = point.face();
        CSVPoint { face, x,  y, z }
    }
}

/// Helper structure for saving the log into a CSV file
#[derive(Serialize,Deserialize,Debug)]
struct CSVLog {
    iter : usize,
    value : f64,
    // serde is junk
    //#[serde(flatten)]
    //point : CSVPoint
    face : Face,
    x : f64,
    y : f64,
    z : f64
}


/// Location for saving results
static PREFIX : &str = "res";

/// A simple test on the cube
fn simple_cube_test() -> DynError {
    
    let points = [
        OnCube::new(F1, Loc([0.5, 0.7])),
        OnCube::new(F2, Loc([0.3, 0.5])),
        OnCube::new(F4, Loc([0.9, 0.9])),
        OnCube::new(F6, Loc([0.4, 0.3])),
        OnCube::new(F4, Loc([0.3, 0.7])),
        OnCube::new(F3, Loc([0.3, 0.7])),
    ];

    let origin = OnCube::new(F4, Loc([0.5, 0.5]));

    write_points(format!("{PREFIX}/data"), points.iter())?;
    write_points(format!("{PREFIX}/origin"), std::iter::once(&origin))?;

    let f = Sum::new(points.into_iter().map(DistToSquaredDiv2));
    //let g = ZeroFn::new();
    let g = Scaled::new(0.5, DistTo(origin));
    let τ = 0.05;
    
    std::fs::create_dir_all(PREFIX)?;
    for face in Face::all() {
        write_face_csv(format!("{PREFIX}/{face}"), face, 32, |x| f.apply(x) + g.apply(x))?;
    }
    write_face_imgs(128, |x| f.apply(x) + g.apply(x))?;

    run_and_save("x1", &f, &g, OnCube::new(F3, Loc([0.1, 0.7])), τ)?;
    run_and_save("x2", &f, &g, OnCube::new(F2, Loc([0.1, 0.7])), τ)?;
    run_and_save("x3", &f, &g, OnCube::new(F6, Loc([0.6, 0.2])), τ)
}

/// Runs [forward_backward] and saves the results.
pub fn run_and_save<F, G>(
    name : &str,
    f : &F,
    g : &G,
    x : OnCube,
    τ : f64,
) -> DynError
where F : Desc<OnCube> +  Mapping<OnCube, Codomain = f64>,
      G : Prox<OnCube> +  Mapping<OnCube, Codomain = f64> {
    
    let mut logger = Logger::new();
    let logmap = |iter, IterInfo { value, point } : IterInfo<OnCube>| {
        let CSVPoint {x , y, z, face} = CSVPoint::from(&point);
        CSVLog {
            iter, value, //point : CSVPoint::from(&point)
            x, y, z, face
         }
    };
    let iter = AlgIteratorOptions{
        max_iter : 100,
        verbose_iter : Verbose::Every(1),
        .. Default::default()
    }.mapped(logmap)
     .into_log(&mut logger);

    let x̂ = forward_backward(f, g, x, τ, iter);
    println!("result = {}\n{:?}", x̂.embedded_coords(), &x̂);

    logger.write_csv(format!("{PREFIX}/{name}_log.csv"))?;

    Ok(())
}

/// Writes the values of `f` on `face` of a [`OnCube`] into a PNG file
/// with resolution `n × n`.
fn write_face_imgs(n : usize, mut f : impl FnMut(&OnCube) -> f64) -> DynError {
    let grid = LinSpace {
        start : Loc([0.0, 0.0]),
        end : Loc([1.0, 1.0]),
        count : [n, n]
    };

    let mut m = 0.0;
    let mut datas = Vec::new();

    for face in Face::all() {
        let rawdata : Vec<_> = grid.into_iter()
                                .map(|Loc([x,y])| f(&OnCube::new(face, Loc([x, 1.0-y]))))
                                .collect();
        m = rawdata.iter().copied().fold(m, f64::max);
        datas.push((face, rawdata));
    }

    for (face, rawdata) in datas {
        let mut img = ImageBuffer::new(n as u32, n as u32);
        img.pixels_mut()
            .zip(rawdata)
            .for_each(|(p, v)| {
                let t = v/m;
                // A very colourful option for bug hunting.
                //let rgb = [(50.0*t).cos(), (20.0*t).sin(), (3.0*t).cos()];
                let rgb = [1.0-t, 1.0-t, 1.0];
                *p = Rgb(rgb.map(|v| (v*(u8::RANGE_MAX as f64)) as u8))
            });

        img.save_with_format(format!("{PREFIX}/{face}.png"), ImageFormat::Png)?;
    }

    Ok(())
}

/// Writes the values of `f` on `face` of a [`OnCube`] into a CSV file
/// with resolution `n × n`.
fn write_face_csv(filename : String, face : Face, n : usize, mut f : impl FnMut(&OnCube) -> f64) -> DynError {

    #[derive(Serialize)]
    struct CSVFace { u : f64, v : f64, value : f64 }

    let grid = LinSpace {
        start : Loc([0.0, 0.0]),
        end : Loc([1.0, 1.0]),
        count : [n, n]
    };

    let data = grid.into_iter()
                   .map(|p@Loc([u,v])| CSVFace{ u, v, value : f(&OnCube::new(face, p)) });

    write_csv(data, format!("{filename}.csv"))?;

    Ok(())
}

/// Writes a list of points on a [`OnCube`] into a CSV file
fn write_points<'a, I : Iterator<Item=&'a OnCube>>(filename : String, list : I) -> DynError {

    write_csv(list.map(CSVPoint::from), format!("{filename}.csv"))?;
    Ok(())
}

mercurial