--- a/src/run.rs Sun Apr 27 15:03:51 2025 -0500 +++ b/src/run.rs Thu Feb 26 11:38:43 2026 -0500 @@ -2,133 +2,81 @@ This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. */ -use numeric_literals::replace_float_literals; -use colored::Colorize; -use serde::{Serialize, Deserialize}; -use serde_json; -use nalgebra::base::DVector; -use std::hash::Hash; -use chrono::{DateTime, Utc}; -use cpu_time::ProcessTime; -use clap::ValueEnum; -use std::collections::HashMap; -use std::time::Instant; - -use rand::prelude::{ - StdRng, - SeedableRng -}; -use rand_distr::Distribution; - -use alg_tools::bisection_tree::*; -use alg_tools::iterate::{ - Timed, - AlgIteratorOptions, - Verbose, - AlgIteratorFactory, - LoggingIteratorFactory, - TimingIteratorFactory, - BasicAlgIteratorFactory, -}; -use alg_tools::logger::Logger; -use alg_tools::error::{ - DynError, - DynResult, -}; -use alg_tools::tabledump::TableDump; -use alg_tools::sets::Cube; -use alg_tools::mapping::{ - RealMapping, - DifferentiableMapping, - DifferentiableRealMapping, - Instance -}; -use alg_tools::nalgebra_support::ToNalgebraRealField; -use alg_tools::euclidean::Euclidean; -use alg_tools::lingrid::{lingrid, LinSpace}; -use alg_tools::sets::SetOrd; -use alg_tools::linops::{RowOp, IdOp /*, ZeroOp*/}; -use alg_tools::discrete_gradient::{Grad, ForwardNeumann}; -use alg_tools::convex::Zero; -use alg_tools::maputil::map3; -use alg_tools::direct_product::Pair; - -use crate::kernels::*; -use crate::types::*; -use crate::measures::*; -use crate::measures::merging::{SpikeMerging,SpikeMergingMethod}; -use crate::forward_model::*; +use crate::fb::{pointsource_fb_reg, pointsource_fista_reg, FBConfig, InsertionConfig}; use crate::forward_model::sensor_grid::{ - SensorGrid, - SensorGridBT, //SensorGridBTFN, Sensor, + SensorGrid, + SensorGridBT, Spread, }; - -use crate::fb::{ - FBConfig, - FBGenericConfig, - pointsource_fb_reg, - pointsource_fista_reg, +use crate::forward_model::*; +use crate::forward_pdps::{pointsource_fb_pair, pointsource_forward_pdps_pair, ForwardPDPSConfig}; +use crate::frank_wolfe::{pointsource_fw_reg, FWConfig, FWVariant, RegTermFW}; +use crate::kernels::*; +use crate::measures::merging::{SpikeMerging, SpikeMergingMethod}; +use crate::measures::*; +use crate::pdps::{pointsource_pdps_reg, PDPSConfig}; +use crate::plot::*; +use crate::prox_penalty::{ + ProxPenalty, ProxTerm, RadonSquared, StepLengthBound, StepLengthBoundPD, StepLengthBoundPair, }; -use crate::sliding_fb::{ - SlidingFBConfig, - TransportConfig, - pointsource_sliding_fb_reg -}; +use crate::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation, SlidingRegTerm}; +use crate::seminorms::*; +use crate::sliding_fb::{pointsource_sliding_fb_reg, SlidingFBConfig, TransportConfig}; use crate::sliding_pdps::{ - SlidingPDPSConfig, - pointsource_sliding_pdps_pair -}; -use crate::forward_pdps::{ - ForwardPDPSConfig, - pointsource_forward_pdps_pair + pointsource_sliding_fb_pair, pointsource_sliding_pdps_pair, SlidingPDPSConfig, }; -use crate::pdps::{ - PDPSConfig, - pointsource_pdps_reg, -}; -use crate::frank_wolfe::{ - FWConfig, - FWVariant, - pointsource_fw_reg, - //WeightOptim, +use crate::subproblem::{InnerMethod, InnerSettings}; +use crate::tolerance::Tolerance; +use crate::types::*; +use crate::{AlgorithmOverrides, CommandLineArgs}; +use alg_tools::bisection_tree::*; +use alg_tools::bounds::{Bounded, MinMaxMapping}; +use alg_tools::convex::{Conjugable, Norm222, Prox, Zero}; +use alg_tools::direct_product::Pair; +use alg_tools::discrete_gradient::{ForwardNeumann, Grad}; +use alg_tools::error::{DynError, DynResult}; +use alg_tools::euclidean::{ClosedEuclidean, Euclidean}; +use alg_tools::iterate::{ + AlgIteratorFactory, AlgIteratorOptions, BasicAlgIteratorFactory, LoggingIteratorFactory, Timed, + TimingIteratorFactory, ValueIteratorFactory, Verbose, }; -use crate::subproblem::{InnerSettings, InnerMethod}; -use crate::seminorms::*; -use crate::plot::*; -use crate::{AlgorithmOverrides, CommandLineArgs}; -use crate::tolerance::Tolerance; -use crate::regularisation::{ - Regularisation, - RadonRegTerm, - NonnegRadonRegTerm -}; -use crate::dataterm::{ - L1, - L2Squared, +use alg_tools::lingrid::lingrid; +use alg_tools::linops::{IdOp, RowOp, AXPY}; +use alg_tools::logger::Logger; +use alg_tools::mapping::{ + DataTerm, DifferentiableMapping, DifferentiableRealMapping, Instance, RealMapping, }; -use crate::prox_penalty::{ - RadonSquared, - //ProxPenalty, -}; -use alg_tools::norms::{L2, NormExponent}; +use alg_tools::maputil::map3; +use alg_tools::nalgebra_support::ToNalgebraRealField; +use alg_tools::norms::{NormExponent, L1, L2}; use alg_tools::operator_arithmetic::Weighted; +use alg_tools::sets::Cube; +use alg_tools::sets::SetOrd; +use alg_tools::tabledump::TableDump; use anyhow::anyhow; +use chrono::{DateTime, Utc}; +use clap::ValueEnum; +use colored::Colorize; +use cpu_time::ProcessTime; +use nalgebra::base::DVector; +use numeric_literals::replace_float_literals; +use rand::prelude::{SeedableRng, StdRng}; +use rand_distr::Distribution; +use serde::{Deserialize, Serialize}; +use serde_json; +use std::collections::HashMap; +use std::hash::Hash; +use std::time::Instant; +use thiserror::Error; -/// Available proximal terms -#[derive(Copy, Clone, Debug, Serialize, Deserialize)] -pub enum ProxTerm { - /// Partial-to-wave operator 𝒟. - Wave, - /// Radon-norm squared - RadonSquared -} +//#[cfg(feature = "pyo3")] +//use pyo3::pyclass; /// Available algorithms and their configurations #[derive(Copy, Clone, Debug, Serialize, Deserialize)] -pub enum AlgorithmConfig<F : Float> { +pub enum AlgorithmConfig<F: Float> { FB(FBConfig<F>, ProxTerm), FISTA(FBConfig<F>, ProxTerm), FW(FWConfig<F>), @@ -138,91 +86,114 @@ SlidingPDPS(SlidingPDPSConfig<F>, ProxTerm), } -fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { +fn unpack_tolerance<F: Float>(v: &Vec<F>) -> Tolerance<F> { assert!(v.len() == 3); - Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } + Tolerance::Power { initial: v[0], factor: v[1], exponent: v[2] } } -impl<F : ClapFloat> AlgorithmConfig<F> { +impl<F: ClapFloat> AlgorithmConfig<F> { /// Override supported parameters based on the command line. - pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { - let override_merging = |g : SpikeMergingMethod<F>| { - SpikeMergingMethod { - enabled : cli.merge.unwrap_or(g.enabled), - radius : cli.merge_radius.unwrap_or(g.radius), - interp : cli.merge_interp.unwrap_or(g.interp), - } + pub fn cli_override(self, cli: &AlgorithmOverrides<F>) -> Self { + let override_merging = |g: SpikeMergingMethod<F>| SpikeMergingMethod { + enabled: cli.merge.unwrap_or(g.enabled), + radius: cli.merge_radius.unwrap_or(g.radius), + interp: cli.merge_interp.unwrap_or(g.interp), }; - let override_fb_generic = |g : FBGenericConfig<F>| { - FBGenericConfig { - bootstrap_insertions : cli.bootstrap_insertions - .as_ref() - .map_or(g.bootstrap_insertions, - |n| Some((n[0], n[1]))), - merge_every : cli.merge_every.unwrap_or(g.merge_every), - merging : override_merging(g.merging), - final_merging : cli.final_merging.unwrap_or(g.final_merging), - fitness_merging : cli.fitness_merging.unwrap_or(g.fitness_merging), - tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance), - .. g - } + let override_fb_generic = |g: InsertionConfig<F>| InsertionConfig { + bootstrap_insertions: cli + .bootstrap_insertions + .as_ref() + .map_or(g.bootstrap_insertions, |n| Some((n[0], n[1]))), + merge_every: cli.merge_every.unwrap_or(g.merge_every), + merging: override_merging(g.merging), + final_merging: cli.final_merging.unwrap_or(g.final_merging), + fitness_merging: cli.fitness_merging.unwrap_or(g.fitness_merging), + tolerance: cli + .tolerance + .as_ref() + .map(unpack_tolerance) + .unwrap_or(g.tolerance), + ..g }; - let override_transport = |g : TransportConfig<F>| { - TransportConfig { - θ0 : cli.theta0.unwrap_or(g.θ0), - tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con), - adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), - .. g - } + let override_transport = |g: TransportConfig<F>| TransportConfig { + θ0: cli.theta0.unwrap_or(g.θ0), + tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con), + adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), + ..g }; use AlgorithmConfig::*; match self { - FB(fb, prox) => FB(FBConfig { - τ0 : cli.tau0.unwrap_or(fb.τ0), - generic : override_fb_generic(fb.generic), - .. fb - }, prox), - FISTA(fb, prox) => FISTA(FBConfig { - τ0 : cli.tau0.unwrap_or(fb.τ0), - generic : override_fb_generic(fb.generic), - .. fb - }, prox), - PDPS(pdps, prox) => PDPS(PDPSConfig { - τ0 : cli.tau0.unwrap_or(pdps.τ0), - σ0 : cli.sigma0.unwrap_or(pdps.σ0), - acceleration : cli.acceleration.unwrap_or(pdps.acceleration), - generic : override_fb_generic(pdps.generic), - .. pdps - }, prox), + FB(fb, prox) => FB( + FBConfig { + τ0: cli.tau0.unwrap_or(fb.τ0), + σp0: cli.sigmap0.unwrap_or(fb.σp0), + insertion: override_fb_generic(fb.insertion), + ..fb + }, + prox, + ), + FISTA(fb, prox) => FISTA( + FBConfig { + τ0: cli.tau0.unwrap_or(fb.τ0), + σp0: cli.sigmap0.unwrap_or(fb.σp0), + insertion: override_fb_generic(fb.insertion), + ..fb + }, + prox, + ), + PDPS(pdps, prox) => PDPS( + PDPSConfig { + τ0: cli.tau0.unwrap_or(pdps.τ0), + σ0: cli.sigma0.unwrap_or(pdps.σ0), + acceleration: cli.acceleration.unwrap_or(pdps.acceleration), + generic: override_fb_generic(pdps.generic), + ..pdps + }, + prox, + ), FW(fw) => FW(FWConfig { - merging : override_merging(fw.merging), - tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), - .. fw + merging: override_merging(fw.merging), + tolerance: cli + .tolerance + .as_ref() + .map(unpack_tolerance) + .unwrap_or(fw.tolerance), + ..fw }), - SlidingFB(sfb, prox) => SlidingFB(SlidingFBConfig { - τ0 : cli.tau0.unwrap_or(sfb.τ0), - transport : override_transport(sfb.transport), - insertion : override_fb_generic(sfb.insertion), - .. sfb - }, prox), - SlidingPDPS(spdps, prox) => SlidingPDPS(SlidingPDPSConfig { - τ0 : cli.tau0.unwrap_or(spdps.τ0), - σp0 : cli.sigmap0.unwrap_or(spdps.σp0), - σd0 : cli.sigma0.unwrap_or(spdps.σd0), - //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), - transport : override_transport(spdps.transport), - insertion : override_fb_generic(spdps.insertion), - .. spdps - }, prox), - ForwardPDPS(fpdps, prox) => ForwardPDPS(ForwardPDPSConfig { - τ0 : cli.tau0.unwrap_or(fpdps.τ0), - σp0 : cli.sigmap0.unwrap_or(fpdps.σp0), - σd0 : cli.sigma0.unwrap_or(fpdps.σd0), - //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), - insertion : override_fb_generic(fpdps.insertion), - .. fpdps - }, prox), + SlidingFB(sfb, prox) => SlidingFB( + SlidingFBConfig { + τ0: cli.tau0.unwrap_or(sfb.τ0), + σp0: cli.sigmap0.unwrap_or(sfb.σp0), + transport: override_transport(sfb.transport), + insertion: override_fb_generic(sfb.insertion), + ..sfb + }, + prox, + ), + SlidingPDPS(spdps, prox) => SlidingPDPS( + SlidingPDPSConfig { + τ0: cli.tau0.unwrap_or(spdps.τ0), + σp0: cli.sigmap0.unwrap_or(spdps.σp0), + σd0: cli.sigma0.unwrap_or(spdps.σd0), + //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), + transport: override_transport(spdps.transport), + insertion: override_fb_generic(spdps.insertion), + ..spdps + }, + prox, + ), + ForwardPDPS(fpdps, prox) => ForwardPDPS( + ForwardPDPSConfig { + τ0: cli.tau0.unwrap_or(fpdps.τ0), + σp0: cli.sigmap0.unwrap_or(fpdps.σp0), + σd0: cli.sigma0.unwrap_or(fpdps.σd0), + //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), + insertion: override_fb_generic(fpdps.insertion), + ..fpdps + }, + prox, + ), } } } @@ -230,13 +201,14 @@ /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Named<Data> { - pub name : String, + pub name: String, #[serde(flatten)] - pub data : Data, + pub data: Data, } /// Shorthand algorithm configurations, to be used with the command line parser #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] +//#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))] pub enum DefaultAlgorithm { /// The μFB forward-backward method #[clap(name = "fb")] @@ -264,7 +236,6 @@ ForwardPDPS, // Radon variants - /// The μFB forward-backward method with radon-norm squared proximal term #[clap(name = "radon_fb")] RadonFB, @@ -287,70 +258,70 @@ impl DefaultAlgorithm { /// Returns the algorithm configuration corresponding to the algorithm shorthand - pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { + pub fn default_config<F: Float>(&self) -> AlgorithmConfig<F> { use DefaultAlgorithm::*; - let radon_insertion = FBGenericConfig { - merging : SpikeMergingMethod{ interp : false, .. Default::default() }, - inner : InnerSettings { - method : InnerMethod::PDPS, // SSN not implemented - .. Default::default() + let radon_insertion = InsertionConfig { + merging: SpikeMergingMethod { interp: false, ..Default::default() }, + inner: InnerSettings { + method: InnerMethod::PDPS, // SSN not implemented + ..Default::default() }, - .. Default::default() + ..Default::default() }; match *self { FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave), FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave), FW => AlgorithmConfig::FW(Default::default()), - FWRelax => AlgorithmConfig::FW(FWConfig{ - variant : FWVariant::Relaxed, - .. Default::default() - }), + FWRelax => { + AlgorithmConfig::FW(FWConfig { variant: FWVariant::Relaxed, ..Default::default() }) + } PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave), SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave), SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave), ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave), // Radon variants - RadonFB => AlgorithmConfig::FB( - FBConfig{ generic : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + FBConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonFISTA => AlgorithmConfig::FISTA( - FBConfig{ generic : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + FBConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonPDPS => AlgorithmConfig::PDPS( - PDPSConfig{ generic : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + PDPSConfig { generic: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonSlidingFB => AlgorithmConfig::SlidingFB( - SlidingFBConfig{ insertion : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + SlidingFBConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS( - SlidingPDPSConfig{ insertion : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + SlidingPDPSConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), RadonForwardPDPS => AlgorithmConfig::ForwardPDPS( - ForwardPDPSConfig{ insertion : radon_insertion, ..Default::default() }, - ProxTerm::RadonSquared + ForwardPDPSConfig { insertion: radon_insertion, ..Default::default() }, + ProxTerm::RadonSquared, ), } } /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand - pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { + pub fn get_named<F: Float>(&self) -> Named<AlgorithmConfig<F>> { self.to_named(self.default_config()) } - pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { - let name = self.to_possible_value().unwrap().get_name().to_string(); - Named{ name , data : alg } + pub fn to_named<F: Float>(self, alg: AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { + Named { name: self.name(), data: alg } + } + + pub fn name(self) -> String { + self.to_possible_value().unwrap().get_name().to_string() } } - // // Floats cannot be hashed directly, so just hash the debug formatting // // for use as file identifier. // impl<F : Float> Hash for AlgorithmConfig<F> { @@ -379,347 +350,389 @@ } } -type DefaultBT<F, const N : usize> = BT< - DynamicDepth, - F, - usize, - Bounds<F>, - N ->; -type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>; -type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::< - F, - Sensor, - Spread, - DefaultBT<F, N>, - N ->; +type DefaultBT<F, const N: usize> = BT<DynamicDepth, F, usize, Bounds<F>, N>; +type DefaultSeminormOp<F, K, const N: usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>; +type DefaultSG<F, Sensor, Spread, const N: usize> = + SensorGrid<F, Sensor, Spread, DefaultBT<F, N>, N>; /// This is a dirty workaround to rust-csv not supporting struct flattening etc. #[derive(Serialize)] struct CSVLog<F> { - iter : usize, - cpu_time : f64, - value : F, - relative_value : F, + iter: usize, + cpu_time: f64, + value: F, + relative_value: F, //post_value : F, - n_spikes : usize, - inner_iters : usize, - merged : usize, - pruned : usize, - this_iters : usize, + n_spikes: usize, + inner_iters: usize, + merged: usize, + pruned: usize, + this_iters: usize, + epsilon: F, } /// Collected experiment statistics #[derive(Clone, Debug, Serialize)] -struct ExperimentStats<F : Float> { +struct ExperimentStats<F: Float> { /// Signal-to-noise ratio in decibels - ssnr : F, + ssnr: F, /// Proportion of noise in the signal as a number in $[0, 1]$. - noise_ratio : F, + noise_ratio: F, /// When the experiment was run (UTC) - when : DateTime<Utc>, + when: DateTime<Utc>, } #[replace_float_literals(F::cast_from(literal))] -impl<F : Float> ExperimentStats<F> { +impl<F: Float> ExperimentStats<F> { /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. - fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self { + fn new<E: Euclidean<F>>(signal: &E, noise: &E) -> Self { let s = signal.norm2_squared(); let n = noise.norm2_squared(); let noise_ratio = (n / s).sqrt(); - let ssnr = 10.0 * (s / n).log10(); - ExperimentStats { - ssnr, - noise_ratio, - when : Utc::now(), - } + let ssnr = 10.0 * (s / n).log10(); + ExperimentStats { ssnr, noise_ratio, when: Utc::now() } } } /// Collected algorithm statistics #[derive(Clone, Debug, Serialize)] -struct AlgorithmStats<F : Float> { +struct AlgorithmStats<F: Float> { /// Overall CPU time spent - cpu_time : F, + cpu_time: F, /// Real time spent - elapsed : F + elapsed: F, } - /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input /// and outputs a [`DynError`]. -fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { +fn write_json<T: Serialize>(filename: String, data: &T) -> DynError { serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; Ok(()) } - /// Struct for experiment configurations #[derive(Debug, Clone, Serialize)] -pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize> -where F : Float + ClapFloat, - [usize; N] : Serialize, - NoiseDistr : Distribution<F>, - S : Sensor<F, N>, - P : Spread<F, N>, - K : SimpleConvolutionKernel<F, N>, +pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N: usize> +where + F: Float + ClapFloat, + [usize; N]: Serialize, + NoiseDistr: Distribution<F>, + S: Sensor<N, F>, + P: Spread<N, F>, + K: SimpleConvolutionKernel<N, F>, { /// Domain $Ω$. - pub domain : Cube<F, N>, + pub domain: Cube<N, F>, /// Number of sensors along each dimension - pub sensor_count : [usize; N], + pub sensor_count: [usize; N], /// Noise distribution - pub noise_distr : NoiseDistr, + pub noise_distr: NoiseDistr, /// Seed for random noise generation (for repeatable experiments) - pub noise_seed : u64, + pub noise_seed: u64, /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. - pub sensor : S, + pub sensor: S, /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. - pub spread : P, + pub spread: P, /// Kernel $ρ$ of $𝒟$. - pub kernel : K, + pub kernel: K, /// True point sources - pub μ_hat : RNDM<F, N>, + pub μ_hat: RNDM<N, F>, /// Regularisation term and parameter - pub regularisation : Regularisation<F>, + pub regularisation: Regularisation<F>, /// For plotting : how wide should the kernels be plotted - pub kernel_plot_width : F, + pub kernel_plot_width: F, /// Data term - pub dataterm : DataTerm, + pub dataterm: DataTermType, /// A map of default configurations for algorithms - pub algorithm_overrides : HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>, + pub algorithm_overrides: HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>, /// Default merge radius - pub default_merge_radius : F, + pub default_merge_radius: F, } #[derive(Debug, Clone, Serialize)] -pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize> -where F : Float + ClapFloat, - [usize; N] : Serialize, - NoiseDistr : Distribution<F>, - S : Sensor<F, N>, - P : Spread<F, N>, - K : SimpleConvolutionKernel<F, N>, - B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, +pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N: usize> +where + F: Float + ClapFloat, + [usize; N]: Serialize, + NoiseDistr: Distribution<F>, + S: Sensor<N, F>, + P: Spread<N, F>, + K: SimpleConvolutionKernel<N, F>, + B: Mapping<Loc<N, F>, Codomain = F> + Serialize + std::fmt::Debug, { /// Basic setup - pub base : ExperimentV2<F, NoiseDistr, S, K, P, N>, + pub base: ExperimentV2<F, NoiseDistr, S, K, P, N>, /// Weight of TV term - pub λ : F, + pub λ: F, /// Bias function - pub bias : B, + pub bias: B, } /// Trait for runnable experiments -pub trait RunnableExperiment<F : ClapFloat> { +pub trait RunnableExperiment<F: ClapFloat> { /// Run all algorithms provided, or default algorithms if none provided, on the experiment. - fn runall(&self, cli : &CommandLineArgs, - algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; + fn runall( + &self, + cli: &CommandLineArgs, + algs: Option<Vec<Named<AlgorithmConfig<F>>>>, + ) -> DynError; /// Return algorithm default config - fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F>; -} - -/// Helper function to print experiment start message and save setup. -/// Returns saving prefix. -fn start_experiment<E, S>( - experiment : &Named<E>, - cli : &CommandLineArgs, - stats : S, -) -> DynResult<String> -where - E : Serialize + std::fmt::Debug, - S : Serialize, -{ - let Named { name : experiment_name, data } = experiment; + fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F>; - println!("{}\n{}", - format!("Performing experiment {}…", experiment_name).cyan(), - format!("Experiment settings: {}", serde_json::to_string(&data)?).bright_black()); - - // Set up output directory - let prefix = format!("{}/{}/", cli.outdir, experiment_name); - - // Save experiment configuration and statistics - let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); - std::fs::create_dir_all(&prefix)?; - write_json(mkname_e("experiment"), experiment)?; - write_json(mkname_e("config"), cli)?; - write_json(mkname_e("stats"), &stats)?; - - Ok(prefix) + /// Experiment name + fn name(&self) -> &str; } /// Error codes for running an algorithm on an experiment. -enum RunError { +#[derive(Error, Debug)] +pub enum RunError { /// Algorithm not implemented for this experiment + #[error("Algorithm not implemented for this experiment")] NotImplemented, } use RunError::*; -type DoRunAllIt<'a, F, const N : usize> = LoggingIteratorFactory< +type DoRunAllIt<'a, F, const N: usize> = LoggingIteratorFactory< 'a, - Timed<IterInfo<F, N>>, - TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F, N>>> + Timed<IterInfo<F>>, + TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F>>>, >; -/// Helper function to run all algorithms on an experiment. -fn do_runall<F : Float + for<'b> Deserialize<'b>, Z, const N : usize>( - experiment_name : &String, - prefix : &String, - cli : &CommandLineArgs, - algorithms : Vec<Named<AlgorithmConfig<F>>>, - plotgrid : LinSpace<Loc<F, N>, [usize; N]>, - mut save_extra : impl FnMut(String, Z) -> DynError, - mut do_alg : impl FnMut( - &AlgorithmConfig<F>, - DoRunAllIt<F, N>, - SeqPlotter<F, N>, - String, - ) -> Result<(RNDM<F, N>, Z), RunError>, -) -> DynError -where - PlotLookup : Plotting<N>, +pub trait RunnableExperimentExtras<F: ClapFloat>: + RunnableExperiment<F> + Serialize + Sized { - let mut logs = Vec::new(); + /// Helper function to print experiment start message and save setup. + /// Returns saving prefix. + fn start(&self, cli: &CommandLineArgs) -> DynResult<String> { + let experiment_name = self.name(); + let ser = serde_json::to_string(self); - let iterator_options = AlgIteratorOptions{ - max_iter : cli.max_iter, - verbose_iter : cli.verbose_iter - .map_or(Verbose::LogarithmicCap{base : 10, cap : 2}, - |n| Verbose::Every(n)), - quiet : cli.quiet, - }; + println!( + "{}\n{}", + format!("Performing experiment {}…", experiment_name).cyan(), + format!( + "Experiment settings: {}", + if let Ok(ref s) = ser { + s + } else { + "<serialisation failure>" + } + ) + .bright_black(), + ); - // Run the algorithm(s) - for named @ Named { name : alg_name, data : alg } in algorithms.iter() { - let this_prefix = format!("{}{}/", prefix, alg_name); + // Set up output directory + let prefix = format!("{}/{}/", cli.outdir, experiment_name); + + // Save experiment configuration and statistics + std::fs::create_dir_all(&prefix)?; + write_json(format!("{prefix}experiment.json"), self)?; + write_json(format!("{prefix}config.json"), cli)?; - // Create Logger and IteratorFactory - let mut logger = Logger::new(); - let iterator = iterator_options.instantiate() - .timed() - .into_log(&mut logger); + Ok(prefix) + } - let running = if !cli.quiet { - format!("{}\n{}\n{}\n", - format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), - format!("Iteration settings: {}", serde_json::to_string(&iterator_options)?).bright_black(), - format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black()) - } else { - "".to_string() - }; - // - // The following is for postprocessing, which has been disabled anyway. - // - // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation { - // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), - // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), - // }; - //let findim_data = reg.prepare_optimise_weights(&opA, &b); - //let inner_config : InnerSettings<F> = Default::default(); - //let inner_it = inner_config.iterator_options; + /// Helper function to run all algorithms on an experiment. + fn do_runall<P, Z, Plot, const N: usize>( + &self, + prefix: &String, + cli: &CommandLineArgs, + algorithms: Vec<Named<AlgorithmConfig<F>>>, + mut make_plotter: impl FnMut(String) -> Plot, + mut save_extra: impl FnMut(String, Z) -> DynError, + init: P, + mut do_alg: impl FnMut( + (&AlgorithmConfig<F>, DoRunAllIt<F, N>, Plot, P, String), + ) -> DynResult<(RNDM<N, F>, Z)>, + ) -> DynError + where + F: for<'b> Deserialize<'b>, + PlotLookup: Plotting<N>, + P: Clone, + { + let experiment_name = self.name(); - // Create plotter and directory if needed. - let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; - let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()); - - let start = Instant::now(); - let start_cpu = ProcessTime::now(); + let mut logs = Vec::new(); - let (μ, z) = match do_alg(alg, iterator, plotter, running) { - Ok(μ) => μ, - Err(RunError::NotImplemented) => { - let msg = format!("Algorithm “{alg_name}” not implemented for {experiment_name}. \ - Skipping.").red(); - eprintln!("{}", msg); - continue - } + let iterator_options = AlgIteratorOptions { + max_iter: cli.max_iter, + verbose_iter: cli + .verbose_iter + .map_or(Verbose::LogarithmicCap { base: 10, cap: 2 }, |n| { + Verbose::Every(n) + }), + quiet: cli.quiet, }; - let elapsed = start.elapsed().as_secs_f64(); - let cpu_time = start_cpu.elapsed().as_secs_f64(); + // Run the algorithm(s) + for named @ Named { name: alg_name, data: alg } in algorithms.iter() { + let this_prefix = format!("{}{}/", prefix, alg_name); - println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); + // Create Logger and IteratorFactory + let mut logger = Logger::new(); + let iterator = iterator_options.instantiate().timed().into_log(&mut logger); - // Save results - println!("{}", "Saving results …".green()); + let running = if !cli.quiet { + format!( + "{}\n{}\n{}\n", + format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), + format!( + "Iteration settings: {}", + serde_json::to_string(&iterator_options)? + ) + .bright_black(), + format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black() + ) + } else { + "".to_string() + }; + // + // The following is for postprocessing, which has been disabled anyway. + // + // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation { + // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), + // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), + // }; + //let findim_data = reg.prepare_optimise_weights(&opA, &b); + //let inner_config : InnerSettings<F> = Default::default(); + //let inner_it = inner_config.iterator_options; - let mkname = |t| format!("{prefix}{alg_name}_{t}"); + // Create plotter and directory if needed. + let plotter = make_plotter(this_prefix); + + let start = Instant::now(); + let start_cpu = ProcessTime::now(); - write_json(mkname("config.json"), &named)?; - write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; - μ.write_csv(mkname("reco.txt"))?; - save_extra(mkname(""), z)?; - //logger.write_csv(mkname("log.txt"))?; - logs.push((mkname("log.txt"), logger)); + let (μ, z) = match do_alg((alg, iterator, plotter, init.clone(), running)) { + Ok(μ) => μ, + Err(e) => { + let msg = format!( + "Skipping algorithm “{alg_name}” for {experiment_name} due to error: {e}" + ) + .red(); + eprintln!("{}", msg); + continue; + } + }; + + let elapsed = start.elapsed().as_secs_f64(); + let cpu_time = start_cpu.elapsed().as_secs_f64(); + + println!( + "{}", + format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow() + ); + + // Save results + println!("{}", "Saving results …".green()); + + let mkname = |t| format!("{prefix}{alg_name}_{t}"); + + write_json(mkname("config.json"), &named)?; + write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; + μ.write_csv(mkname("reco.txt"))?; + save_extra(mkname(""), z)?; + //logger.write_csv(mkname("log.txt"))?; + logs.push((mkname("log.txt"), logger)); + } + + save_logs( + logs, + format!("{prefix}valuerange.json"), + cli.load_valuerange, + ) } +} - save_logs(logs, format!("{prefix}valuerange.json"), cli.load_valuerange) +impl<F, E> RunnableExperimentExtras<F> for E +where + F: ClapFloat, + Self: RunnableExperiment<F> + Serialize, +{ } #[replace_float_literals(F::cast_from(literal))] -impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N : usize> RunnableExperiment<F> for -Named<ExperimentV2<F, NoiseDistr, S, K, P, N>> +impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N: usize> RunnableExperiment<F> + for Named<ExperimentV2<F, NoiseDistr, S, K, P, N>> where - F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> - + Default + for<'b> Deserialize<'b>, - [usize; N] : Serialize, - S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, - P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, - Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy - // TODO: shold not have differentiability as a requirement, but - // decide availability of sliding based on it. - //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, - // TODO: very weird that rust only compiles with Differentiable - // instead of the above one on references, which is required by - // poitsource_sliding_fb_reg. - + DifferentiableRealMapping<F, N> - + Lipschitz<L2, FloatType=F>, - for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb. - AutoConvolution<P> : BoundedBy<F, K>, - K : SimpleConvolutionKernel<F, N> + F: ClapFloat + + nalgebra::RealField + + ToNalgebraRealField<MixedType = F> + + Default + + for<'b> Deserialize<'b>, + [usize; N]: Serialize, + S: Sensor<N, F> + Copy + Serialize + std::fmt::Debug, + P: Spread<N, F> + Copy + Serialize + std::fmt::Debug, + Convolution<S, P>: Spread<N, F> + + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> - + Copy + Serialize + std::fmt::Debug, - Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, - PlotLookup : Plotting<N>, - DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, + + Copy + // TODO: shold not have differentiability as a requirement, but + // decide availability of sliding based on it. + //+ for<'b> Differentiable<&'b Loc<N, F>, Output = Loc<N, F>>, + // TODO: very weird that rust only compiles with Differentiable + // instead of the above one on references, which is required by + // poitsource_sliding_fb_reg. + + DifferentiableRealMapping<N, F> + + Lipschitz<L2, FloatType = F>, + for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<N, F>>>::Differential<'b>: + Lipschitz<L2, FloatType = F>, // TODO: should not be required generally, only for sliding_fb. + AutoConvolution<P>: BoundedBy<F, K>, + K: SimpleConvolutionKernel<N, F> + + LocalAnalysis<F, Bounds<F>, N> + + Copy + + Serialize + + std::fmt::Debug, + Cube<N, F>: P2Minimise<Loc<N, F>, F> + SetOrd, + PlotLookup: Plotting<N>, + DefaultBT<F, N>: SensorGridBT<F, S, P, N, Depth = DynamicDepth> + BTSearch<N, F>, BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - RNDM<F, N> : SpikeMerging<F>, - NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, - // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, - // PreadjointCodomain : Space + Bounded<F> + DifferentiableRealMapping<F, N>, - // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, - // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, - // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, - // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, + RNDM<N, F>: SpikeMerging<F>, + NoiseDistr: Distribution<F> + Serialize + std::fmt::Debug, { + fn name(&self) -> &str { + self.name.as_ref() + } - fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { + fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F> { AlgorithmOverrides { - merge_radius : Some(self.data.default_merge_radius), - .. self.data.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) + merge_radius: Some(self.data.default_merge_radius), + ..self + .data + .algorithm_overrides + .get(&alg) + .cloned() + .unwrap_or(Default::default()) } } - fn runall(&self, cli : &CommandLineArgs, - algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { + fn runall( + &self, + cli: &CommandLineArgs, + algs: Option<Vec<Named<AlgorithmConfig<F>>>>, + ) -> DynError { // Get experiment configuration - let &Named { - name : ref experiment_name, - data : ExperimentV2 { - domain, sensor_count, ref noise_distr, sensor, spread, kernel, - ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, - .. - } - } = self; + let &ExperimentV2 { + domain, + sensor_count, + ref noise_distr, + sensor, + spread, + kernel, + ref μ_hat, + regularisation, + kernel_plot_width, + dataterm, + noise_seed, + .. + } = &self.data; // Set up algorithms let algorithms = match (algs, dataterm) { (Some(algs), _) => algs, - (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], - (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], + (None, DataTermType::L222) => vec![DefaultAlgorithm::FB.get_named()], + (None, DataTermType::L1) => vec![DefaultAlgorithm::PDPS.get_named()], }; // Set up operators @@ -738,255 +751,348 @@ // overloading log10 and conflicting with standard NumTraits one. let stats = ExperimentStats::new(&b, &noise); - let prefix = start_experiment(&self, cli, stats)?; + let prefix = self.start(cli)?; + write_json(format!("{prefix}stats.json"), &stats)?; - plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, - &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; + plotall( + cli, + &prefix, + &domain, + &sensor, + &kernel, + &spread, + &μ_hat, + &op𝒟, + &opA, + &b_hat, + &b, + kernel_plot_width, + )?; - let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); + let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]); + let make_plotter = |this_prefix| { + let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; + SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()) + }; let save_extra = |_, ()| Ok(()); - do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, - |alg, iterator, plotter, running| - { - let μ = match alg { - AlgorithmConfig::FB(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fb_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - _ => Err(NotImplemented) - } - }, - AlgorithmConfig::FISTA(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_fista_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - _ => Err(NotImplemented), - } - }, - AlgorithmConfig::SlidingFB(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_fb_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter - ) - }), - _ => Err(NotImplemented), - } - }, - AlgorithmConfig::PDPS(ref algconfig, prox) => { - print!("{running}"); - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L1 - ) - }), - (Regularisation::Radon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, L1 - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L2Squared - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L1 - ) - }), - (Regularisation::Radon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ - pointsource_pdps_reg( - &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, L1 - ) - }), - // _ => Err(NotImplemented), - } - }, - AlgorithmConfig::FW(ref algconfig) => { - match (regularisation, dataterm) { - (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ - print!("{running}"); - pointsource_fw_reg(&opA, &b, RadonRegTerm(α), - algconfig, iterator, plotter) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ - print!("{running}"); - pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), - algconfig, iterator, plotter) - }), - _ => Err(NotImplemented), - } - }, - _ => Err(NotImplemented), - }?; - Ok((μ, ())) - }) + let μ0 = None; // Zero init + + match (dataterm, regularisation) { + (DataTermType::L1, Regularisation::Radon(α)) => { + let f = DataTerm::new(opA, b, L1.as_mapping()); + let reg = RadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |_| Err(NotImplemented.into())) + }) + .map(|μ| (μ, ())) + }, + ) + } + (DataTermType::L1, Regularisation::NonnegRadon(α)) => { + let f = DataTerm::new(opA, b, L1.as_mapping()); + let reg = NonnegRadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |_| Err(NotImplemented.into())) + }) + .map(|μ| (μ, ())) + }, + ) + } + (DataTermType::L222, Regularisation::Radon(α)) => { + let f = DataTerm::new(opA, b, Norm222::new()); + let reg = RadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_fb(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_fb(&f, ®, &op𝒟, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |p| { + run_fw(&f, ®, p, |_| Err(NotImplemented.into())) + }) + }) + }) + }) + .map(|μ| (μ, ())) + }, + ) + } + (DataTermType::L222, Regularisation::NonnegRadon(α)) => { + let f = DataTerm::new(opA, b, Norm222::new()); + let reg = NonnegRadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + μ0, + |p| { + run_fb(&f, ®, &RadonSquared, p, |p| { + run_pdps(&f, ®, &RadonSquared, p, |p| { + run_fb(&f, ®, &op𝒟, p, |p| { + run_pdps(&f, ®, &op𝒟, p, |p| { + run_fw(&f, ®, p, |_| Err(NotImplemented.into())) + }) + }) + }) + }) + .map(|μ| (μ, ())) + }, + ) + } + } + } +} + +/// Runs PDPS if `alg` so requests and `prox_penalty` matches. +/// +/// Due to the structure of the PDPS, the data term `f` has to have a specific form. +/// +/// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. +pub fn run_pdps<'a, F, A, Phi, Reg, P, I, Plot, const N: usize>( + f: &'a DataTerm<F, RNDM<N, F>, A, Phi>, + reg: &Reg, + prox_penalty: &P, + (alg, iterator, plotter, μ0, running): ( + &AlgorithmConfig<F>, + I, + Plot, + Option<RNDM<N, F>>, + String, + ), + cont: impl FnOnce( + (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String), + ) -> DynResult<RNDM<N, F>>, +) -> DynResult<RNDM<N, F>> +where + F: Float + ToNalgebraRealField, + A: ForwardModel<RNDM<N, F>, F>, + Phi: Conjugable<A::Observable, F>, + for<'b> Phi::Conjugate<'b>: Prox<A::Observable>, + for<'b> &'b A::Observable: Instance<A::Observable>, + A::Observable: AXPY, + Reg: SlidingRegTerm<Loc<N, F>, F>, + P: ProxPenalty<Loc<N, F>, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD<F, A, RNDM<N, F>>, + RNDM<N, F>: SpikeMerging<F>, + I: AlgIteratorFactory<IterInfo<F>>, + Plot: Plotter<P::ReturnMapping, A::PreadjointCodomain, RNDM<N, F>>, +{ + match alg { + &AlgorithmConfig::PDPS(ref algconfig, prox_type) if prox_type == P::prox_type() => { + print!("{running}"); + pointsource_pdps_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + _ => cont((alg, iterator, plotter, μ0, running)), } } +/// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches. +/// +/// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. +pub fn run_fb<F, Dat, Reg, P, I, Plot, const N: usize>( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + (alg, iterator, plotter, μ0, running): ( + &AlgorithmConfig<F>, + I, + Plot, + Option<RNDM<N, F>>, + String, + ), + cont: impl FnOnce( + (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String), + ) -> DynResult<RNDM<N, F>>, +) -> DynResult<RNDM<N, F>> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory<IterInfo<F>>, + Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>, + Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>, + RNDM<N, F>: SpikeMerging<F>, + Reg: SlidingRegTerm<Loc<N, F>, F>, + P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, + Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, +{ + let pt = P::prox_type(); + + match alg { + &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + &AlgorithmConfig::FISTA(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_fista_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_sliding_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) + } + _ => cont((alg, iterator, plotter, μ0, running)), + } +} + +/// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches. +/// +/// For the moment, due to restrictions of the Frank–Wolfe implementation, only the +/// $L^2$-squared data term is enabled through the type signatures. +/// +/// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. +pub fn run_fw<'a, F, A, Reg, I, Plot, const N: usize>( + f: &'a DataTerm<F, RNDM<N, F>, A, Norm222<F>>, + reg: &Reg, + (alg, iterator, plotter, μ0, running): ( + &AlgorithmConfig<F>, + I, + Plot, + Option<RNDM<N, F>>, + String, + ), + cont: impl FnOnce( + (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String), + ) -> DynResult<RNDM<N, F>>, +) -> DynResult<RNDM<N, F>> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory<IterInfo<F>>, + A: ForwardModel<RNDM<N, F>, F>, + A::PreadjointCodomain: MinMaxMapping<Loc<N, F>, F>, + for<'b> &'b A::PreadjointCodomain: Instance<A::PreadjointCodomain>, + Cube<N, F>: P2Minimise<Loc<N, F>, F>, + RNDM<N, F>: SpikeMerging<F>, + Reg: RegTermFW<F, A, ValueIteratorFactory<F, AlgIteratorOptions>, N>, + Plot: Plotter<A::PreadjointCodomain, A::PreadjointCodomain, RNDM<N, F>>, +{ + match alg { + &AlgorithmConfig::FW(ref algconfig) => { + print!("{running}"); + pointsource_fw_reg(f, reg, algconfig, iterator, plotter, μ0) + } + _ => cont((alg, iterator, plotter, μ0, running)), + } +} #[replace_float_literals(F::cast_from(literal))] -impl<F, NoiseDistr, S, K, P, B, /*PreadjointCodomain,*/ const N : usize> RunnableExperiment<F> for -Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> +impl<F, NoiseDistr, S, K, P, B, const N: usize> RunnableExperiment<F> + for Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> where - F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> - + Default + for<'b> Deserialize<'b>, - [usize; N] : Serialize, - S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, - P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, - Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy - // TODO: shold not have differentiability as a requirement, but - // decide availability of sliding based on it. - //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, - // TODO: very weird that rust only compiles with Differentiable - // instead of the above one on references, which is required by - // poitsource_sliding_fb_reg. - + DifferentiableRealMapping<F, N> - + Lipschitz<L2, FloatType=F>, - for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb. - AutoConvolution<P> : BoundedBy<F, K>, - K : SimpleConvolutionKernel<F, N> + F: ClapFloat + + nalgebra::RealField + + ToNalgebraRealField<MixedType = F> + + Default + + for<'b> Deserialize<'b>, + [usize; N]: Serialize, + S: Sensor<N, F> + Copy + Serialize + std::fmt::Debug, + P: Spread<N, F> + Copy + Serialize + std::fmt::Debug, + Convolution<S, P>: Spread<N, F> + + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> - + Copy + Serialize + std::fmt::Debug, - Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, - PlotLookup : Plotting<N>, - DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, + + Copy + // TODO: shold not have differentiability as a requirement, but + // decide availability of sliding based on it. + //+ for<'b> Differentiable<&'b Loc<N, F>, Output = Loc<N, F>>, + // TODO: very weird that rust only compiles with Differentiable + // instead of the above one on references, which is required by + // poitsource_sliding_fb_reg. + + DifferentiableRealMapping<N, F> + + Lipschitz<L2, FloatType = F>, + for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<N, F>>>::Differential<'b>: + Lipschitz<L2, FloatType = F>, // TODO: should not be required generally, only for sliding_fb. + AutoConvolution<P>: BoundedBy<F, K>, + K: SimpleConvolutionKernel<N, F> + + LocalAnalysis<F, Bounds<F>, N> + + Copy + + Serialize + + std::fmt::Debug, + Cube<N, F>: P2Minimise<Loc<N, F>, F> + SetOrd, + PlotLookup: Plotting<N>, + DefaultBT<F, N>: SensorGridBT<F, S, P, N, Depth = DynamicDepth> + BTSearch<N, F>, BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, - RNDM<F, N> : SpikeMerging<F>, - NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, - B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, - // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, - // PreadjointCodomain : Bounded<F> + DifferentiableRealMapping<F, N>, + RNDM<N, F>: SpikeMerging<F>, + NoiseDistr: Distribution<F> + Serialize + std::fmt::Debug, + B: Mapping<Loc<N, F>, Codomain = F> + Serialize + std::fmt::Debug, + nalgebra::DVector<F>: ClosedMul<F>, + // This is mainly required for the final Mul requirement to be defined + // DefaultSG<F, S, P, N>: ForwardModel< + // RNDM<N, F>, + // F, + // PreadjointCodomain = PreadjointCodomain, + // Observable = DVector<F::MixedType>, + // >, + // PreadjointCodomain: Bounded<F> + DifferentiableRealMapping<N, F> + std::ops::Mul<F>, + // Pair<PreadjointCodomain, DVector<F>>: std::ops::Mul<F>, // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, { + fn name(&self) -> &str { + self.name.as_ref() + } - fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { + fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F> { AlgorithmOverrides { - merge_radius : Some(self.data.base.default_merge_radius), - .. self.data.base.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) + merge_radius: Some(self.data.base.default_merge_radius), + ..self + .data + .base + .algorithm_overrides + .get(&alg) + .cloned() + .unwrap_or(Default::default()) } } - fn runall(&self, cli : &CommandLineArgs, - algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { + fn runall( + &self, + cli: &CommandLineArgs, + algs: Option<Vec<Named<AlgorithmConfig<F>>>>, + ) -> DynError { // Get experiment configuration - let &Named { - name : ref experiment_name, - data : ExperimentBiased { - λ, - ref bias, - base : ExperimentV2 { - domain, sensor_count, ref noise_distr, sensor, spread, kernel, - ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, + let &ExperimentBiased { + λ, + ref bias, + base: + ExperimentV2 { + domain, + sensor_count, + ref noise_distr, + sensor, + spread, + kernel, + ref μ_hat, + regularisation, + kernel_plot_width, + dataterm, + noise_seed, .. - } - } - } = self; + }, + } = &self.data; // Set up algorithms let algorithms = match (algs, dataterm) { @@ -1000,173 +1106,304 @@ let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); let opAext = RowOp(opA.clone(), IdOp::new()); let fnR = Zero::new(); - let h = map3(domain.span_start(), domain.span_end(), sensor_count, - |a, b, n| (b-a)/F::cast_from(n)) - .into_iter() - .reduce(NumTraitsFloat::max) - .unwrap(); + let h = map3( + domain.span_start(), + domain.span_end(), + sensor_count, + |a, b, n| (b - a) / F::cast_from(n), + ) + .into_iter() + .reduce(NumTraitsFloat::max) + .unwrap(); let z = DVector::zeros(sensor_count.iter().product()); let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap(); let y = opKz.apply(&z); - let fnH = Weighted{ base_fn : L1.as_mapping(), weight : λ}; // TODO: L_{2,1} - // let zero_y = y.clone(); - // let zeroBTFN = opA.preadjoint().apply(&zero_y); - // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); + let fnH = Weighted { base_fn: L1.as_mapping(), weight: λ }; // TODO: L_{2,1} + // let zero_y = y.clone(); + // let zeroBTFN = opA.preadjoint().apply(&zero_y); + // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); // Set up random number generator. let mut rng = StdRng::seed_from_u64(noise_seed); // Generate the data and calculate SSNR statistic - let bias_vec = DVector::from_vec(opA.grid() - .into_iter() - .map(|v| bias.apply(v)) - .collect::<Vec<F>>()); - let b_hat : DVector<_> = opA.apply(μ_hat) + &bias_vec; + let bias_vec = DVector::from_vec( + opA.grid() + .into_iter() + .map(|v| bias.apply(v)) + .collect::<Vec<F>>(), + ); + let b_hat: DVector<_> = opA.apply(μ_hat) + &bias_vec; let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); let b = &b_hat + &noise; // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField // overloading log10 and conflicting with standard NumTraits one. let stats = ExperimentStats::new(&b, &noise); - let prefix = start_experiment(&self, cli, stats)?; + let prefix = self.start(cli)?; + write_json(format!("{prefix}stats.json"), &stats)?; - plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, - &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; + plotall( + cli, + &prefix, + &domain, + &sensor, + &kernel, + &spread, + &μ_hat, + &op𝒟, + &opA, + &b_hat, + &b, + kernel_plot_width, + )?; opA.write_observable(&bias_vec, format!("{prefix}bias"))?; - let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); + let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]); + let make_plotter = |this_prefix| { + let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; + SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()) + }; let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z")); - // Run the algorithms - do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, - |alg, iterator, plotter, running| - { - let Pair(μ, z) = match alg { - AlgorithmConfig::ForwardPDPS(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_forward_pdps_pair( - &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - _ => Err(NotImplemented) - } - }, - AlgorithmConfig::SlidingPDPS(ref algconfig, prox) => { - match (regularisation, dataterm, prox) { - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ - print!("{running}"); - pointsource_sliding_pdps_pair( - &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, - iterator, plotter, - /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), - ) - }), - _ => Err(NotImplemented) - } - }, - _ => Err(NotImplemented) - }?; - Ok((μ, z)) - }) + let μ0 = None; // Zero init + + match (dataterm, regularisation) { + (DataTermType::L222, Regularisation::Radon(α)) => { + let f = DataTerm::new(opAext, b, Norm222::new()); + let reg = RadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + (μ0, z, y), + |p| { + run_pdps_pair(&f, ®, &RadonSquared, &opKz, &fnR, &fnH, p, |q| { + run_pdps_pair(&f, ®, &op𝒟, &opKz, &fnR, &fnH, q, |_| { + Err(NotImplemented.into()) + }) + }) + .map(|Pair(μ, z)| (μ, z)) + }, + ) + } + (DataTermType::L222, Regularisation::NonnegRadon(α)) => { + let f = DataTerm::new(opAext, b, Norm222::new()); + let reg = NonnegRadonRegTerm(α); + self.do_runall( + &prefix, + cli, + algorithms, + make_plotter, + save_extra, + (μ0, z, y), + |p| { + run_pdps_pair(&f, ®, &RadonSquared, &opKz, &fnR, &fnH, p, |q| { + run_pdps_pair(&f, ®, &op𝒟, &opKz, &fnR, &fnH, q, |_| { + Err(NotImplemented.into()) + }) + }) + .map(|Pair(μ, z)| (μ, z)) + }, + ) + } + _ => Err(NotImplemented.into()), + } + } +} + +type MeasureZ<F, Z, const N: usize> = Pair<RNDM<N, F>, Z>; + +pub fn run_pdps_pair<F, S, Dat, Reg, Z, R, Y, KOpZ, H, P, I, Plot, const N: usize>( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + opKz: &KOpZ, + fnR: &R, + fnH: &H, + (alg, iterator, plotter, μ0zy, running): ( + &AlgorithmConfig<F>, + I, + Plot, + (Option<RNDM<N, F>>, Z, Y), + String, + ), + cont: impl FnOnce( + ( + &AlgorithmConfig<F>, + I, + Plot, + (Option<RNDM<N, F>>, Z, Y), + String, + ), + ) -> DynResult<Pair<RNDM<N, F>, Z>>, +) -> DynResult<Pair<RNDM<N, F>, Z>> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory<IterInfo<F>>, + Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>> + + BoundedCurvature<F>, + S: DifferentiableRealMapping<N, F> + ClosedMul<F>, + for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>, + //Pair<S, Z>: ClosedMul<F>, + RNDM<N, F>: SpikeMerging<F>, + Reg: SlidingRegTerm<Loc<N, F>, F>, + P: ProxPenalty<Loc<N, F>, S, Reg, F>, + KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y> + + GEMV<F, Z> + + SimplyAdjointable<Z, Y, AdjointCodomain = Z>, + KOpZ::SimpleAdjoint: GEMV<F, Y>, + Y: ClosedEuclidean<F> + Clone, + for<'b> &'b Y: Instance<Y>, + Z: ClosedEuclidean<F> + Clone + ClosedMul<F>, + for<'b> &'b Z: Instance<Z>, + R: Prox<Z, Codomain = F>, + H: Conjugable<Y, F, Codomain = F>, + for<'b> H::Conjugate<'b>: Prox<Y>, + Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>, +{ + let pt = P::prox_type(); + + match alg { + &AlgorithmConfig::ForwardPDPS(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_forward_pdps_pair( + f, + reg, + prox_penalty, + algconfig, + iterator, + plotter, + μ0zy, + opKz, + fnR, + fnH, + ) + } + &AlgorithmConfig::SlidingPDPS(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_sliding_pdps_pair( + f, + reg, + prox_penalty, + algconfig, + iterator, + plotter, + μ0zy, + opKz, + fnR, + fnH, + ) + } + _ => cont((alg, iterator, plotter, μ0zy, running)), + } +} + +pub fn run_fb_pair<F, S, Dat, Reg, Z, R, P, I, Plot, const N: usize>( + f: &Dat, + reg: &Reg, + prox_penalty: &P, + fnR: &R, + (alg, iterator, plotter, μ0z, running): ( + &AlgorithmConfig<F>, + I, + Plot, + (Option<RNDM<N, F>>, Z), + String, + ), + cont: impl FnOnce( + ( + &AlgorithmConfig<F>, + I, + Plot, + (Option<RNDM<N, F>>, Z), + String, + ), + ) -> DynResult<Pair<RNDM<N, F>, Z>>, +) -> DynResult<Pair<RNDM<N, F>, Z>> +where + F: Float + ToNalgebraRealField, + I: AlgIteratorFactory<IterInfo<F>>, + Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>> + + BoundedCurvature<F>, + S: DifferentiableRealMapping<N, F> + ClosedMul<F>, + RNDM<N, F>: SpikeMerging<F>, + Reg: SlidingRegTerm<Loc<N, F>, F>, + P: ProxPenalty<Loc<N, F>, S, Reg, F>, + for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>, + Z: ClosedEuclidean<F> + AXPY + Clone, + for<'b> &'b Z: Instance<Z>, + R: Prox<Z, Codomain = F>, + Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>, + // We should not need to explicitly require this: + for<'b> &'b Loc<0, F>: Instance<Loc<0, F>>, +{ + let pt = P::prox_type(); + + match alg { + &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_fb_pair(f, reg, prox_penalty, algconfig, iterator, plotter, μ0z, fnR) + } + &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => { + print!("{running}"); + pointsource_sliding_fb_pair( + f, + reg, + prox_penalty, + algconfig, + iterator, + plotter, + μ0z, + fnR, + ) + } + _ => cont((alg, iterator, plotter, μ0z, running)), } } #[derive(Copy, Clone, Debug, Serialize, Deserialize)] -struct ValueRange<F : Float> { - ini : F, - min : F, +struct ValueRange<F: Float> { + ini: F, + min: F, } -impl<F : Float> ValueRange<F> { - fn expand_with(self, other : Self) -> Self { - ValueRange { - ini : self.ini.max(other.ini), - min : self.min.min(other.min), - } +impl<F: Float> ValueRange<F> { + fn expand_with(self, other: Self) -> Self { + ValueRange { ini: self.ini.max(other.ini), min: self.min.min(other.min) } } } /// Calculative minimum and maximum values of all the `logs`, and save them into /// corresponding file names given as the first elements of the tuples in the vectors. -fn save_logs<F : Float + for<'b> Deserialize<'b>, const N : usize>( - logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)>, - valuerange_file : String, - load_valuerange : bool, +fn save_logs<F: Float + for<'b> Deserialize<'b>>( + logs: Vec<(String, Logger<Timed<IterInfo<F>>>)>, + valuerange_file: String, + load_valuerange: bool, ) -> DynError { // Process logs for relative values println!("{}", "Processing logs…"); // Find minimum value and initial value within a single log - let proc_single_log = |log : &Logger<Timed<IterInfo<F, N>>>| { + let proc_single_log = |log: &Logger<Timed<IterInfo<F>>>| { let d = log.data(); - let mi = d.iter() - .map(|i| i.data.value) - .reduce(NumTraitsFloat::min); + let mi = d.iter().map(|i| i.data.value).reduce(NumTraitsFloat::min); d.first() - .map(|i| i.data.value) - .zip(mi) - .map(|(ini, min)| ValueRange{ ini, min }) + .map(|i| i.data.value) + .zip(mi) + .map(|(ini, min)| ValueRange { ini, min }) }; // Find minimum and maximum value over all logs - let mut v = logs.iter() - .filter_map(|&(_, ref log)| proc_single_log(log)) - .reduce(|v1, v2| v1.expand_with(v2)) - .ok_or(anyhow!("No algorithms found"))?; + let mut v = logs + .iter() + .filter_map(|&(_, ref log)| proc_single_log(log)) + .reduce(|v1, v2| v1.expand_with(v2)) + .ok_or(anyhow!("No algorithms found"))?; // Load existing range if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() { @@ -1183,10 +1420,11 @@ pruned, //postprocessing, this_iters, + ε, .. } = data; // let post_value = match (postprocessing, dataterm) { - // (Some(mut μ), DataTerm::L2Squared) => { + // (Some(mut μ), DataTermType::L222) => { // // Comparison postprocessing is only implemented for the case handled // // by the FW variants. // reg.optimise_weights( @@ -1198,18 +1436,19 @@ // }, // _ => value, // }; - let relative_value = (value - v.min)/(v.ini - v.min); + let relative_value = (value - v.min) / (v.ini - v.min); CSVLog { iter, value, relative_value, //post_value, n_spikes, - cpu_time : cpu_time.as_secs_f64(), + cpu_time: cpu_time.as_secs_f64(), inner_iters, merged, pruned, - this_iters + this_iters, + epsilon: ε, } }; @@ -1224,45 +1463,48 @@ Ok(()) } - /// Plot experiment setup #[replace_float_literals(F::cast_from(literal))] -fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( - cli : &CommandLineArgs, - prefix : &String, - domain : &Cube<F, N>, - sensor : &Sensor, - kernel : &Kernel, - spread : &Spread, - μ_hat : &RNDM<F, N>, - op𝒟 : &𝒟, - opA : &A, - b_hat : &A::Observable, - b : &A::Observable, - kernel_plot_width : F, +fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N: usize>( + cli: &CommandLineArgs, + prefix: &String, + domain: &Cube<N, F>, + sensor: &Sensor, + kernel: &Kernel, + spread: &Spread, + μ_hat: &RNDM<N, F>, + op𝒟: &𝒟, + opA: &A, + b_hat: &A::Observable, + b: &A::Observable, + kernel_plot_width: F, ) -> DynError -where F : Float + ToNalgebraRealField, - Sensor : RealMapping<F, N> + Support<F, N> + Clone, - Spread : RealMapping<F, N> + Support<F, N> + Clone, - Kernel : RealMapping<F, N> + Support<F, N>, - Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>, - 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, - 𝒟::Codomain : RealMapping<F, N>, - A : ForwardModel<RNDM<F, N>, F>, - for<'a> &'a A::Observable : Instance<A::Observable>, - A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>, - PlotLookup : Plotting<N>, - Cube<F, N> : SetOrd { - +where + F: Float + ToNalgebraRealField, + Sensor: RealMapping<N, F> + Support<N, F> + Clone, + Spread: RealMapping<N, F> + Support<N, F> + Clone, + Kernel: RealMapping<N, F> + Support<N, F>, + Convolution<Sensor, Spread>: DifferentiableRealMapping<N, F> + Support<N, F>, + 𝒟: DiscreteMeasureOp<Loc<N, F>, F>, + 𝒟::Codomain: RealMapping<N, F>, + A: ForwardModel<RNDM<N, F>, F>, + for<'a> &'a A::Observable: Instance<A::Observable>, + A::PreadjointCodomain: DifferentiableRealMapping<N, F> + Bounded<F>, + PlotLookup: Plotting<N>, + Cube<N, F>: SetOrd, +{ if cli.plot < PlotLevel::Data { - return Ok(()) + return Ok(()); } let base = Convolution(sensor.clone(), spread.clone()); - let resolution = if N==1 { 100 } else { 40 }; + let resolution = if N == 1 { 100 } else { 40 }; let pfx = |n| format!("{prefix}{n}"); - let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); + let plotgrid = lingrid( + &[[-kernel_plot_width, kernel_plot_width]; N].into(), + &[resolution; N], + ); PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor")); PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel")); @@ -1272,19 +1514,19 @@ let plotgrid2 = lingrid(&domain, &[resolution; N]); let ω_hat = op𝒟.apply(μ_hat); - let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); + let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat")); PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise")); - let preadj_b = opA.preadjoint().apply(b); - let preadj_b_hat = opA.preadjoint().apply(b_hat); + let preadj_b = opA.preadjoint().apply(b); + let preadj_b_hat = opA.preadjoint().apply(b_hat); //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); PlotLookup::plot_into_file_spikes( Some(&preadj_b), Some(&preadj_b_hat), plotgrid2, &μ_hat, - pfx("omega_b") + pfx("omega_b"), ); PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b")); PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"));