src/run.rs

Sat, 10 Dec 2022 16:22:38 +0200

author
Tuomo Valkonen <tuomov@iki.fi>
date
Sat, 10 Dec 2022 16:22:38 +0200
changeset 22
9fb8ecb3da74
parent 20
90f77ad9a98d
child 23
9869fa1e0ccd
permissions
-rw-r--r--

Add rust-version specification to Cargo.toml

/*!
This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment.
*/

use numeric_literals::replace_float_literals;
use colored::Colorize;
use serde::{Serialize, Deserialize};
use serde_json;
use nalgebra::base::DVector;
use std::hash::Hash;
use chrono::{DateTime, Utc};
use cpu_time::ProcessTime;
use clap::ValueEnum;
use std::collections::HashMap;
use std::time::Instant;

use rand::prelude::{
    StdRng,
    SeedableRng
};
use rand_distr::Distribution;

use alg_tools::bisection_tree::*;
use alg_tools::iterate::{
    Timed,
    AlgIteratorOptions,
    Verbose,
    AlgIteratorFactory,
};
use alg_tools::logger::Logger;
use alg_tools::error::DynError;
use alg_tools::tabledump::TableDump;
use alg_tools::sets::Cube;
use alg_tools::mapping::RealMapping;
use alg_tools::nalgebra_support::ToNalgebraRealField;
use alg_tools::euclidean::Euclidean;
use alg_tools::norms::{Norm, L1};
use alg_tools::lingrid::lingrid;
use alg_tools::sets::SetOrd;

use crate::kernels::*;
use crate::types::*;
use crate::measures::*;
use crate::measures::merging::SpikeMerging;
use crate::forward_model::*;
use crate::fb::{
    FBConfig,
    pointsource_fb,
    FBMetaAlgorithm, FBGenericConfig,
};
use crate::pdps::{
    PDPSConfig,
    L2Squared,
    pointsource_pdps,
};
use crate::frank_wolfe::{
    FWConfig,
    FWVariant,
    pointsource_fw,
    prepare_optimise_weights,
    optimise_weights,
};
use crate::subproblem::InnerSettings;
use crate::seminorms::*;
use crate::plot::*;
use crate::{AlgorithmOverrides, CommandLineArgs};
use crate::tolerance::Tolerance;

/// Available algorithms and their configurations
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub enum AlgorithmConfig<F : Float> {
    FB(FBConfig<F>),
    FW(FWConfig<F>),
    PDPS(PDPSConfig<F>),
}

fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> {
    assert!(v.len() == 3);
    Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] }
}

impl<F : ClapFloat> AlgorithmConfig<F> {
    /// Override supported parameters based on the command line.
    pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self {
        let override_fb_generic = |g : FBGenericConfig<F>| {
            FBGenericConfig {
                bootstrap_insertions : cli.bootstrap_insertions
                                          .as_ref()
                                          .map_or(g.bootstrap_insertions,
                                                  |n| Some((n[0], n[1]))),
                merge_every : cli.merge_every.unwrap_or(g.merge_every),
                merging : cli.merging.clone().unwrap_or(g.merging),
                final_merging : cli.final_merging.clone().unwrap_or(g.final_merging),
                tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance),
                .. g
            }
        };

        use AlgorithmConfig::*;
        match self {
            FB(fb) => FB(FBConfig {
                τ0 : cli.tau0.unwrap_or(fb.τ0),
                insertion : override_fb_generic(fb.insertion),
                .. fb
            }),
            PDPS(pdps) => PDPS(PDPSConfig {
                τ0 : cli.tau0.unwrap_or(pdps.τ0),
                σ0 : cli.sigma0.unwrap_or(pdps.σ0),
                acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
                insertion : override_fb_generic(pdps.insertion),
                .. pdps
            }),
            FW(fw) => FW(FWConfig {
                merging : cli.merging.clone().unwrap_or(fw.merging),
                tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance),
                .. fw
            })
        }
    }
}

/// Helper struct for tagging and [`AlgorithmConfig`] or [`Experiment`] with a name.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Named<Data> {
    pub name : String,
    #[serde(flatten)]
    pub data : Data,
}

/// Shorthand algorithm configurations, to be used with the command line parser
#[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub enum DefaultAlgorithm {
    /// The μFB forward-backward method
    #[clap(name = "fb")]
    FB,
    /// The μFISTA inertial forward-backward method
    #[clap(name = "fista")]
    FISTA,
    /// The “fully corrective” conditional gradient method
    #[clap(name = "fw")]
    FW,
    /// The “relaxed conditional gradient method
    #[clap(name = "fwrelax")]
    FWRelax,
    /// The μPDPS primal-dual proximal splitting method
    #[clap(name = "pdps")]
    PDPS,
}

impl DefaultAlgorithm {
    /// Returns the algorithm configuration corresponding to the algorithm shorthand
    pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> {
        use DefaultAlgorithm::*;
        match *self {
            FB => AlgorithmConfig::FB(Default::default()),
            FISTA => AlgorithmConfig::FB(FBConfig{
                meta : FBMetaAlgorithm::InertiaFISTA,
                .. Default::default()
            }),
            FW => AlgorithmConfig::FW(Default::default()),
            FWRelax => AlgorithmConfig::FW(FWConfig{
                variant : FWVariant::Relaxed,
                .. Default::default()
            }),
            PDPS => AlgorithmConfig::PDPS(Default::default()),
        }
    }

    /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand
    pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> {
        self.to_named(self.default_config())
    }

    pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> {
        let name = self.to_possible_value().unwrap().get_name().to_string();
        Named{ name , data : alg }
    }
}


// // Floats cannot be hashed directly, so just hash the debug formatting
// // for use as file identifier.
// impl<F : Float> Hash for AlgorithmConfig<F> {
//     fn hash<H: Hasher>(&self, state: &mut H) {
//         format!("{:?}", self).hash(state);
//     }
// }

/// Plotting level configuration
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, ValueEnum, Debug)]
pub enum PlotLevel {
    /// Plot nothing
    #[clap(name = "none")]
    None,
    /// Plot problem data
    #[clap(name = "data")]
    Data,
    /// Plot iterationwise state
    #[clap(name = "iter")]
    Iter,
}

type DefaultBT<F, const N : usize> = BT<
    DynamicDepth,
    F,
    usize,
    Bounds<F>,
    N
>;
type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>;
type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::<
    F,
    Sensor,
    Spread,
    DefaultBT<F, N>,
    N
>;

/// This is a dirty workaround to rust-csv not supporting struct flattening etc.
#[derive(Serialize)]
struct CSVLog<F> {
    iter : usize,
    cpu_time : f64,
    value : F,
    post_value : F,
    n_spikes : usize,
    inner_iters : usize,
    merged : usize,
    pruned : usize,
    this_iters : usize,
}

/// Collected experiment statistics
#[derive(Clone, Debug, Serialize)]
struct ExperimentStats<F : Float> {
    /// Signal-to-noise ratio in decibels
    ssnr : F,
    /// Proportion of noise in the signal as a number in $[0, 1]$.
    noise_ratio : F,
    /// When the experiment was run (UTC)
    when : DateTime<Utc>,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> ExperimentStats<F> {
    /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal.
    fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self {
        let s = signal.norm2_squared();
        let n = noise.norm2_squared();
        let noise_ratio = (n / s).sqrt();
        let ssnr = 10.0 * (s /  n).log10();
        ExperimentStats {
            ssnr,
            noise_ratio,
            when : Utc::now(),
        }
    }
}
/// Collected algorithm statistics
#[derive(Clone, Debug, Serialize)]
struct AlgorithmStats<F : Float> {
    /// Overall CPU time spent
    cpu_time : F,
    /// Real time spent
    elapsed : F
}


/// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input
/// and outputs a [`DynError`].
fn write_json<T : Serialize>(filename : String, data : &T) -> DynError {
    serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?;
    Ok(())
}


/// Struct for experiment configurations
#[derive(Debug, Clone, Serialize)]
pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize>
where F : Float,
      [usize; N] : Serialize,
      NoiseDistr : Distribution<F>,
      S : Sensor<F, N>,
      P : Spread<F, N>,
      K : SimpleConvolutionKernel<F, N>,
{
    /// Domain $Ω$.
    pub domain : Cube<F, N>,
    /// Number of sensors along each dimension
    pub sensor_count : [usize; N],
    /// Noise distribution
    pub noise_distr : NoiseDistr,
    /// Seed for random noise generation (for repeatable experiments)
    pub noise_seed : u64,
    /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$.
    pub sensor : S,
    /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$.
    pub spread : P,
    /// Kernel $ρ$ of $𝒟$.
    pub kernel : K,
    /// True point sources
    pub μ_hat : DiscreteMeasure<Loc<F, N>, F>,
    /// Regularisation parameter
    pub α : F,
    /// For plotting : how wide should the kernels be plotted
    pub kernel_plot_width : F,
    /// Data term
    pub dataterm : DataTerm,
    /// A map of default configurations for algorithms
    #[serde(skip)]
    pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>,
}

/// Trait for runnable experiments
pub trait RunnableExperiment<F : ClapFloat> {
    /// Run all algorithms provided, or default algorithms if none provided, on the experiment.
    fn runall(&self, cli : &CommandLineArgs,
              algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError;

    /// Return algorithm default config
    fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>)
    -> Named<AlgorithmConfig<F>>;
}

impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for
Named<Experiment<F, NoiseDistr, S, K, P, N>>
where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
      [usize; N] : Serialize,
      S : Sensor<F, N> + Copy + Serialize,
      P : Spread<F, N> + Copy + Serialize,
      Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy,
      AutoConvolution<P> : BoundedBy<F, K>,
      K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + Copy + Serialize,
      Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
      PlotLookup : Plotting<N>,
      DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
      NoiseDistr : Distribution<F> + Serialize {

    fn algorithm_defaults(&self, alg : DefaultAlgorithm, cli : &AlgorithmOverrides<F>)
    -> Named<AlgorithmConfig<F>> {
        alg.to_named(
            self.data
                .algorithm_defaults
                .get(&alg)
                .map_or_else(|| alg.default_config(),
                            |config| config.clone())
                .cli_override(cli)
        )
    }

    fn runall(&self, cli : &CommandLineArgs,
              algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError {
        // Get experiment configuration
        let &Named {
            name : ref experiment_name,
            data : Experiment {
                domain, sensor_count, ref noise_distr, sensor, spread, kernel,
                ref μ_hat, α, kernel_plot_width, dataterm, noise_seed,
                ..
            }
        } = self;

        // Set up output directory
        let prefix = format!("{}/{}/", cli.outdir, self.name);

        // Set up algorithms
        let iterator_options = AlgIteratorOptions{
                max_iter : cli.max_iter,
                verbose_iter : cli.verbose_iter
                                  .map_or(Verbose::Logarithmic(10),
                                          |n| Verbose::Every(n)),
                quiet : cli.quiet,
        };
        let algorithms = match (algs, self.data.dataterm) {
            (Some(algs), _) => algs,
            (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()],
            (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()],
        };

        // Set up operators
        let depth = DynamicDepth(8);
        let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth);
        let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel);

        // Set up random number generator.
        let mut rng = StdRng::seed_from_u64(noise_seed);

        // Generate the data and calculate SSNR statistic
        let b_hat = opA.apply(μ_hat);
        let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng);
        let b = &b_hat + &noise;
        // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
        // overloading log10 and conflicting with standard NumTraits one.
        let stats = ExperimentStats::new(&b, &noise);

        // Save experiment configuration and statistics
        let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t);
        std::fs::create_dir_all(&prefix)?;
        write_json(mkname_e("experiment"), self)?;
        write_json(mkname_e("config"), cli)?;
        write_json(mkname_e("stats"), &stats)?;

        plotall(cli, &prefix, &domain, &sensor, &kernel, &spread,
                &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?;

        // Run the algorithm(s)
        for named @ Named { name : alg_name, data : alg } in algorithms.iter() {
            let this_prefix = format!("{}{}/", prefix, alg_name);

            let running = || if !cli.quiet {
                println!("{}\n{}\n{}",
                        format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(),
                        format!("{:?}", iterator_options).bright_black(),
                        format!("{:?}", alg).bright_black());
            };

            // Create Logger and IteratorFactory
            let mut logger = Logger::new();
            let findim_data = prepare_optimise_weights(&opA);
            let inner_config : InnerSettings<F> = Default::default();
            let inner_it = inner_config.iterator_options;
            let logmap = |iter, Timed { cpu_time, data }| {
                let IterInfo {
                    value,
                    n_spikes,
                    inner_iters,
                    merged,
                    pruned,
                    postprocessing,
                    this_iters,
                    ..
                } = data;
                let post_value = match postprocessing {
                    None => value,
                    Some(mut μ) => {
                        match dataterm {
                            DataTerm::L2Squared => {
                                optimise_weights(
                                    &mut μ, &opA, &b, α, &findim_data, &inner_config,
                                    inner_it
                                );
                                dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon)
                            },
                            _ => value,
                        }
                    }
                };
                CSVLog {
                    iter,
                    value,
                    post_value,
                    n_spikes,
                    cpu_time : cpu_time.as_secs_f64(),
                    inner_iters,
                    merged,
                    pruned,
                    this_iters
                }
            };
            let iterator = iterator_options.instantiate()
                                           .timed()
                                           .mapped(logmap)
                                           .into_log(&mut logger);
            let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]);

            // Create plotter and directory if needed.
            let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 };
            let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid);

            // Run the algorithm
            let start = Instant::now();
            let start_cpu = ProcessTime::now();
            let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) {
                (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => {
                    running();
                    pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter)
                },
                (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => {
                    running();
                    pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter)
                },
                (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => {
                    running();
                    pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared)
                },
                (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => {
                    running();
                    pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1)
                },
                _ =>  {
                    let msg = format!("Algorithm “{alg_name}” not implemented for \
                                       dataterm {dataterm:?}. Skipping.").red();
                    eprintln!("{}", msg);
                    continue
                }
            };
            let elapsed = start.elapsed().as_secs_f64();
            let cpu_time = start_cpu.elapsed().as_secs_f64();

            println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow());

            // Save results
            println!("{}", "Saving results…".green());

            let mkname = |t| format!("{prefix}{alg_name}_{t}");

            write_json(mkname("config.json"), &named)?;
            write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?;
            μ.write_csv(mkname("reco.txt"))?;
            logger.write_csv(mkname("log.txt"))?;
        }

        Ok(())
    }
}

/// Plot experiment setup
#[replace_float_literals(F::cast_from(literal))]
fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>(
    cli : &CommandLineArgs,
    prefix : &String,
    domain : &Cube<F, N>,
    sensor : &Sensor,
    kernel : &Kernel,
    spread : &Spread,
    μ_hat : &DiscreteMeasure<Loc<F, N>, F>,
    op𝒟 : &𝒟,
    opA : &A,
    b_hat : &A::Observable,
    b : &A::Observable,
    kernel_plot_width : F,
) -> DynError
where F : Float + ToNalgebraRealField,
      Sensor : RealMapping<F, N> + Support<F, N> + Clone,
      Spread : RealMapping<F, N> + Support<F, N> + Clone,
      Kernel : RealMapping<F, N> + Support<F, N>,
      Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>,
      𝒟 : DiscreteMeasureOp<Loc<F, N>, F>,
      𝒟::Codomain : RealMapping<F, N>,
      A : ForwardModel<Loc<F, N>, F>,
      A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>,
      PlotLookup : Plotting<N>,
      Cube<F, N> : SetOrd {

    if cli.plot < PlotLevel::Data {
        return Ok(())
    }

    let base = Convolution(sensor.clone(), spread.clone());

    let resolution = if N==1 { 100 } else { 40 };
    let pfx = |n| format!("{}{}", prefix, n);
    let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]);

    PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string());
    PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string());
    PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string());
    PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string());

    let plotgrid2 = lingrid(&domain, &[resolution; N]);

    let ω_hat = op𝒟.apply(μ_hat);
    let noise =  opA.preadjoint().apply(opA.apply(μ_hat) - b);
    PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"), "ω̂".to_string());
    PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"),
                               "noise Aᵀ(Aμ̂ - b)".to_string());

    let preadj_b =  opA.preadjoint().apply(b);
    let preadj_b_hat =  opA.preadjoint().apply(b_hat);
    //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds());
    PlotLookup::plot_into_file_spikes(
        "Aᵀb".to_string(), &preadj_b,
        "Aᵀb̂".to_string(), Some(&preadj_b_hat),
        plotgrid2, None, &μ_hat,
        pfx("omega_b")
    );

    // Save true solution and observables
    let pfx = |n| format!("{}{}", prefix, n);
    μ_hat.write_csv(pfx("orig.txt"))?;
    opA.write_observable(&b_hat, pfx("b_hat"))?;
    opA.write_observable(&b, pfx("b_noisy"))
}

mercurial