--- a/src/run.rs Sun Dec 11 23:19:17 2022 +0200 +++ b/src/run.rs Sun Dec 11 23:25:53 2022 +0200 @@ -34,7 +34,7 @@ use alg_tools::mapping::RealMapping; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::euclidean::Euclidean; -use alg_tools::norms::{Norm, L1}; +use alg_tools::norms::L1; use alg_tools::lingrid::lingrid; use alg_tools::sets::SetOrd; @@ -45,13 +45,14 @@ use crate::forward_model::*; use crate::fb::{ FBConfig, - pointsource_fb, - FBMetaAlgorithm, FBGenericConfig, + pointsource_fb_reg, + FBMetaAlgorithm, + FBGenericConfig, }; use crate::pdps::{ PDPSConfig, L2Squared, - pointsource_pdps, + pointsource_pdps_reg, }; use crate::frank_wolfe::{ FWConfig, @@ -65,6 +66,7 @@ use crate::plot::*; use crate::{AlgorithmOverrides, CommandLineArgs}; use crate::tolerance::Tolerance; +use crate::regularisation::{Regularisation, RadonRegTerm, NonnegRadonRegTerm}; /// Available algorithms and their configurations #[derive(Copy, Clone, Debug, Serialize, Deserialize)] @@ -276,7 +278,7 @@ /// Struct for experiment configurations #[derive(Debug, Clone, Serialize)] -pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> +pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize> where F : Float, [usize; N] : Serialize, NoiseDistr : Distribution<F>, @@ -300,8 +302,8 @@ pub kernel : K, /// True point sources pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, - /// Regularisation parameter - pub α : F, + /// Regularisation term and parameter + pub regularisation : Regularisation<F>, /// For plotting : how wide should the kernels be plotted pub kernel_plot_width : F, /// Data term @@ -322,8 +324,12 @@ -> Named<AlgorithmConfig<F>>; } +// *** macro boilerplate *** +macro_rules! impl_experiment { +($type:ident, $reg_field:ident, $reg_convert:path) => { +// *** macro *** impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for -Named<Experiment<F, NoiseDistr, S, K, P, N>> +Named<$type<F, NoiseDistr, S, K, P, N>> where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, [usize; N] : Serialize, S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, @@ -356,12 +362,14 @@ // Get experiment configuration let &Named { name : ref experiment_name, - data : Experiment { + data : $type { domain, sensor_count, ref noise_distr, sensor, spread, kernel, - ref μ_hat, α, kernel_plot_width, dataterm, noise_seed, + ref μ_hat, /*regularisation,*/ kernel_plot_width, dataterm, noise_seed, .. } } = self; + #[allow(deprecated)] + let regularisation = $reg_convert(self.data.$reg_field); println!("{}\n{}", format!("Performing experiment {}…", experiment_name).cyan(), @@ -420,7 +428,12 @@ format!("{:?}", iterator_options).bright_black(), format!("{:?}", alg).bright_black()); }; - + let not_implemented = || { + let msg = format!("Algorithm “{alg_name}” not implemented for \ + dataterm {dataterm:?} and regularisation {regularisation:?}. \ + Skipping.").red(); + eprintln!("{}", msg); + }; // Create Logger and IteratorFactory let mut logger = Logger::new(); let findim_data = prepare_optimise_weights(&opA); @@ -437,20 +450,18 @@ this_iters, .. } = data; - let post_value = match postprocessing { - None => value, - Some(mut μ) => { - match dataterm { - DataTerm::L2Squared => { - optimise_weights( - &mut μ, &opA, &b, α, &findim_data, &inner_config, - inner_it - ); - dataterm.value_at_residual(opA.apply(&μ) - &b) + α * μ.norm(Radon) - }, - _ => value, - } - } + let post_value = match (postprocessing, dataterm, regularisation) { + (Some(mut μ), DataTerm::L2Squared, Regularisation::Radon(α)) => { + // Comparison postprocessing is only implemented for the case handled + // by the FW variants. + optimise_weights( + &mut μ, &opA, &b, α, &findim_data, &inner_config, + inner_it + ); + dataterm.value_at_residual(opA.apply(&μ) - &b) + + regularisation.apply(&μ) + }, + _ => value, }; CSVLog { iter, @@ -477,30 +488,72 @@ // Run the algorithm let start = Instant::now(); let start_cpu = ProcessTime::now(); - let μ : DiscreteMeasure<Loc<F, N>, F> = match (alg, dataterm) { - (AlgorithmConfig::FB(ref algconfig), DataTerm::L2Squared) => { - running(); - pointsource_fb(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter) + let μ = match alg { + AlgorithmConfig::FB(ref algconfig) => { + match (regularisation, dataterm) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { + running(); + pointsource_fb_reg( + &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, + iterator, plotter + ) + }, + (Regularisation::Radon(α), DataTerm::L2Squared) => { + running(); + pointsource_fb_reg( + &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, + iterator, plotter + ) + }, + _ => { + not_implemented(); + continue + } + } }, - (AlgorithmConfig::FW(ref algconfig), DataTerm::L2Squared) => { - running(); - pointsource_fw(&opA, &b, α, &algconfig, iterator, plotter) - }, - (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L2Squared) => { + AlgorithmConfig::PDPS(ref algconfig) => { running(); - pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L2Squared) - }, - (AlgorithmConfig::PDPS(ref algconfig), DataTerm::L1) => { - running(); - pointsource_pdps(&opA, &b, α, &op𝒟, &algconfig, iterator, plotter, L1) + match (regularisation, dataterm) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { + pointsource_pdps_reg( + &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, + iterator, plotter, L2Squared + ) + }, + (Regularisation::Radon(α),DataTerm::L2Squared) => { + pointsource_pdps_reg( + &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, + iterator, plotter, L2Squared + ) + }, + (Regularisation::NonnegRadon(α), DataTerm::L1) => { + pointsource_pdps_reg( + &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, + iterator, plotter, L1 + ) + }, + (Regularisation::Radon(α), DataTerm::L1) => { + pointsource_pdps_reg( + &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, + iterator, plotter, L1 + ) + }, + } }, - _ => { - let msg = format!("Algorithm “{alg_name}” not implemented for \ - dataterm {dataterm:?}. Skipping.").red(); - eprintln!("{}", msg); - continue + AlgorithmConfig::FW(ref algconfig) => { + match (regularisation, dataterm) { + (Regularisation::Radon(α), DataTerm::L2Squared) => { + running(); + pointsource_fw(&opA, &b, α, algconfig, iterator, plotter) + }, + _ => { + not_implemented(); + continue + } + } } }; + let elapsed = start.elapsed().as_secs_f64(); let cpu_time = start_cpu.elapsed().as_secs_f64(); @@ -520,6 +573,11 @@ Ok(()) } } +// *** macro end boiler plate *** +}} +// *** actual code *** + +impl_experiment!(ExperimentV2, regularisation, std::convert::identity); /// Plot experiment setup #[replace_float_literals(F::cast_from(literal))] @@ -589,3 +647,46 @@ opA.write_observable(&b, pfx("b_noisy")) } +// +// Deprecated interface +// + +/// Struct for experiment configurations +#[derive(Debug, Clone, Serialize)] +pub struct Experiment<F, NoiseDistr, S, K, P, const N : usize> +where F : Float, + [usize; N] : Serialize, + NoiseDistr : Distribution<F>, + S : Sensor<F, N>, + P : Spread<F, N>, + K : SimpleConvolutionKernel<F, N>, +{ + /// Domain $Ω$. + pub domain : Cube<F, N>, + /// Number of sensors along each dimension + pub sensor_count : [usize; N], + /// Noise distribution + pub noise_distr : NoiseDistr, + /// Seed for random noise generation (for repeatable experiments) + pub noise_seed : u64, + /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. + pub sensor : S, + /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. + pub spread : P, + /// Kernel $ρ$ of $𝒟$. + pub kernel : K, + /// True point sources + pub μ_hat : DiscreteMeasure<Loc<F, N>, F>, + /// Regularisation parameter + #[deprecated(note = "Use [`ExperimentV2`], which replaces `α` by more generic `regularisation`")] + pub α : F, + /// For plotting : how wide should the kernels be plotted + pub kernel_plot_width : F, + /// Data term + pub dataterm : DataTerm, + /// A map of default configurations for algorithms + #[serde(skip)] + pub algorithm_defaults : HashMap<DefaultAlgorithm, AlgorithmConfig<F>>, +} + +impl_experiment!(Experiment, α, Regularisation::NonnegRadon);