diff -r 9738b51d90d7 -r 4f468d35fa29 src/run.rs --- a/src/run.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/run.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,133 +2,81 @@ This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. */ -use numeric_literals::replace_float_literals; -use colored::Colorize; -use serde::{Serialize, Deserialize}; -use serde_json; -use nalgebra::base::DVector; -use std::hash::Hash; -use chrono::{DateTime, Utc}; -use cpu_time::ProcessTime; -use clap::ValueEnum; -use std::collections::HashMap; -use std::time::Instant; - -use rand::prelude::{ - StdRng, - SeedableRng -}; -use rand_distr::Distribution; - -use alg_tools::bisection_tree::*; -use alg_tools::iterate::{ - Timed, - AlgIteratorOptions, - Verbose, - AlgIteratorFactory, - LoggingIteratorFactory, - TimingIteratorFactory, - BasicAlgIteratorFactory, -}; -use alg_tools::logger::Logger; -use alg_tools::error::{ - DynError, - DynResult, -}; -use alg_tools::tabledump::TableDump; -use alg_tools::sets::Cube; -use alg_tools::mapping::{ - RealMapping, - DifferentiableMapping, - DifferentiableRealMapping, - Instance -}; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::euclidean::Euclidean; -use alg_tools::lingrid::{lingrid, LinSpace}; -use alg_tools::sets::SetOrd; -use alg_tools::linops::{RowOp, IdOp /*, ZeroOp*/}; -use alg_tools::discrete_gradient::{Grad, ForwardNeumann}; -use alg_tools::convex::Zero; -use alg_tools::maputil::map3; -use alg_tools::direct_product::Pair; - -use crate::kernels::*; -use crate::types::*; -use crate::measures::*; -use crate::measures::merging::{SpikeMerging,SpikeMergingMethod}; -use crate::forward_model::*; +use crate::fb::{pointsource_fb_reg, pointsource_fista_reg, FBConfig, InsertionConfig}; use crate::forward_model::sensor_grid::{ - SensorGrid, - SensorGridBT, //SensorGridBTFN, Sensor, + SensorGrid, + SensorGridBT, Spread, }; - -use crate::fb::{ - FBConfig, - FBGenericConfig, - pointsource_fb_reg, - pointsource_fista_reg, +use crate::forward_model::*; +use crate::forward_pdps::{pointsource_fb_pair, pointsource_forward_pdps_pair, ForwardPDPSConfig}; +use crate::frank_wolfe::{pointsource_fw_reg, FWConfig, FWVariant, RegTermFW}; +use crate::kernels::*; +use crate::measures::merging::{SpikeMerging, SpikeMergingMethod}; +use crate::measures::*; +use crate::pdps::{pointsource_pdps_reg, PDPSConfig}; +use crate::plot::*; +use crate::prox_penalty::{ + ProxPenalty, ProxTerm, RadonSquared, StepLengthBound, StepLengthBoundPD, StepLengthBoundPair, }; -use crate::sliding_fb::{ - SlidingFBConfig, - TransportConfig, - pointsource_sliding_fb_reg -}; +use crate::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation, SlidingRegTerm}; +use crate::seminorms::*; +use crate::sliding_fb::{pointsource_sliding_fb_reg, SlidingFBConfig, TransportConfig}; use crate::sliding_pdps::{ - SlidingPDPSConfig, - pointsource_sliding_pdps_pair -}; -use crate::forward_pdps::{ - ForwardPDPSConfig, - pointsource_forward_pdps_pair + pointsource_sliding_fb_pair, pointsource_sliding_pdps_pair, SlidingPDPSConfig, }; -use crate::pdps::{ - PDPSConfig, - pointsource_pdps_reg, -}; -use crate::frank_wolfe::{ - FWConfig, - FWVariant, - pointsource_fw_reg, - //WeightOptim, +use crate::subproblem::{InnerMethod, InnerSettings}; +use crate::tolerance::Tolerance; +use crate::types::*; +use crate::{AlgorithmOverrides, CommandLineArgs}; +use alg_tools::bisection_tree::*; +use alg_tools::bounds::{Bounded, MinMaxMapping}; +use alg_tools::convex::{Conjugable, Norm222, Prox, Zero}; +use alg_tools::direct_product::Pair; +use alg_tools::discrete_gradient::{ForwardNeumann, Grad}; +use alg_tools::error::{DynError, DynResult}; +use alg_tools::euclidean::{ClosedEuclidean, Euclidean}; +use alg_tools::iterate::{ + AlgIteratorFactory, AlgIteratorOptions, BasicAlgIteratorFactory, LoggingIteratorFactory, Timed, + TimingIteratorFactory, ValueIteratorFactory, Verbose, }; -use crate::subproblem::{InnerSettings, InnerMethod}; -use crate::seminorms::*; -use crate::plot::*; -use crate::{AlgorithmOverrides, CommandLineArgs}; -use crate::tolerance::Tolerance; -use crate::regularisation::{ - Regularisation, - RadonRegTerm, - NonnegRadonRegTerm -}; -use crate::dataterm::{ - L1, - L2Squared, +use alg_tools::lingrid::lingrid; +use alg_tools::linops::{IdOp, RowOp, AXPY}; +use alg_tools::logger::Logger; +use alg_tools::mapping::{ + DataTerm, DifferentiableMapping, DifferentiableRealMapping, Instance, RealMapping, }; -use crate::prox_penalty::{ - RadonSquared, - //ProxPenalty, -}; -use alg_tools::norms::{L2, NormExponent}; +use alg_tools::maputil::map3; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::norms::{NormExponent, L1, L2}; use alg_tools::operator_arithmetic::Weighted; +use alg_tools::sets::Cube; +use alg_tools::sets::SetOrd; +use alg_tools::tabledump::TableDump; use anyhow::anyhow; +use chrono::{DateTime, Utc}; +use clap::ValueEnum; +use colored::Colorize; +use cpu_time::ProcessTime; +use nalgebra::base::DVector; +use numeric_literals::replace_float_literals; +use rand::prelude::{SeedableRng, StdRng}; +use rand_distr::Distribution; +use serde::{Deserialize, Serialize}; +use serde_json; +use std::collections::HashMap; +use std::hash::Hash; +use std::time::Instant; +use thiserror::Error; -/// Available proximal terms -#[derive(Copy, Clone, Debug, Serialize, Deserialize)] -pub enum ProxTerm { - /// Partial-to-wave operator 𝒟. - Wave, - /// Radon-norm squared - RadonSquared -} +//#[cfg(feature = "pyo3")] +//use pyo3::pyclass; /// Available algorithms and their configurations #[derive(Copy, Clone, Debug, Serialize, Deserialize)] -pub enum AlgorithmConfig { +pub enum AlgorithmConfig { FB(FBConfig, ProxTerm), FISTA(FBConfig, ProxTerm), FW(FWConfig), @@ -138,91 +86,114 @@ SlidingPDPS(SlidingPDPSConfig, ProxTerm), } -fn unpack_tolerance(v : &Vec) -> Tolerance { +fn unpack_tolerance(v: &Vec) -> Tolerance { assert!(v.len() == 3); - Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } + Tolerance::Power { initial: v[0], factor: v[1], exponent: v[2] } } -impl AlgorithmConfig { +impl AlgorithmConfig { /// Override supported parameters based on the command line. - pub fn cli_override(self, cli : &AlgorithmOverrides) -> Self { - let override_merging = |g : SpikeMergingMethod| { - SpikeMergingMethod { - enabled : cli.merge.unwrap_or(g.enabled), - radius : cli.merge_radius.unwrap_or(g.radius), - interp : cli.merge_interp.unwrap_or(g.interp), - } + pub fn cli_override(self, cli: &AlgorithmOverrides) -> Self { + let override_merging = |g: SpikeMergingMethod| SpikeMergingMethod { + enabled: cli.merge.unwrap_or(g.enabled), + radius: cli.merge_radius.unwrap_or(g.radius), + interp: cli.merge_interp.unwrap_or(g.interp), }; - let override_fb_generic = |g : FBGenericConfig| { - FBGenericConfig { - bootstrap_insertions : cli.bootstrap_insertions - .as_ref() - .map_or(g.bootstrap_insertions, - |n| Some((n[0], n[1]))), - merge_every : cli.merge_every.unwrap_or(g.merge_every), - merging : override_merging(g.merging), - final_merging : cli.final_merging.unwrap_or(g.final_merging), - fitness_merging : cli.fitness_merging.unwrap_or(g.fitness_merging), - tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance), - .. g - } + let override_fb_generic = |g: InsertionConfig| InsertionConfig { + bootstrap_insertions: cli + .bootstrap_insertions + .as_ref() + .map_or(g.bootstrap_insertions, |n| Some((n[0], n[1]))), + merge_every: cli.merge_every.unwrap_or(g.merge_every), + merging: override_merging(g.merging), + final_merging: cli.final_merging.unwrap_or(g.final_merging), + fitness_merging: cli.fitness_merging.unwrap_or(g.fitness_merging), + tolerance: cli + .tolerance + .as_ref() + .map(unpack_tolerance) + .unwrap_or(g.tolerance), + ..g }; - let override_transport = |g : TransportConfig| { - TransportConfig { - θ0 : cli.theta0.unwrap_or(g.θ0), - tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con), - adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), - .. g - } + let override_transport = |g: TransportConfig| TransportConfig { + θ0: cli.theta0.unwrap_or(g.θ0), + tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con), + adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), + ..g }; use AlgorithmConfig::*; match self { - FB(fb, prox) => FB(FBConfig { - τ0 : cli.tau0.unwrap_or(fb.τ0), - generic : override_fb_generic(fb.generic), - .. fb - }, prox), - FISTA(fb, prox) => FISTA(FBConfig { - τ0 : cli.tau0.unwrap_or(fb.τ0), - generic : override_fb_generic(fb.generic), - .. fb - }, prox), - PDPS(pdps, prox) => PDPS(PDPSConfig { - τ0 : cli.tau0.unwrap_or(pdps.τ0), - σ0 : cli.sigma0.unwrap_or(pdps.σ0), - acceleration : cli.acceleration.unwrap_or(pdps.acceleration), - generic : override_fb_generic(pdps.generic), - .. pdps - }, prox), + FB(fb, prox) => FB( + FBConfig { + τ0: cli.tau0.unwrap_or(fb.τ0), + σp0: cli.sigmap0.unwrap_or(fb.σp0), + insertion: override_fb_generic(fb.insertion), + ..fb + }, + prox, + ), + FISTA(fb, prox) => FISTA( + FBConfig { + τ0: cli.tau0.unwrap_or(fb.τ0), + σp0: cli.sigmap0.unwrap_or(fb.σp0), + insertion: override_fb_generic(fb.insertion), + ..fb + }, + prox, + ), + PDPS(pdps, prox) => PDPS( + PDPSConfig { + τ0: cli.tau0.unwrap_or(pdps.τ0), + σ0: cli.sigma0.unwrap_or(pdps.σ0), + acceleration: cli.acceleration.unwrap_or(pdps.acceleration), + generic: override_fb_generic(pdps.generic), + ..pdps + }, + prox, + ), FW(fw) => FW(FWConfig { - merging : override_merging(fw.merging), - tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), - .. fw + merging: override_merging(fw.merging), + tolerance: cli + .tolerance + .as_ref() + .map(unpack_tolerance) + .unwrap_or(fw.tolerance), + ..fw }), - SlidingFB(sfb, prox) => SlidingFB(SlidingFBConfig { - τ0 : cli.tau0.unwrap_or(sfb.τ0), - transport : override_transport(sfb.transport), - insertion : override_fb_generic(sfb.insertion), - .. sfb - }, prox), - SlidingPDPS(spdps, prox) => SlidingPDPS(SlidingPDPSConfig { - τ0 : cli.tau0.unwrap_or(spdps.τ0), - σp0 : cli.sigmap0.unwrap_or(spdps.σp0), - σd0 : cli.sigma0.unwrap_or(spdps.σd0), - //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), - transport : override_transport(spdps.transport), - insertion : override_fb_generic(spdps.insertion), - .. spdps - }, prox), - ForwardPDPS(fpdps, prox) => ForwardPDPS(ForwardPDPSConfig { - τ0 : cli.tau0.unwrap_or(fpdps.τ0), - σp0 : cli.sigmap0.unwrap_or(fpdps.σp0), - σd0 : cli.sigma0.unwrap_or(fpdps.σd0), - //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), - insertion : override_fb_generic(fpdps.insertion), - .. fpdps - }, prox), + SlidingFB(sfb, prox) => SlidingFB( + SlidingFBConfig { + τ0: cli.tau0.unwrap_or(sfb.τ0), + σp0: cli.sigmap0.unwrap_or(sfb.σp0), + transport: override_transport(sfb.transport), + insertion: override_fb_generic(sfb.insertion), + ..sfb + }, + prox, + ), + SlidingPDPS(spdps, prox) => SlidingPDPS( + SlidingPDPSConfig { + τ0: cli.tau0.unwrap_or(spdps.τ0), + σp0: cli.sigmap0.unwrap_or(spdps.σp0), + σd0: cli.sigma0.unwrap_or(spdps.σd0), + //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), + transport: override_transport(spdps.transport), + insertion: override_fb_generic(spdps.insertion), + ..spdps + }, + prox, + ), + ForwardPDPS(fpdps, prox) => ForwardPDPS( + ForwardPDPSConfig { + τ0: cli.tau0.unwrap_or(fpdps.τ0), + σp0: cli.sigmap0.unwrap_or(fpdps.σp0), + σd0: cli.sigma0.unwrap_or(fpdps.σd0), + //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), + insertion: override_fb_generic(fpdps.insertion), + ..fpdps + }, + prox, + ), } } } @@ -230,13 +201,14 @@ /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Named { - pub name : String, + pub name: String, #[serde(flatten)] - pub data : Data, + pub data: Data, } /// Shorthand algorithm configurations, to be used with the command line parser #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] +//#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))] pub enum DefaultAlgorithm { /// The μFB forward-backward method #[clap(name = "fb")] @@ -264,7 +236,6 @@ ForwardPDPS, // Radon variants - /// The μFB forward-backward method with radon-norm squared proximal term #[clap(name = "radon_fb")] RadonFB, @@ -287,70 +258,70 @@ impl DefaultAlgorithm { /// Returns the algorithm configuration corresponding to the algorithm shorthand - pub fn default_config(&self) -> AlgorithmConfig { + pub fn default_config(&self) -> AlgorithmConfig { use DefaultAlgorithm::*; - let radon_insertion = FBGenericConfig { - merging : SpikeMergingMethod{ interp : false, .. Default::default() }, - inner : InnerSettings { - method : InnerMethod::PDPS, // SSN not implemented - .. Default::default() + let radon_insertion = InsertionConfig { + merging: SpikeMergingMethod { interp: false, ..Default::default() }, + inner: InnerSettings { + method: InnerMethod::PDPS, // SSN not implemented + ..Default::default() }, - .. Default::default() + ..Default::default() }; match *self { FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave), FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave), FW => AlgorithmConfig::FW(Default::default()), - FWRelax => AlgorithmConfig::FW(FWConfig{ - variant : FWVariant::Relaxed, - .. Default::default() - }), + FWRelax => { + AlgorithmConfig::FW(FWConfig { variant: FWVariant::Relaxed, ..Default::default() }) + } PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave), SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave), SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave), ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave), // Radon variants - RadonFB => AlgorithmConfig::FB( - FBConfig{ generic : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + FBConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonFISTA => AlgorithmConfig::FISTA( - FBConfig{ generic : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + FBConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonPDPS => AlgorithmConfig::PDPS( - PDPSConfig{ generic : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + PDPSConfig { generic: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonSlidingFB => AlgorithmConfig::SlidingFB( - SlidingFBConfig{ insertion : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + SlidingFBConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS( - SlidingPDPSConfig{ insertion : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + SlidingPDPSConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonForwardPDPS => AlgorithmConfig::ForwardPDPS( - ForwardPDPSConfig{ insertion : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + ForwardPDPSConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), } } /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand - pub fn get_named(&self) -> Named> { + pub fn get_named(&self) -> Named> { self.to_named(self.default_config()) } - pub fn to_named(self, alg : AlgorithmConfig) -> Named> { - let name = self.to_possible_value().unwrap().get_name().to_string(); - Named{ name , data : alg } + pub fn to_named(self, alg: AlgorithmConfig) -> Named> { + Named { name: self.name(), data: alg } + } + + pub fn name(self) -> String { + self.to_possible_value().unwrap().get_name().to_string() } } - // // Floats cannot be hashed directly, so just hash the debug formatting // // for use as file identifier. // impl Hash for AlgorithmConfig { @@ -379,347 +350,389 @@ } } -type DefaultBT = BT< - DynamicDepth, - F, - usize, - Bounds, - N ->; -type DefaultSeminormOp = ConvolutionOp, N>; -type DefaultSG = SensorGrid::< - F, - Sensor, - Spread, - DefaultBT, - N ->; +type DefaultBT = BT, N>; +type DefaultSeminormOp = ConvolutionOp, N>; +type DefaultSG = + SensorGrid, N>; /// This is a dirty workaround to rust-csv not supporting struct flattening etc. #[derive(Serialize)] struct CSVLog { - iter : usize, - cpu_time : f64, - value : F, - relative_value : F, + iter: usize, + cpu_time: f64, + value: F, + relative_value: F, //post_value : F, - n_spikes : usize, - inner_iters : usize, - merged : usize, - pruned : usize, - this_iters : usize, + n_spikes: usize, + inner_iters: usize, + merged: usize, + pruned: usize, + this_iters: usize, + epsilon: F, } /// Collected experiment statistics #[derive(Clone, Debug, Serialize)] -struct ExperimentStats { +struct ExperimentStats { /// Signal-to-noise ratio in decibels - ssnr : F, + ssnr: F, /// Proportion of noise in the signal as a number in $[0, 1]$. - noise_ratio : F, + noise_ratio: F, /// When the experiment was run (UTC) - when : DateTime, + when: DateTime, } #[replace_float_literals(F::cast_from(literal))] -impl ExperimentStats { +impl ExperimentStats { /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. - fn new>(signal : &E, noise : &E) -> Self { + fn new>(signal: &E, noise: &E) -> Self { let s = signal.norm2_squared(); let n = noise.norm2_squared(); let noise_ratio = (n / s).sqrt(); - let ssnr = 10.0 * (s / n).log10(); - ExperimentStats { - ssnr, - noise_ratio, - when : Utc::now(), - } + let ssnr = 10.0 * (s / n).log10(); + ExperimentStats { ssnr, noise_ratio, when: Utc::now() } } } /// Collected algorithm statistics #[derive(Clone, Debug, Serialize)] -struct AlgorithmStats { +struct AlgorithmStats { /// Overall CPU time spent - cpu_time : F, + cpu_time: F, /// Real time spent - elapsed : F + elapsed: F, } - /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input /// and outputs a [`DynError`]. -fn write_json(filename : String, data : &T) -> DynError { +fn write_json(filename: String, data: &T) -> DynError { serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; Ok(()) } - /// Struct for experiment configurations #[derive(Debug, Clone, Serialize)] -pub struct ExperimentV2 -where F : Float + ClapFloat, - [usize; N] : Serialize, - NoiseDistr : Distribution, - S : Sensor, - P : Spread, - K : SimpleConvolutionKernel, +pub struct ExperimentV2 +where + F: Float + ClapFloat, + [usize; N]: Serialize, + NoiseDistr: Distribution, + S: Sensor, + P: Spread, + K: SimpleConvolutionKernel, { /// Domain $Ω$. - pub domain : Cube, + pub domain: Cube, /// Number of sensors along each dimension - pub sensor_count : [usize; N], + pub sensor_count: [usize; N], /// Noise distribution - pub noise_distr : NoiseDistr, + pub noise_distr: NoiseDistr, /// Seed for random noise generation (for repeatable experiments) - pub noise_seed : u64, + pub noise_seed: u64, /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. - pub sensor : S, + pub sensor: S, /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. - pub spread : P, + pub spread: P, /// Kernel $ρ$ of $𝒟$. - pub kernel : K, + pub kernel: K, /// True point sources - pub μ_hat : RNDM, + pub μ_hat: RNDM, /// Regularisation term and parameter - pub regularisation : Regularisation, + pub regularisation: Regularisation, /// For plotting : how wide should the kernels be plotted - pub kernel_plot_width : F, + pub kernel_plot_width: F, /// Data term - pub dataterm : DataTerm, + pub dataterm: DataTermType, /// A map of default configurations for algorithms - pub algorithm_overrides : HashMap>, + pub algorithm_overrides: HashMap>, /// Default merge radius - pub default_merge_radius : F, + pub default_merge_radius: F, } #[derive(Debug, Clone, Serialize)] -pub struct ExperimentBiased -where F : Float + ClapFloat, - [usize; N] : Serialize, - NoiseDistr : Distribution, - S : Sensor, - P : Spread, - K : SimpleConvolutionKernel, - B : Mapping, Codomain = F> + Serialize + std::fmt::Debug, +pub struct ExperimentBiased +where + F: Float + ClapFloat, + [usize; N]: Serialize, + NoiseDistr: Distribution, + S: Sensor, + P: Spread, + K: SimpleConvolutionKernel, + B: Mapping, Codomain = F> + Serialize + std::fmt::Debug, { /// Basic setup - pub base : ExperimentV2, + pub base: ExperimentV2, /// Weight of TV term - pub λ : F, + pub λ: F, /// Bias function - pub bias : B, + pub bias: B, } /// Trait for runnable experiments -pub trait RunnableExperiment { +pub trait RunnableExperiment { /// Run all algorithms provided, or default algorithms if none provided, on the experiment. - fn runall(&self, cli : &CommandLineArgs, - algs : Option>>>) -> DynError; + fn runall( + &self, + cli: &CommandLineArgs, + algs: Option>>>, + ) -> DynError; /// Return algorithm default config - fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides; -} - -/// Helper function to print experiment start message and save setup. -/// Returns saving prefix. -fn start_experiment( - experiment : &Named, - cli : &CommandLineArgs, - stats : S, -) -> DynResult -where - E : Serialize + std::fmt::Debug, - S : Serialize, -{ - let Named { name : experiment_name, data } = experiment; + fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides; - println!("{}\n{}", - format!("Performing experiment {}…", experiment_name).cyan(), - format!("Experiment settings: {}", serde_json::to_string(&data)?).bright_black()); - - // Set up output directory - let prefix = format!("{}/{}/", cli.outdir, experiment_name); - - // Save experiment configuration and statistics - let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); - std::fs::create_dir_all(&prefix)?; - write_json(mkname_e("experiment"), experiment)?; - write_json(mkname_e("config"), cli)?; - write_json(mkname_e("stats"), &stats)?; - - Ok(prefix) + /// Experiment name + fn name(&self) -> &str; } /// Error codes for running an algorithm on an experiment. -enum RunError { +#[derive(Error, Debug)] +pub enum RunError { /// Algorithm not implemented for this experiment + #[error("Algorithm not implemented for this experiment")] NotImplemented, } use RunError::*; -type DoRunAllIt<'a, F, const N : usize> = LoggingIteratorFactory< +type DoRunAllIt<'a, F, const N: usize> = LoggingIteratorFactory< 'a, - Timed>, - TimingIteratorFactory>> + Timed>, + TimingIteratorFactory>>, >; -/// Helper function to run all algorithms on an experiment. -fn do_runall Deserialize<'b>, Z, const N : usize>( - experiment_name : &String, - prefix : &String, - cli : &CommandLineArgs, - algorithms : Vec>>, - plotgrid : LinSpace, [usize; N]>, - mut save_extra : impl FnMut(String, Z) -> DynError, - mut do_alg : impl FnMut( - &AlgorithmConfig, - DoRunAllIt, - SeqPlotter, - String, - ) -> Result<(RNDM, Z), RunError>, -) -> DynError -where - PlotLookup : Plotting, +pub trait RunnableExperimentExtras: + RunnableExperiment + Serialize + Sized { - let mut logs = Vec::new(); + /// Helper function to print experiment start message and save setup. + /// Returns saving prefix. + fn start(&self, cli: &CommandLineArgs) -> DynResult { + let experiment_name = self.name(); + let ser = serde_json::to_string(self); - let iterator_options = AlgIteratorOptions{ - max_iter : cli.max_iter, - verbose_iter : cli.verbose_iter - .map_or(Verbose::LogarithmicCap{base : 10, cap : 2}, - |n| Verbose::Every(n)), - quiet : cli.quiet, - }; + println!( + "{}\n{}", + format!("Performing experiment {}…", experiment_name).cyan(), + format!( + "Experiment settings: {}", + if let Ok(ref s) = ser { + s + } else { + "" + } + ) + .bright_black(), + ); - // Run the algorithm(s) - for named @ Named { name : alg_name, data : alg } in algorithms.iter() { - let this_prefix = format!("{}{}/", prefix, alg_name); + // Set up output directory + let prefix = format!("{}/{}/", cli.outdir, experiment_name); + + // Save experiment configuration and statistics + std::fs::create_dir_all(&prefix)?; + write_json(format!("{prefix}experiment.json"), self)?; + write_json(format!("{prefix}config.json"), cli)?; - // Create Logger and IteratorFactory - let mut logger = Logger::new(); - let iterator = iterator_options.instantiate() - .timed() - .into_log(&mut logger); + Ok(prefix) + } - let running = if !cli.quiet { - format!("{}\n{}\n{}\n", - format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), - format!("Iteration settings: {}", serde_json::to_string(&iterator_options)?).bright_black(), - format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black()) - } else { - "".to_string() - }; - // - // The following is for postprocessing, which has been disabled anyway. - // - // let reg : Box> = match regularisation { - // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), - // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), - // }; - //let findim_data = reg.prepare_optimise_weights(&opA, &b); - //let inner_config : InnerSettings = Default::default(); - //let inner_it = inner_config.iterator_options; + /// Helper function to run all algorithms on an experiment. + fn do_runall( + &self, + prefix: &String, + cli: &CommandLineArgs, + algorithms: Vec>>, + mut make_plotter: impl FnMut(String) -> Plot, + mut save_extra: impl FnMut(String, Z) -> DynError, + init: P, + mut do_alg: impl FnMut( + (&AlgorithmConfig, DoRunAllIt, Plot, P, String), + ) -> DynResult<(RNDM, Z)>, + ) -> DynError + where + F: for<'b> Deserialize<'b>, + PlotLookup: Plotting, + P: Clone, + { + let experiment_name = self.name(); - // Create plotter and directory if needed. - let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; - let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()); - - let start = Instant::now(); - let start_cpu = ProcessTime::now(); + let mut logs = Vec::new(); - let (μ, z) = match do_alg(alg, iterator, plotter, running) { - Ok(μ) => μ, - Err(RunError::NotImplemented) => { - let msg = format!("Algorithm “{alg_name}” not implemented for {experiment_name}. \ - Skipping.").red(); - eprintln!("{}", msg); - continue - } + let iterator_options = AlgIteratorOptions { + max_iter: cli.max_iter, + verbose_iter: cli + .verbose_iter + .map_or(Verbose::LogarithmicCap { base: 10, cap: 2 }, |n| { + Verbose::Every(n) + }), + quiet: cli.quiet, }; - let elapsed = start.elapsed().as_secs_f64(); - let cpu_time = start_cpu.elapsed().as_secs_f64(); + // Run the algorithm(s) + for named @ Named { name: alg_name, data: alg } in algorithms.iter() { + let this_prefix = format!("{}{}/", prefix, alg_name); - println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); + // Create Logger and IteratorFactory + let mut logger = Logger::new(); + let iterator = iterator_options.instantiate().timed().into_log(&mut logger); - // Save results - println!("{}", "Saving results …".green()); + let running = if !cli.quiet { + format!( + "{}\n{}\n{}\n", + format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), + format!( + "Iteration settings: {}", + serde_json::to_string(&iterator_options)? + ) + .bright_black(), + format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black() + ) + } else { + "".to_string() + }; + // + // The following is for postprocessing, which has been disabled anyway. + // + // let reg : Box> = match regularisation { + // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), + // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), + // }; + //let findim_data = reg.prepare_optimise_weights(&opA, &b); + //let inner_config : InnerSettings = Default::default(); + //let inner_it = inner_config.iterator_options; - let mkname = |t| format!("{prefix}{alg_name}_{t}"); + // Create plotter and directory if needed. + let plotter = make_plotter(this_prefix); + + let start = Instant::now(); + let start_cpu = ProcessTime::now(); - write_json(mkname("config.json"), &named)?; - write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; - μ.write_csv(mkname("reco.txt"))?; - save_extra(mkname(""), z)?; - //logger.write_csv(mkname("log.txt"))?; - logs.push((mkname("log.txt"), logger)); + let (μ, z) = match do_alg((alg, iterator, plotter, init.clone(), running)) { + Ok(μ) => μ, + Err(e) => { + let msg = format!( + "Skipping algorithm “{alg_name}” for {experiment_name} due to error: {e}" + ) + .red(); + eprintln!("{}", msg); + continue; + } + }; + + let elapsed = start.elapsed().as_secs_f64(); + let cpu_time = start_cpu.elapsed().as_secs_f64(); + + println!( + "{}", + format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow() + ); + + // Save results + println!("{}", "Saving results …".green()); + + let mkname = |t| format!("{prefix}{alg_name}_{t}"); + + write_json(mkname("config.json"), &named)?; + write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; + μ.write_csv(mkname("reco.txt"))?; + save_extra(mkname(""), z)?; + //logger.write_csv(mkname("log.txt"))?; + logs.push((mkname("log.txt"), logger)); + } + + save_logs( + logs, + format!("{prefix}valuerange.json"), + cli.load_valuerange, + ) } +} - save_logs(logs, format!("{prefix}valuerange.json"), cli.load_valuerange) +impl RunnableExperimentExtras for E +where + F: ClapFloat, + Self: RunnableExperiment + Serialize, +{ } #[replace_float_literals(F::cast_from(literal))] -impl RunnableExperiment for -Named> +impl RunnableExperiment + for Named> where - F : ClapFloat + nalgebra::RealField + ToNalgebraRealField - + Default + for<'b> Deserialize<'b>, - [usize; N] : Serialize, - S : Sensor + Copy + Serialize + std::fmt::Debug, - P : Spread + Copy + Serialize + std::fmt::Debug, - Convolution: Spread + Bounded + LocalAnalysis, N> + Copy - // TODO: shold not have differentiability as a requirement, but - // decide availability of sliding based on it. - //+ for<'b> Differentiable<&'b Loc, Output = Loc>, - // TODO: very weird that rust only compiles with Differentiable - // instead of the above one on references, which is required by - // poitsource_sliding_fb_reg. - + DifferentiableRealMapping - + Lipschitz, - for<'b> as DifferentiableMapping>>::Differential<'b> : Lipschitz, // TODO: should not be required generally, only for sliding_fb. - AutoConvolution

: BoundedBy, - K : SimpleConvolutionKernel + F: ClapFloat + + nalgebra::RealField + + ToNalgebraRealField + + Default + + for<'b> Deserialize<'b>, + [usize; N]: Serialize, + S: Sensor + Copy + Serialize + std::fmt::Debug, + P: Spread + Copy + Serialize + std::fmt::Debug, + Convolution: Spread + + Bounded + LocalAnalysis, N> - + Copy + Serialize + std::fmt::Debug, - Cube: P2Minimise, F> + SetOrd, - PlotLookup : Plotting, - DefaultBT : SensorGridBT + BTSearch, + + Copy + // TODO: shold not have differentiability as a requirement, but + // decide availability of sliding based on it. + //+ for<'b> Differentiable<&'b Loc, Output = Loc>, + // TODO: very weird that rust only compiles with Differentiable + // instead of the above one on references, which is required by + // poitsource_sliding_fb_reg. + + DifferentiableRealMapping + + Lipschitz, + for<'b> as DifferentiableMapping>>::Differential<'b>: + Lipschitz, // TODO: should not be required generally, only for sliding_fb. + AutoConvolution

: BoundedBy, + K: SimpleConvolutionKernel + + LocalAnalysis, N> + + Copy + + Serialize + + std::fmt::Debug, + Cube: P2Minimise, F> + SetOrd, + PlotLookup: Plotting, + DefaultBT: SensorGridBT + BTSearch, BTNodeLookup: BTNode, N>, - RNDM : SpikeMerging, - NoiseDistr : Distribution + Serialize + std::fmt::Debug, - // DefaultSG : ForwardModel, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector>, - // PreadjointCodomain : Space + Bounded + DifferentiableRealMapping, - // DefaultSeminormOp : ProxPenalty, N>, - // DefaultSeminormOp : ProxPenalty, N>, - // RadonSquared : ProxPenalty, N>, - // RadonSquared : ProxPenalty, N>, + RNDM: SpikeMerging, + NoiseDistr: Distribution + Serialize + std::fmt::Debug, { + fn name(&self) -> &str { + self.name.as_ref() + } - fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides { + fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides { AlgorithmOverrides { - merge_radius : Some(self.data.default_merge_radius), - .. self.data.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) + merge_radius: Some(self.data.default_merge_radius), + ..self + .data + .algorithm_overrides + .get(&alg) + .cloned() + .unwrap_or(Default::default()) } } - fn runall(&self, cli : &CommandLineArgs, - algs : Option>>>) -> DynError { + fn runall( + &self, + cli: &CommandLineArgs, + algs: Option>>>, + ) -> DynError { // Get experiment configuration - let &Named { - name : ref experiment_name, - data : ExperimentV2 { - domain, sensor_count, ref noise_distr, sensor, spread, kernel, - ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, - .. - } - } = self; + let &ExperimentV2 { + domain, + sensor_count, + ref noise_distr, + sensor, + spread, + kernel, + ref μ_hat, + regularisation, + kernel_plot_width, + dataterm, + noise_seed, + .. + } = &self.data; // Set up algorithms let algorithms = match (algs, dataterm) { (Some(algs), _) => algs, - (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], - (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], + (None, DataTermType::L222) => vec![DefaultAlgorithm::FB.get_named()], + (None, DataTermType::L1) => vec![DefaultAlgorithm::PDPS.get_named()], }; // Set up operators @@ -738,255 +751,348 @@ // overloading log10 and conflicting with standard NumTraits one. let stats = ExperimentStats::new(&b, &noise); - let prefix = start_experiment(&self, cli, stats)?; + let prefix = self.start(cli)?; + write_json(format!("{prefix}stats.json"), &stats)?; - plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, - &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; + plotall( + cli, + &prefix, + &domain, + &sensor, + &kernel, + &spread, + &μ_hat, + &op𝒟, + &opA, + &b_hat, + &b, + kernel_plot_width, + )?; - let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); + let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]); + let make_plotter = |this_prefix| { + let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; + SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()) + }; let save_extra = |_, ()| Ok(()); - do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, - |alg, iterator, plotter, running| - { - let μ = match alg { - AlgorithmConfig::FB(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - _ => Err(NotImplemented) - } - }, - AlgorithmConfig::FISTA(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - _ => Err(NotImplemented), - } - }, - AlgorithmConfig::SlidingFB(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - _ => Err(NotImplemented), - } - }, - AlgorithmConfig::PDPS(ref algconfig, prox) => { - print!("{running}"); - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L1 - ) - }), - (Regularisation::Radon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L1 - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L1 - ) - }), - (Regularisation::Radon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L1 - ) - }), - // _ => Err(NotImplemented), - } - }, - AlgorithmConfig::FW(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ - print!("{running}"); - pointsource_fw_reg(&opA, &b, RadonRegTerm(α), - algconfig, iterator, plotter) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ - print!("{running}"); - pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), - algconfig, iterator, plotter) - }), - _ => Err(NotImplemented), - } - }, - _ => Err(NotImplemented), - }?; - Ok((μ, ())) - }) + let μ0 = None; // Zero init + + match (dataterm, regularisation) { + (DataTermType::L1, Regularisation::Radon(α)) => { + let f = DataTerm::new(opA, b, L1.as_mapping()); + let reg = RadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |_| Err(NotImplemented.into())) + }) + .map(|μ| (μ, ())) + }, + ) + } + (DataTermType::L1, Regularisation::NonnegRadon(α)) => { + let f = DataTerm::new(opA, b, L1.as_mapping()); + let reg = NonnegRadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |_| Err(NotImplemented.into())) + }) + .map(|μ| (μ, ())) + }, + ) + } + (DataTermType::L222, Regularisation::Radon(α)) => { + let f = DataTerm::new(opA, b, Norm222::new()); + let reg = RadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_fb(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_fb(&f, ®, &op𝒟, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |p| { + run_fw(&f, ®, p, |_| Err(NotImplemented.into())) + }) + }) + }) + }) + .map(|μ| (μ, ())) + }, + ) + } + (DataTermType::L222, Regularisation::NonnegRadon(α)) => { + let f = DataTerm::new(opA, b, Norm222::new()); + let reg = NonnegRadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_fb(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_fb(&f, ®, &op𝒟, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |p| { + run_fw(&f, ®, p, |_| Err(NotImplemented.into())) + }) + }) + }) + }) + .map(|μ| (μ, ())) + }, + ) + } + } + } +} + +/// Runs PDPS if `alg` so requests and `prox_penalty` matches. +/// +/// Due to the structure of the PDPS, the data term `f` has to have a specific form. +/// +/// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. +pub fn run_pdps<'a, F, A, Phi, Reg, P, I, Plot, const N: usize>( + f: &'a DataTerm, A, Phi>, + reg: &Reg, + prox_penalty: &P, + (alg, iterator, plotter, μ0, running): ( + &AlgorithmConfig, + I, + Plot, + Option>, + String, + ), + cont: impl FnOnce( + (&AlgorithmConfig, I, Plot, Option>, String), + ) -> DynResult>, +) -> DynResult> +where + F: Float + ToNalgebraRealField, + A: ForwardModel, F>, + Phi: Conjugable, + for<'b> Phi::Conjugate<'b>: Prox, + for<'b> &'b A::Observable: Instance, + A::Observable: AXPY, + Reg: SlidingRegTerm, F>, + P: ProxPenalty, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD>, + RNDM: SpikeMerging, + I: AlgIteratorFactory>, + Plot: Plotter>, +{ + match alg { + &AlgorithmConfig::PDPS(ref algconfig, prox_type) if prox_type == P::prox_type() => { + print!("{running}"); + pointsource_pdps_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + _ => cont((alg, iterator, plotter, μ0, running)), } } +/// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches. +/// +/// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. +pub fn run_fb( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + (alg, iterator, plotter, μ0, running): ( + &AlgorithmConfig, + I, + Plot, + Option>, + String, + ), + cont: impl FnOnce( + (&AlgorithmConfig, I, Plot, Option>, String), + ) -> DynResult>, +) -> DynResult> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F> + BoundedCurvature, + Dat::DerivativeDomain: DifferentiableRealMapping + ClosedMul, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, F>, + P: ProxPenalty, Dat::DerivativeDomain, Reg, F> + StepLengthBound, + Plot: Plotter>, +{ + let pt = P::prox_type(); + + match alg { + &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + &AlgorithmConfig::FISTA(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_fista_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_sliding_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + _ => cont((alg, iterator, plotter, μ0, running)), + } +} + +/// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches. +/// +/// For the moment, due to restrictions of the Frank–Wolfe implementation, only the +/// $L^2$-squared data term is enabled through the type signatures. +/// +/// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. +pub fn run_fw<'a, F, A, Reg, I, Plot, const N: usize>( + f: &'a DataTerm, A, Norm222>, + reg: &Reg, + (alg, iterator, plotter, μ0, running): ( + &AlgorithmConfig, + I, + Plot, + Option>, + String, + ), + cont: impl FnOnce( + (&AlgorithmConfig, I, Plot, Option>, String), + ) -> DynResult>, +) -> DynResult> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + A: ForwardModel, F>, + A::PreadjointCodomain: MinMaxMapping, F>, + for<'b> &'b A::PreadjointCodomain: Instance, + Cube: P2Minimise, F>, + RNDM: SpikeMerging, + Reg: RegTermFW, N>, + Plot: Plotter>, +{ + match alg { + &AlgorithmConfig::FW(ref algconfig) => { + print!("{running}"); + pointsource_fw_reg(f, reg, algconfig, iterator, plotter, μ0) + } + _ => cont((alg, iterator, plotter, μ0, running)), + } +} #[replace_float_literals(F::cast_from(literal))] -impl RunnableExperiment for -Named> +impl RunnableExperiment + for Named> where - F : ClapFloat + nalgebra::RealField + ToNalgebraRealField - + Default + for<'b> Deserialize<'b>, - [usize; N] : Serialize, - S : Sensor + Copy + Serialize + std::fmt::Debug, - P : Spread + Copy + Serialize + std::fmt::Debug, - Convolution: Spread + Bounded + LocalAnalysis, N> + Copy - // TODO: shold not have differentiability as a requirement, but - // decide availability of sliding based on it. - //+ for<'b> Differentiable<&'b Loc, Output = Loc>, - // TODO: very weird that rust only compiles with Differentiable - // instead of the above one on references, which is required by - // poitsource_sliding_fb_reg. - + DifferentiableRealMapping - + Lipschitz, - for<'b> as DifferentiableMapping>>::Differential<'b> : Lipschitz, // TODO: should not be required generally, only for sliding_fb. - AutoConvolution

: BoundedBy, - K : SimpleConvolutionKernel + F: ClapFloat + + nalgebra::RealField + + ToNalgebraRealField + + Default + + for<'b> Deserialize<'b>, + [usize; N]: Serialize, + S: Sensor + Copy + Serialize + std::fmt::Debug, + P: Spread + Copy + Serialize + std::fmt::Debug, + Convolution: Spread + + Bounded + LocalAnalysis, N> - + Copy + Serialize + std::fmt::Debug, - Cube: P2Minimise, F> + SetOrd, - PlotLookup : Plotting, - DefaultBT : SensorGridBT + BTSearch, + + Copy + // TODO: shold not have differentiability as a requirement, but + // decide availability of sliding based on it. + //+ for<'b> Differentiable<&'b Loc, Output = Loc>, + // TODO: very weird that rust only compiles with Differentiable + // instead of the above one on references, which is required by + // poitsource_sliding_fb_reg. + + DifferentiableRealMapping + + Lipschitz, + for<'b> as DifferentiableMapping>>::Differential<'b>: + Lipschitz, // TODO: should not be required generally, only for sliding_fb. + AutoConvolution

: BoundedBy, + K: SimpleConvolutionKernel + + LocalAnalysis, N> + + Copy + + Serialize + + std::fmt::Debug, + Cube: P2Minimise, F> + SetOrd, + PlotLookup: Plotting, + DefaultBT: SensorGridBT + BTSearch, BTNodeLookup: BTNode, N>, - RNDM : SpikeMerging, - NoiseDistr : Distribution + Serialize + std::fmt::Debug, - B : Mapping, Codomain = F> + Serialize + std::fmt::Debug, - // DefaultSG : ForwardModel, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector>, - // PreadjointCodomain : Bounded + DifferentiableRealMapping, + RNDM: SpikeMerging, + NoiseDistr: Distribution + Serialize + std::fmt::Debug, + B: Mapping, Codomain = F> + Serialize + std::fmt::Debug, + nalgebra::DVector: ClosedMul, + // This is mainly required for the final Mul requirement to be defined + // DefaultSG: ForwardModel< + // RNDM, + // F, + // PreadjointCodomain = PreadjointCodomain, + // Observable = DVector, + // >, + // PreadjointCodomain: Bounded + DifferentiableRealMapping + std::ops::Mul, + // Pair>: std::ops::Mul, // DefaultSeminormOp : ProxPenalty, N>, // DefaultSeminormOp : ProxPenalty, N>, // RadonSquared : ProxPenalty, N>, // RadonSquared : ProxPenalty, N>, { + fn name(&self) -> &str { + self.name.as_ref() + } - fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides { + fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides { AlgorithmOverrides { - merge_radius : Some(self.data.base.default_merge_radius), - .. self.data.base.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) + merge_radius: Some(self.data.base.default_merge_radius), + ..self + .data + .base + .algorithm_overrides + .get(&alg) + .cloned() + .unwrap_or(Default::default()) } } - fn runall(&self, cli : &CommandLineArgs, - algs : Option>>>) -> DynError { + fn runall( + &self, + cli: &CommandLineArgs, + algs: Option>>>, + ) -> DynError { // Get experiment configuration - let &Named { - name : ref experiment_name, - data : ExperimentBiased { - λ, - ref bias, - base : ExperimentV2 { - domain, sensor_count, ref noise_distr, sensor, spread, kernel, - ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, + let &ExperimentBiased { + λ, + ref bias, + base: + ExperimentV2 { + domain, + sensor_count, + ref noise_distr, + sensor, + spread, + kernel, + ref μ_hat, + regularisation, + kernel_plot_width, + dataterm, + noise_seed, .. - } - } - } = self; + }, + } = &self.data; // Set up algorithms let algorithms = match (algs, dataterm) { @@ -1000,173 +1106,304 @@ let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); let opAext = RowOp(opA.clone(), IdOp::new()); let fnR = Zero::new(); - let h = map3(domain.span_start(), domain.span_end(), sensor_count, - |a, b, n| (b-a)/F::cast_from(n)) - .into_iter() - .reduce(NumTraitsFloat::max) - .unwrap(); + let h = map3( + domain.span_start(), + domain.span_end(), + sensor_count, + |a, b, n| (b - a) / F::cast_from(n), + ) + .into_iter() + .reduce(NumTraitsFloat::max) + .unwrap(); let z = DVector::zeros(sensor_count.iter().product()); let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap(); let y = opKz.apply(&z); - let fnH = Weighted{ base_fn : L1.as_mapping(), weight : λ}; // TODO: L_{2,1} - // let zero_y = y.clone(); - // let zeroBTFN = opA.preadjoint().apply(&zero_y); - // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); + let fnH = Weighted { base_fn: L1.as_mapping(), weight: λ }; // TODO: L_{2,1} + // let zero_y = y.clone(); + // let zeroBTFN = opA.preadjoint().apply(&zero_y); + // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); // Set up random number generator. let mut rng = StdRng::seed_from_u64(noise_seed); // Generate the data and calculate SSNR statistic - let bias_vec = DVector::from_vec(opA.grid() - .into_iter() - .map(|v| bias.apply(v)) - .collect::>()); - let b_hat : DVector<_> = opA.apply(μ_hat) + &bias_vec; + let bias_vec = DVector::from_vec( + opA.grid() + .into_iter() + .map(|v| bias.apply(v)) + .collect::>(), + ); + let b_hat: DVector<_> = opA.apply(μ_hat) + &bias_vec; let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); let b = &b_hat + &noise; // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField // overloading log10 and conflicting with standard NumTraits one. let stats = ExperimentStats::new(&b, &noise); - let prefix = start_experiment(&self, cli, stats)?; + let prefix = self.start(cli)?; + write_json(format!("{prefix}stats.json"), &stats)?; - plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, - &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; + plotall( + cli, + &prefix, + &domain, + &sensor, + &kernel, + &spread, + &μ_hat, + &op𝒟, + &opA, + &b_hat, + &b, + kernel_plot_width, + )?; opA.write_observable(&bias_vec, format!("{prefix}bias"))?; - let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); + let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]); + let make_plotter = |this_prefix| { + let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; + SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()) + }; let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z")); - // Run the algorithms - do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, - |alg, iterator, plotter, running| - { - let Pair(μ, z) = match alg { - AlgorithmConfig::ForwardPDPS(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - _ => Err(NotImplemented) - } - }, - AlgorithmConfig::SlidingPDPS(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - _ => Err(NotImplemented) - } - }, - _ => Err(NotImplemented) - }?; - Ok((μ, z)) - }) + let μ0 = None; // Zero init + + match (dataterm, regularisation) { + (DataTermType::L222, Regularisation::Radon(α)) => { + let f = DataTerm::new(opAext, b, Norm222::new()); + let reg = RadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + (μ0, z, y), + |p| { + run_pdps_pair(&f, ®, &RadonSquared, &opKz, &fnR, &fnH, p, |q| { + run_pdps_pair(&f, ®, &op𝒟, &opKz, &fnR, &fnH, q, |_| { + Err(NotImplemented.into()) + }) + }) + .map(|Pair(μ, z)| (μ, z)) + }, + ) + } + (DataTermType::L222, Regularisation::NonnegRadon(α)) => { + let f = DataTerm::new(opAext, b, Norm222::new()); + let reg = NonnegRadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + (μ0, z, y), + |p| { + run_pdps_pair(&f, ®, &RadonSquared, &opKz, &fnR, &fnH, p, |q| { + run_pdps_pair(&f, ®, &op𝒟, &opKz, &fnR, &fnH, q, |_| { + Err(NotImplemented.into()) + }) + }) + .map(|Pair(μ, z)| (μ, z)) + }, + ) + } + _ => Err(NotImplemented.into()), + } + } +} + +type MeasureZ = Pair, Z>; + +pub fn run_pdps_pair( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + opKz: &KOpZ, + fnR: &R, + fnH: &H, + (alg, iterator, plotter, μ0zy, running): ( + &AlgorithmConfig, + I, + Plot, + (Option>, Z, Y), + String, + ), + cont: impl FnOnce( + ( + &AlgorithmConfig, + I, + Plot, + (Option>, Z, Y), + String, + ), + ) -> DynResult, Z>>, +) -> DynResult, Z>> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F, DerivativeDomain = Pair> + + BoundedCurvature, + S: DifferentiableRealMapping + ClosedMul, + for<'a> Pair<&'a P, &'a IdOp>: StepLengthBoundPair, + //Pair: ClosedMul, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, F>, + P: ProxPenalty, S, Reg, F>, + KOpZ: BoundedLinear + + GEMV + + SimplyAdjointable, + KOpZ::SimpleAdjoint: GEMV, + Y: ClosedEuclidean + Clone, + for<'b> &'b Y: Instance, + Z: ClosedEuclidean + Clone + ClosedMul, + for<'b> &'b Z: Instance, + R: Prox, + H: Conjugable, + for<'b> H::Conjugate<'b>: Prox, + Plot: Plotter>, +{ + let pt = P::prox_type(); + + match alg { + &AlgorithmConfig::ForwardPDPS(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_forward_pdps_pair( + f, + reg, + prox_penalty, + algconfig, + iterator, + plotter, + μ0zy, + opKz, + fnR, + fnH, + ) + } + &AlgorithmConfig::SlidingPDPS(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_sliding_pdps_pair( + f, + reg, + prox_penalty, + algconfig, + iterator, + plotter, + μ0zy, + opKz, + fnR, + fnH, + ) + } + _ => cont((alg, iterator, plotter, μ0zy, running)), + } +} + +pub fn run_fb_pair( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + fnR: &R, + (alg, iterator, plotter, μ0z, running): ( + &AlgorithmConfig, + I, + Plot, + (Option>, Z), + String, + ), + cont: impl FnOnce( + ( + &AlgorithmConfig, + I, + Plot, + (Option>, Z), + String, + ), + ) -> DynResult, Z>>, +) -> DynResult, Z>> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory>, + Dat: DifferentiableMapping, Codomain = F, DerivativeDomain = Pair> + + BoundedCurvature, + S: DifferentiableRealMapping + ClosedMul, + RNDM: SpikeMerging, + Reg: SlidingRegTerm, F>, + P: ProxPenalty, S, Reg, F>, + for<'a> Pair<&'a P, &'a IdOp>: StepLengthBoundPair, + Z: ClosedEuclidean + AXPY + Clone, + for<'b> &'b Z: Instance, + R: Prox, + Plot: Plotter>, + // We should not need to explicitly require this: + for<'b> &'b Loc<0, F>: Instance>, +{ + let pt = P::prox_type(); + + match alg { + &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_fb_pair(f, reg, prox_penalty, algconfig, iterator, plotter, μ0z, fnR) + } + &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_sliding_fb_pair( + f, + reg, + prox_penalty, + algconfig, + iterator, + plotter, + μ0z, + fnR, + ) + } + _ => cont((alg, iterator, plotter, μ0z, running)), } } #[derive(Copy, Clone, Debug, Serialize, Deserialize)] -struct ValueRange { - ini : F, - min : F, +struct ValueRange { + ini: F, + min: F, } -impl ValueRange { - fn expand_with(self, other : Self) -> Self { - ValueRange { - ini : self.ini.max(other.ini), - min : self.min.min(other.min), - } +impl ValueRange { + fn expand_with(self, other: Self) -> Self { + ValueRange { ini: self.ini.max(other.ini), min: self.min.min(other.min) } } } /// Calculative minimum and maximum values of all the `logs`, and save them into /// corresponding file names given as the first elements of the tuples in the vectors. -fn save_logs Deserialize<'b>, const N : usize>( - logs : Vec<(String, Logger>>)>, - valuerange_file : String, - load_valuerange : bool, +fn save_logs Deserialize<'b>>( + logs: Vec<(String, Logger>>)>, + valuerange_file: String, + load_valuerange: bool, ) -> DynError { // Process logs for relative values println!("{}", "Processing logs…"); // Find minimum value and initial value within a single log - let proc_single_log = |log : &Logger>>| { + let proc_single_log = |log: &Logger>>| { let d = log.data(); - let mi = d.iter() - .map(|i| i.data.value) - .reduce(NumTraitsFloat::min); + let mi = d.iter().map(|i| i.data.value).reduce(NumTraitsFloat::min); d.first() - .map(|i| i.data.value) - .zip(mi) - .map(|(ini, min)| ValueRange{ ini, min }) + .map(|i| i.data.value) + .zip(mi) + .map(|(ini, min)| ValueRange { ini, min }) }; // Find minimum and maximum value over all logs - let mut v = logs.iter() - .filter_map(|&(_, ref log)| proc_single_log(log)) - .reduce(|v1, v2| v1.expand_with(v2)) - .ok_or(anyhow!("No algorithms found"))?; + let mut v = logs + .iter() + .filter_map(|&(_, ref log)| proc_single_log(log)) + .reduce(|v1, v2| v1.expand_with(v2)) + .ok_or(anyhow!("No algorithms found"))?; // Load existing range if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() { @@ -1183,10 +1420,11 @@ pruned, //postprocessing, this_iters, + ε, .. } = data; // let post_value = match (postprocessing, dataterm) { - // (Some(mut μ), DataTerm::L2Squared) => { + // (Some(mut μ), DataTermType::L222) => { // // Comparison postprocessing is only implemented for the case handled // // by the FW variants. // reg.optimise_weights( @@ -1198,18 +1436,19 @@ // }, // _ => value, // }; - let relative_value = (value - v.min)/(v.ini - v.min); + let relative_value = (value - v.min) / (v.ini - v.min); CSVLog { iter, value, relative_value, //post_value, n_spikes, - cpu_time : cpu_time.as_secs_f64(), + cpu_time: cpu_time.as_secs_f64(), inner_iters, merged, pruned, - this_iters + this_iters, + epsilon: ε, } }; @@ -1224,45 +1463,48 @@ Ok(()) } - /// Plot experiment setup #[replace_float_literals(F::cast_from(literal))] -fn plotall( - cli : &CommandLineArgs, - prefix : &String, - domain : &Cube, - sensor : &Sensor, - kernel : &Kernel, - spread : &Spread, - μ_hat : &RNDM, - op𝒟 : &𝒟, - opA : &A, - b_hat : &A::Observable, - b : &A::Observable, - kernel_plot_width : F, +fn plotall( + cli: &CommandLineArgs, + prefix: &String, + domain: &Cube, + sensor: &Sensor, + kernel: &Kernel, + spread: &Spread, + μ_hat: &RNDM, + op𝒟: &𝒟, + opA: &A, + b_hat: &A::Observable, + b: &A::Observable, + kernel_plot_width: F, ) -> DynError -where F : Float + ToNalgebraRealField, - Sensor : RealMapping + Support + Clone, - Spread : RealMapping + Support + Clone, - Kernel : RealMapping + Support, - Convolution : DifferentiableRealMapping + Support, - 𝒟 : DiscreteMeasureOp, F>, - 𝒟::Codomain : RealMapping, - A : ForwardModel, F>, - for<'a> &'a A::Observable : Instance, - A::PreadjointCodomain : DifferentiableRealMapping + Bounded, - PlotLookup : Plotting, - Cube : SetOrd { - +where + F: Float + ToNalgebraRealField, + Sensor: RealMapping + Support + Clone, + Spread: RealMapping + Support + Clone, + Kernel: RealMapping + Support, + Convolution: DifferentiableRealMapping + Support, + 𝒟: DiscreteMeasureOp, F>, + 𝒟::Codomain: RealMapping, + A: ForwardModel, F>, + for<'a> &'a A::Observable: Instance, + A::PreadjointCodomain: DifferentiableRealMapping + Bounded, + PlotLookup: Plotting, + Cube: SetOrd, +{ if cli.plot < PlotLevel::Data { - return Ok(()) + return Ok(()); } let base = Convolution(sensor.clone(), spread.clone()); - let resolution = if N==1 { 100 } else { 40 }; + let resolution = if N == 1 { 100 } else { 40 }; let pfx = |n| format!("{prefix}{n}"); - let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); + let plotgrid = lingrid( + &[[-kernel_plot_width, kernel_plot_width]; N].into(), + &[resolution; N], + ); PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor")); PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel")); @@ -1272,19 +1514,19 @@ let plotgrid2 = lingrid(&domain, &[resolution; N]); let ω_hat = op𝒟.apply(μ_hat); - let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); + let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat")); PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise")); - let preadj_b = opA.preadjoint().apply(b); - let preadj_b_hat = opA.preadjoint().apply(b_hat); + let preadj_b = opA.preadjoint().apply(b); + let preadj_b_hat = opA.preadjoint().apply(b_hat); //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); PlotLookup::plot_into_file_spikes( Some(&preadj_b), Some(&preadj_b_hat), plotgrid2, &μ_hat, - pfx("omega_b") + pfx("omega_b"), ); PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b")); PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"));