src/run.rs

changeset 70
ed16d0f10d08
parent 63
7a8a55fd41c0
equal deleted inserted replaced
58:6099ba025aac 70:ed16d0f10d08
1 /*! 1 /*!
2 This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. 2 This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment.
3 */ 3 */
4 4
5 use numeric_literals::replace_float_literals; 5 use crate::fb::{pointsource_fb_reg, pointsource_fista_reg, FBConfig, InsertionConfig};
6 use colored::Colorize;
7 use serde::{Serialize, Deserialize};
8 use serde_json;
9 use nalgebra::base::DVector;
10 use std::hash::Hash;
11 use chrono::{DateTime, Utc};
12 use cpu_time::ProcessTime;
13 use clap::ValueEnum;
14 use std::collections::HashMap;
15 use std::time::Instant;
16
17 use rand::prelude::{
18 StdRng,
19 SeedableRng
20 };
21 use rand_distr::Distribution;
22
23 use alg_tools::bisection_tree::*;
24 use alg_tools::iterate::{
25 Timed,
26 AlgIteratorOptions,
27 Verbose,
28 AlgIteratorFactory,
29 LoggingIteratorFactory,
30 TimingIteratorFactory,
31 BasicAlgIteratorFactory,
32 };
33 use alg_tools::logger::Logger;
34 use alg_tools::error::{
35 DynError,
36 DynResult,
37 };
38 use alg_tools::tabledump::TableDump;
39 use alg_tools::sets::Cube;
40 use alg_tools::mapping::{
41 RealMapping,
42 DifferentiableMapping,
43 DifferentiableRealMapping,
44 Instance
45 };
46 use alg_tools::nalgebra_support::ToNalgebraRealField;
47 use alg_tools::euclidean::Euclidean;
48 use alg_tools::lingrid::{lingrid, LinSpace};
49 use alg_tools::sets::SetOrd;
50 use alg_tools::linops::{RowOp, IdOp /*, ZeroOp*/};
51 use alg_tools::discrete_gradient::{Grad, ForwardNeumann};
52 use alg_tools::convex::Zero;
53 use alg_tools::maputil::map3;
54 use alg_tools::direct_product::Pair;
55
56 use crate::kernels::*;
57 use crate::types::*;
58 use crate::measures::*;
59 use crate::measures::merging::{SpikeMerging,SpikeMergingMethod};
60 use crate::forward_model::*;
61 use crate::forward_model::sensor_grid::{ 6 use crate::forward_model::sensor_grid::{
7 //SensorGridBTFN,
8 Sensor,
62 SensorGrid, 9 SensorGrid,
63 SensorGridBT, 10 SensorGridBT,
64 //SensorGridBTFN,
65 Sensor,
66 Spread, 11 Spread,
67 }; 12 };
68 13 use crate::forward_model::*;
69 use crate::fb::{ 14 use crate::forward_pdps::{pointsource_fb_pair, pointsource_forward_pdps_pair, ForwardPDPSConfig};
70 FBConfig, 15 use crate::frank_wolfe::{pointsource_fw_reg, FWConfig, FWVariant, RegTermFW};
71 FBGenericConfig, 16 use crate::kernels::*;
72 pointsource_fb_reg, 17 use crate::measures::merging::{SpikeMerging, SpikeMergingMethod};
73 pointsource_fista_reg, 18 use crate::measures::*;
19 use crate::pdps::{pointsource_pdps_reg, PDPSConfig};
20 use crate::plot::*;
21 use crate::prox_penalty::{
22 ProxPenalty, ProxTerm, RadonSquared, StepLengthBound, StepLengthBoundPD, StepLengthBoundPair,
74 }; 23 };
75 use crate::sliding_fb::{ 24 use crate::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation, SlidingRegTerm};
76 SlidingFBConfig, 25 use crate::seminorms::*;
77 TransportConfig, 26 use crate::sliding_fb::{pointsource_sliding_fb_reg, SlidingFBConfig, TransportConfig};
78 pointsource_sliding_fb_reg 27 use crate::sliding_pdps::{
28 pointsource_sliding_fb_pair, pointsource_sliding_pdps_pair, SlidingPDPSConfig,
79 }; 29 };
80 use crate::sliding_pdps::{ 30 use crate::subproblem::{InnerMethod, InnerSettings};
81 SlidingPDPSConfig, 31 use crate::tolerance::Tolerance;
82 pointsource_sliding_pdps_pair 32 use crate::types::*;
33 use crate::{AlgorithmOverrides, CommandLineArgs};
34 use alg_tools::bisection_tree::*;
35 use alg_tools::bounds::{Bounded, MinMaxMapping};
36 use alg_tools::convex::{Conjugable, Norm222, Prox, Zero};
37 use alg_tools::direct_product::Pair;
38 use alg_tools::discrete_gradient::{ForwardNeumann, Grad};
39 use alg_tools::error::{DynError, DynResult};
40 use alg_tools::euclidean::{ClosedEuclidean, Euclidean};
41 use alg_tools::iterate::{
42 AlgIteratorFactory, AlgIteratorOptions, BasicAlgIteratorFactory, LoggingIteratorFactory, Timed,
43 TimingIteratorFactory, ValueIteratorFactory, Verbose,
83 }; 44 };
84 use crate::forward_pdps::{ 45 use alg_tools::lingrid::lingrid;
85 ForwardPDPSConfig, 46 use alg_tools::linops::{IdOp, RowOp, AXPY};
86 pointsource_forward_pdps_pair 47 use alg_tools::logger::Logger;
48 use alg_tools::mapping::{
49 DataTerm, DifferentiableMapping, DifferentiableRealMapping, Instance, RealMapping,
87 }; 50 };
88 use crate::pdps::{ 51 use alg_tools::maputil::map3;
89 PDPSConfig, 52 use alg_tools::nalgebra_support::ToNalgebraRealField;
90 pointsource_pdps_reg, 53 use alg_tools::norms::{NormExponent, L1, L2};
91 };
92 use crate::frank_wolfe::{
93 FWConfig,
94 FWVariant,
95 pointsource_fw_reg,
96 //WeightOptim,
97 };
98 use crate::subproblem::{InnerSettings, InnerMethod};
99 use crate::seminorms::*;
100 use crate::plot::*;
101 use crate::{AlgorithmOverrides, CommandLineArgs};
102 use crate::tolerance::Tolerance;
103 use crate::regularisation::{
104 Regularisation,
105 RadonRegTerm,
106 NonnegRadonRegTerm
107 };
108 use crate::dataterm::{
109 L1,
110 L2Squared,
111 };
112 use crate::prox_penalty::{
113 RadonSquared,
114 //ProxPenalty,
115 };
116 use alg_tools::norms::{L2, NormExponent};
117 use alg_tools::operator_arithmetic::Weighted; 54 use alg_tools::operator_arithmetic::Weighted;
55 use alg_tools::sets::Cube;
56 use alg_tools::sets::SetOrd;
57 use alg_tools::tabledump::TableDump;
118 use anyhow::anyhow; 58 use anyhow::anyhow;
119 59 use chrono::{DateTime, Utc};
120 /// Available proximal terms 60 use clap::ValueEnum;
121 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 61 use colored::Colorize;
122 pub enum ProxTerm { 62 use cpu_time::ProcessTime;
123 /// Partial-to-wave operator 𝒟. 63 use nalgebra::base::DVector;
124 Wave, 64 use numeric_literals::replace_float_literals;
125 /// Radon-norm squared 65 use rand::prelude::{SeedableRng, StdRng};
126 RadonSquared 66 use rand_distr::Distribution;
127 } 67 use serde::{Deserialize, Serialize};
68 use serde_json;
69 use std::collections::HashMap;
70 use std::hash::Hash;
71 use std::time::Instant;
72 use thiserror::Error;
73
74 //#[cfg(feature = "pyo3")]
75 //use pyo3::pyclass;
128 76
129 /// Available algorithms and their configurations 77 /// Available algorithms and their configurations
130 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 78 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
131 pub enum AlgorithmConfig<F : Float> { 79 pub enum AlgorithmConfig<F: Float> {
132 FB(FBConfig<F>, ProxTerm), 80 FB(FBConfig<F>, ProxTerm),
133 FISTA(FBConfig<F>, ProxTerm), 81 FISTA(FBConfig<F>, ProxTerm),
134 FW(FWConfig<F>), 82 FW(FWConfig<F>),
135 PDPS(PDPSConfig<F>, ProxTerm), 83 PDPS(PDPSConfig<F>, ProxTerm),
136 SlidingFB(SlidingFBConfig<F>, ProxTerm), 84 SlidingFB(SlidingFBConfig<F>, ProxTerm),
137 ForwardPDPS(ForwardPDPSConfig<F>, ProxTerm), 85 ForwardPDPS(ForwardPDPSConfig<F>, ProxTerm),
138 SlidingPDPS(SlidingPDPSConfig<F>, ProxTerm), 86 SlidingPDPS(SlidingPDPSConfig<F>, ProxTerm),
139 } 87 }
140 88
141 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { 89 fn unpack_tolerance<F: Float>(v: &Vec<F>) -> Tolerance<F> {
142 assert!(v.len() == 3); 90 assert!(v.len() == 3);
143 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } 91 Tolerance::Power { initial: v[0], factor: v[1], exponent: v[2] }
144 } 92 }
145 93
146 impl<F : ClapFloat> AlgorithmConfig<F> { 94 impl<F: ClapFloat> AlgorithmConfig<F> {
147 /// Override supported parameters based on the command line. 95 /// Override supported parameters based on the command line.
148 pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { 96 pub fn cli_override(self, cli: &AlgorithmOverrides<F>) -> Self {
149 let override_merging = |g : SpikeMergingMethod<F>| { 97 let override_merging = |g: SpikeMergingMethod<F>| SpikeMergingMethod {
150 SpikeMergingMethod { 98 enabled: cli.merge.unwrap_or(g.enabled),
151 enabled : cli.merge.unwrap_or(g.enabled), 99 radius: cli.merge_radius.unwrap_or(g.radius),
152 radius : cli.merge_radius.unwrap_or(g.radius), 100 interp: cli.merge_interp.unwrap_or(g.interp),
153 interp : cli.merge_interp.unwrap_or(g.interp),
154 }
155 }; 101 };
156 let override_fb_generic = |g : FBGenericConfig<F>| { 102 let override_inner = |g: InnerSettings<F>| InnerSettings {
157 FBGenericConfig { 103 method: cli.inner_method.unwrap_or(g.method),
158 bootstrap_insertions : cli.bootstrap_insertions 104 fb_τ0: cli.inner_τ0.unwrap_or(g.fb_τ0),
159 .as_ref() 105 pdps_τσ0: cli
160 .map_or(g.bootstrap_insertions, 106 .inner_pdps_τσ0
161 |n| Some((n[0], n[1]))), 107 .as_ref()
162 merge_every : cli.merge_every.unwrap_or(g.merge_every), 108 .map_or(g.pdps_τσ0, |τσ0| (τσ0[0], τσ0[1])),
163 merging : override_merging(g.merging), 109 pp_τ: cli.inner_pp_τ.as_ref().map_or(g.pp_τ, |τ| (τ[0], τ[1])),
164 final_merging : cli.final_merging.unwrap_or(g.final_merging), 110 tolerance_mult: cli.inner_tol.unwrap_or(g.tolerance_mult),
165 fitness_merging : cli.fitness_merging.unwrap_or(g.fitness_merging), 111 ..g
166 tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance),
167 .. g
168 }
169 }; 112 };
170 let override_transport = |g : TransportConfig<F>| { 113 let override_fb_generic = |g: InsertionConfig<F>| InsertionConfig {
171 TransportConfig { 114 bootstrap_insertions: cli
172 θ0 : cli.theta0.unwrap_or(g.θ0), 115 .bootstrap_insertions
173 tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con), 116 .as_ref()
174 adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), 117 .map_or(g.bootstrap_insertions, |n| Some((n[0], n[1]))),
175 .. g 118 merge_every: cli.merge_every.unwrap_or(g.merge_every),
176 } 119 merging: override_merging(g.merging),
120 final_merging: cli.final_merging.unwrap_or(g.final_merging),
121 fitness_merging: cli.fitness_merging.unwrap_or(g.fitness_merging),
122 inner: override_inner(g.inner),
123 tolerance: cli
124 .tolerance
125 .as_ref()
126 .map(unpack_tolerance)
127 .unwrap_or(g.tolerance),
128 ..g
129 };
130 let override_transport = |g: TransportConfig<F>| TransportConfig {
131 θ0: cli.theta0.unwrap_or(g.θ0),
132 tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con),
133 adaptation: cli.transport_adaptation.unwrap_or(g.adaptation),
134 ..g
177 }; 135 };
178 136
179 use AlgorithmConfig::*; 137 use AlgorithmConfig::*;
180 match self { 138 match self {
181 FB(fb, prox) => FB(FBConfig { 139 FB(fb, prox) => FB(
182 τ0 : cli.tau0.unwrap_or(fb.τ0), 140 FBConfig {
183 generic : override_fb_generic(fb.generic), 141 τ0: cli.tau0.unwrap_or(fb.τ0),
184 .. fb 142 σp0: cli.sigmap0.unwrap_or(fb.σp0),
185 }, prox), 143 insertion: override_fb_generic(fb.insertion),
186 FISTA(fb, prox) => FISTA(FBConfig { 144 ..fb
187 τ0 : cli.tau0.unwrap_or(fb.τ0), 145 },
188 generic : override_fb_generic(fb.generic), 146 prox,
189 .. fb 147 ),
190 }, prox), 148 FISTA(fb, prox) => FISTA(
191 PDPS(pdps, prox) => PDPS(PDPSConfig { 149 FBConfig {
192 τ0 : cli.tau0.unwrap_or(pdps.τ0), 150 τ0: cli.tau0.unwrap_or(fb.τ0),
193 σ0 : cli.sigma0.unwrap_or(pdps.σ0), 151 σp0: cli.sigmap0.unwrap_or(fb.σp0),
194 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), 152 insertion: override_fb_generic(fb.insertion),
195 generic : override_fb_generic(pdps.generic), 153 ..fb
196 .. pdps 154 },
197 }, prox), 155 prox,
156 ),
157 PDPS(pdps, prox) => PDPS(
158 PDPSConfig {
159 τ0: cli.tau0.unwrap_or(pdps.τ0),
160 σ0: cli.sigma0.unwrap_or(pdps.σ0),
161 acceleration: cli.acceleration.unwrap_or(pdps.acceleration),
162 generic: override_fb_generic(pdps.generic),
163 ..pdps
164 },
165 prox,
166 ),
198 FW(fw) => FW(FWConfig { 167 FW(fw) => FW(FWConfig {
199 merging : override_merging(fw.merging), 168 merging: override_merging(fw.merging),
200 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), 169 tolerance: cli
201 .. fw 170 .tolerance
171 .as_ref()
172 .map(unpack_tolerance)
173 .unwrap_or(fw.tolerance),
174 ..fw
202 }), 175 }),
203 SlidingFB(sfb, prox) => SlidingFB(SlidingFBConfig { 176 SlidingFB(sfb, prox) => SlidingFB(
204 τ0 : cli.tau0.unwrap_or(sfb.τ0), 177 SlidingFBConfig {
205 transport : override_transport(sfb.transport), 178 τ0: cli.tau0.unwrap_or(sfb.τ0),
206 insertion : override_fb_generic(sfb.insertion), 179 σp0: cli.sigmap0.unwrap_or(sfb.σp0),
207 .. sfb 180 transport: override_transport(sfb.transport),
208 }, prox), 181 insertion: override_fb_generic(sfb.insertion),
209 SlidingPDPS(spdps, prox) => SlidingPDPS(SlidingPDPSConfig { 182 ..sfb
210 τ0 : cli.tau0.unwrap_or(spdps.τ0), 183 },
211 σp0 : cli.sigmap0.unwrap_or(spdps.σp0), 184 prox,
212 σd0 : cli.sigma0.unwrap_or(spdps.σd0), 185 ),
213 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), 186 SlidingPDPS(spdps, prox) => SlidingPDPS(
214 transport : override_transport(spdps.transport), 187 SlidingPDPSConfig {
215 insertion : override_fb_generic(spdps.insertion), 188 τ0: cli.tau0.unwrap_or(spdps.τ0),
216 .. spdps 189 σp0: cli.sigmap0.unwrap_or(spdps.σp0),
217 }, prox), 190 σd0: cli.sigma0.unwrap_or(spdps.σd0),
218 ForwardPDPS(fpdps, prox) => ForwardPDPS(ForwardPDPSConfig { 191 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
219 τ0 : cli.tau0.unwrap_or(fpdps.τ0), 192 transport: override_transport(spdps.transport),
220 σp0 : cli.sigmap0.unwrap_or(fpdps.σp0), 193 insertion: override_fb_generic(spdps.insertion),
221 σd0 : cli.sigma0.unwrap_or(fpdps.σd0), 194 ..spdps
222 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), 195 },
223 insertion : override_fb_generic(fpdps.insertion), 196 prox,
224 .. fpdps 197 ),
225 }, prox), 198 ForwardPDPS(fpdps, prox) => ForwardPDPS(
199 ForwardPDPSConfig {
200 τ0: cli.tau0.unwrap_or(fpdps.τ0),
201 σp0: cli.sigmap0.unwrap_or(fpdps.σp0),
202 σd0: cli.sigma0.unwrap_or(fpdps.σd0),
203 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
204 insertion: override_fb_generic(fpdps.insertion),
205 ..fpdps
206 },
207 prox,
208 ),
226 } 209 }
227 } 210 }
228 } 211 }
229 212
230 /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name. 213 /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name.
231 #[derive(Clone, Debug, Serialize, Deserialize)] 214 #[derive(Clone, Debug, Serialize, Deserialize)]
232 pub struct Named<Data> { 215 pub struct Named<Data> {
233 pub name : String, 216 pub name: String,
234 #[serde(flatten)] 217 #[serde(flatten)]
235 pub data : Data, 218 pub data: Data,
236 } 219 }
237 220
238 /// Shorthand algorithm configurations, to be used with the command line parser 221 /// Shorthand algorithm configurations, to be used with the command line parser
239 #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] 222 #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
223 //#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))]
240 pub enum DefaultAlgorithm { 224 pub enum DefaultAlgorithm {
241 /// The μFB forward-backward method 225 /// The μFB forward-backward method
242 #[clap(name = "fb")] 226 #[clap(name = "fb")]
243 FB, 227 FB,
244 /// The μFISTA inertial forward-backward method 228 /// The μFISTA inertial forward-backward method
262 /// The PDPS method with a forward step for the smooth function 246 /// The PDPS method with a forward step for the smooth function
263 #[clap(name = "forward_pdps", alias = "fpdps")] 247 #[clap(name = "forward_pdps", alias = "fpdps")]
264 ForwardPDPS, 248 ForwardPDPS,
265 249
266 // Radon variants 250 // Radon variants
267
268 /// The μFB forward-backward method with radon-norm squared proximal term 251 /// The μFB forward-backward method with radon-norm squared proximal term
269 #[clap(name = "radon_fb")] 252 #[clap(name = "radon_fb")]
270 RadonFB, 253 RadonFB,
271 /// The μFISTA inertial forward-backward method with radon-norm squared proximal term 254 /// The μFISTA inertial forward-backward method with radon-norm squared proximal term
272 #[clap(name = "radon_fista")] 255 #[clap(name = "radon_fista")]
285 RadonForwardPDPS, 268 RadonForwardPDPS,
286 } 269 }
287 270
288 impl DefaultAlgorithm { 271 impl DefaultAlgorithm {
289 /// Returns the algorithm configuration corresponding to the algorithm shorthand 272 /// Returns the algorithm configuration corresponding to the algorithm shorthand
290 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { 273 pub fn default_config<F: Float>(&self) -> AlgorithmConfig<F> {
291 use DefaultAlgorithm::*; 274 use DefaultAlgorithm::*;
292 let radon_insertion = FBGenericConfig { 275 let radon_insertion = InsertionConfig {
293 merging : SpikeMergingMethod{ interp : false, .. Default::default() }, 276 merging: SpikeMergingMethod { interp: false, ..Default::default() },
294 inner : InnerSettings { 277 inner: InnerSettings {
295 method : InnerMethod::PDPS, // SSN not implemented 278 method: InnerMethod::PDPS, // SSN not implemented
296 .. Default::default() 279 ..Default::default()
297 }, 280 },
298 .. Default::default() 281 ..Default::default()
299 }; 282 };
300 match *self { 283 match *self {
301 FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave), 284 FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave),
302 FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave), 285 FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave),
303 FW => AlgorithmConfig::FW(Default::default()), 286 FW => AlgorithmConfig::FW(Default::default()),
304 FWRelax => AlgorithmConfig::FW(FWConfig{ 287 FWRelax => {
305 variant : FWVariant::Relaxed, 288 AlgorithmConfig::FW(FWConfig { variant: FWVariant::Relaxed, ..Default::default() })
306 .. Default::default() 289 }
307 }),
308 PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave), 290 PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave),
309 SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave), 291 SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave),
310 SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave), 292 SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave),
311 ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave), 293 ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave),
312 294
313 // Radon variants 295 // Radon variants
314
315 RadonFB => AlgorithmConfig::FB( 296 RadonFB => AlgorithmConfig::FB(
316 FBConfig{ generic : radon_insertion, ..Default::default() }, 297 FBConfig { insertion: radon_insertion, ..Default::default() },
317 ProxTerm::RadonSquared 298 ProxTerm::RadonSquared,
318 ), 299 ),
319 RadonFISTA => AlgorithmConfig::FISTA( 300 RadonFISTA => AlgorithmConfig::FISTA(
320 FBConfig{ generic : radon_insertion, ..Default::default() }, 301 FBConfig { insertion: radon_insertion, ..Default::default() },
321 ProxTerm::RadonSquared 302 ProxTerm::RadonSquared,
322 ), 303 ),
323 RadonPDPS => AlgorithmConfig::PDPS( 304 RadonPDPS => AlgorithmConfig::PDPS(
324 PDPSConfig{ generic : radon_insertion, ..Default::default() }, 305 PDPSConfig { generic: radon_insertion, ..Default::default() },
325 ProxTerm::RadonSquared 306 ProxTerm::RadonSquared,
326 ), 307 ),
327 RadonSlidingFB => AlgorithmConfig::SlidingFB( 308 RadonSlidingFB => AlgorithmConfig::SlidingFB(
328 SlidingFBConfig{ insertion : radon_insertion, ..Default::default() }, 309 SlidingFBConfig { insertion: radon_insertion, ..Default::default() },
329 ProxTerm::RadonSquared 310 ProxTerm::RadonSquared,
330 ), 311 ),
331 RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS( 312 RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS(
332 SlidingPDPSConfig{ insertion : radon_insertion, ..Default::default() }, 313 SlidingPDPSConfig { insertion: radon_insertion, ..Default::default() },
333 ProxTerm::RadonSquared 314 ProxTerm::RadonSquared,
334 ), 315 ),
335 RadonForwardPDPS => AlgorithmConfig::ForwardPDPS( 316 RadonForwardPDPS => AlgorithmConfig::ForwardPDPS(
336 ForwardPDPSConfig{ insertion : radon_insertion, ..Default::default() }, 317 ForwardPDPSConfig { insertion: radon_insertion, ..Default::default() },
337 ProxTerm::RadonSquared 318 ProxTerm::RadonSquared,
338 ), 319 ),
339 } 320 }
340 } 321 }
341 322
342 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand 323 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand
343 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { 324 pub fn get_named<F: Float>(&self) -> Named<AlgorithmConfig<F>> {
344 self.to_named(self.default_config()) 325 self.to_named(self.default_config())
345 } 326 }
346 327
347 pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { 328 pub fn to_named<F: Float>(self, alg: AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> {
348 let name = self.to_possible_value().unwrap().get_name().to_string(); 329 Named { name: self.name(), data: alg }
349 Named{ name , data : alg } 330 }
350 } 331
351 } 332 pub fn name(self) -> String {
352 333 self.to_possible_value().unwrap().get_name().to_string()
334 }
335 }
353 336
354 // // Floats cannot be hashed directly, so just hash the debug formatting 337 // // Floats cannot be hashed directly, so just hash the debug formatting
355 // // for use as file identifier. 338 // // for use as file identifier.
356 // impl<F : Float> Hash for AlgorithmConfig<F> { 339 // impl<F : Float> Hash for AlgorithmConfig<F> {
357 // fn hash<H: Hasher>(&self, state: &mut H) { 340 // fn hash<H: Hasher>(&self, state: &mut H) {
377 fn default() -> Self { 360 fn default() -> Self {
378 Self::Data 361 Self::Data
379 } 362 }
380 } 363 }
381 364
382 type DefaultBT<F, const N : usize> = BT< 365 type DefaultBT<F, const N: usize> = BT<DynamicDepth, F, usize, Bounds<F>, N>;
383 DynamicDepth, 366 type DefaultSeminormOp<F, K, const N: usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>;
384 F, 367 type DefaultSG<F, Sensor, Spread, const N: usize> =
385 usize, 368 SensorGrid<F, Sensor, Spread, DefaultBT<F, N>, N>;
386 Bounds<F>,
387 N
388 >;
389 type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>;
390 type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::<
391 F,
392 Sensor,
393 Spread,
394 DefaultBT<F, N>,
395 N
396 >;
397 369
398 /// This is a dirty workaround to rust-csv not supporting struct flattening etc. 370 /// This is a dirty workaround to rust-csv not supporting struct flattening etc.
399 #[derive(Serialize)] 371 #[derive(Serialize)]
400 struct CSVLog<F> { 372 struct CSVLog<F> {
401 iter : usize, 373 iter: usize,
402 cpu_time : f64, 374 cpu_time: f64,
403 value : F, 375 value: F,
404 relative_value : F, 376 relative_value: F,
405 //post_value : F, 377 //post_value : F,
406 n_spikes : usize, 378 n_spikes: usize,
407 inner_iters : usize, 379 inner_iters: usize,
408 merged : usize, 380 merged: usize,
409 pruned : usize, 381 pruned: usize,
410 this_iters : usize, 382 this_iters: usize,
383 epsilon: F,
411 } 384 }
412 385
413 /// Collected experiment statistics 386 /// Collected experiment statistics
414 #[derive(Clone, Debug, Serialize)] 387 #[derive(Clone, Debug, Serialize)]
415 struct ExperimentStats<F : Float> { 388 struct ExperimentStats<F: Float> {
416 /// Signal-to-noise ratio in decibels 389 /// Signal-to-noise ratio in decibels
417 ssnr : F, 390 ssnr: F,
418 /// Proportion of noise in the signal as a number in $[0, 1]$. 391 /// Proportion of noise in the signal as a number in $[0, 1]$.
419 noise_ratio : F, 392 noise_ratio: F,
420 /// When the experiment was run (UTC) 393 /// When the experiment was run (UTC)
421 when : DateTime<Utc>, 394 when: DateTime<Utc>,
422 } 395 }
423 396
424 #[replace_float_literals(F::cast_from(literal))] 397 #[replace_float_literals(F::cast_from(literal))]
425 impl<F : Float> ExperimentStats<F> { 398 impl<F: Float> ExperimentStats<F> {
426 /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. 399 /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal.
427 fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self { 400 fn new<E: Euclidean<F>>(signal: &E, noise: &E) -> Self {
428 let s = signal.norm2_squared(); 401 let s = signal.norm2_squared();
429 let n = noise.norm2_squared(); 402 let n = noise.norm2_squared();
430 let noise_ratio = (n / s).sqrt(); 403 let noise_ratio = (n / s).sqrt();
431 let ssnr = 10.0 * (s / n).log10(); 404 let ssnr = 10.0 * (s / n).log10();
432 ExperimentStats { 405 ExperimentStats { ssnr, noise_ratio, when: Utc::now() }
433 ssnr,
434 noise_ratio,
435 when : Utc::now(),
436 }
437 } 406 }
438 } 407 }
439 /// Collected algorithm statistics 408 /// Collected algorithm statistics
440 #[derive(Clone, Debug, Serialize)] 409 #[derive(Clone, Debug, Serialize)]
441 struct AlgorithmStats<F : Float> { 410 struct AlgorithmStats<F: Float> {
442 /// Overall CPU time spent 411 /// Overall CPU time spent
443 cpu_time : F, 412 cpu_time: F,
444 /// Real time spent 413 /// Real time spent
445 elapsed : F 414 elapsed: F,
446 } 415 }
447
448 416
449 /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input 417 /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input
450 /// and outputs a [`DynError`]. 418 /// and outputs a [`DynError`].
451 fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { 419 fn write_json<T: Serialize>(filename: String, data: &T) -> DynError {
452 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; 420 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?;
453 Ok(()) 421 Ok(())
454 } 422 }
455 423
456
457 /// Struct for experiment configurations 424 /// Struct for experiment configurations
458 #[derive(Debug, Clone, Serialize)] 425 #[derive(Debug, Clone, Serialize)]
459 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize> 426 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N: usize>
460 where F : Float + ClapFloat, 427 where
461 [usize; N] : Serialize, 428 F: Float + ClapFloat,
462 NoiseDistr : Distribution<F>, 429 [usize; N]: Serialize,
463 S : Sensor<F, N>, 430 NoiseDistr: Distribution<F>,
464 P : Spread<F, N>, 431 S: Sensor<N, F>,
465 K : SimpleConvolutionKernel<F, N>, 432 P: Spread<N, F>,
433 K: SimpleConvolutionKernel<N, F>,
466 { 434 {
467 /// Domain $Ω$. 435 /// Domain $Ω$.
468 pub domain : Cube<F, N>, 436 pub domain: Cube<N, F>,
469 /// Number of sensors along each dimension 437 /// Number of sensors along each dimension
470 pub sensor_count : [usize; N], 438 pub sensor_count: [usize; N],
471 /// Noise distribution 439 /// Noise distribution
472 pub noise_distr : NoiseDistr, 440 pub noise_distr: NoiseDistr,
473 /// Seed for random noise generation (for repeatable experiments) 441 /// Seed for random noise generation (for repeatable experiments)
474 pub noise_seed : u64, 442 pub noise_seed: u64,
475 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. 443 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$.
476 pub sensor : S, 444 pub sensor: S,
477 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. 445 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$.
478 pub spread : P, 446 pub spread: P,
479 /// Kernel $ρ$ of $𝒟$. 447 /// Kernel $ρ$ of $𝒟$.
480 pub kernel : K, 448 pub kernel: K,
481 /// True point sources 449 /// True point sources
482 pub μ_hat : RNDM<F, N>, 450 pub μ_hat: RNDM<N, F>,
483 /// Regularisation term and parameter 451 /// Regularisation term and parameter
484 pub regularisation : Regularisation<F>, 452 pub regularisation: Regularisation<F>,
485 /// For plotting : how wide should the kernels be plotted 453 /// For plotting : how wide should the kernels be plotted
486 pub kernel_plot_width : F, 454 pub kernel_plot_width: F,
487 /// Data term 455 /// Data term
488 pub dataterm : DataTerm, 456 pub dataterm: DataTermType,
489 /// A map of default configurations for algorithms 457 /// A map of default configurations for algorithms
490 pub algorithm_overrides : HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>, 458 pub algorithm_overrides: HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>,
491 /// Default merge radius 459 /// Default merge radius
492 pub default_merge_radius : F, 460 pub default_merge_radius: F,
493 } 461 }
494 462
495 #[derive(Debug, Clone, Serialize)] 463 #[derive(Debug, Clone, Serialize)]
496 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize> 464 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N: usize>
497 where F : Float + ClapFloat, 465 where
498 [usize; N] : Serialize, 466 F: Float + ClapFloat,
499 NoiseDistr : Distribution<F>, 467 [usize; N]: Serialize,
500 S : Sensor<F, N>, 468 NoiseDistr: Distribution<F>,
501 P : Spread<F, N>, 469 S: Sensor<N, F>,
502 K : SimpleConvolutionKernel<F, N>, 470 P: Spread<N, F>,
503 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, 471 K: SimpleConvolutionKernel<N, F>,
472 B: Mapping<Loc<N, F>, Codomain = F> + Serialize + std::fmt::Debug,
504 { 473 {
505 /// Basic setup 474 /// Basic setup
506 pub base : ExperimentV2<F, NoiseDistr, S, K, P, N>, 475 pub base: ExperimentV2<F, NoiseDistr, S, K, P, N>,
507 /// Weight of TV term 476 /// Weight of TV term
508 pub λ : F, 477 pub λ: F,
509 /// Bias function 478 /// Bias function
510 pub bias : B, 479 pub bias: B,
511 } 480 }
512 481
513 /// Trait for runnable experiments 482 /// Trait for runnable experiments
514 pub trait RunnableExperiment<F : ClapFloat> { 483 pub trait RunnableExperiment<F: ClapFloat> {
515 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. 484 /// Run all algorithms provided, or default algorithms if none provided, on the experiment.
516 fn runall(&self, cli : &CommandLineArgs, 485 fn runall(
517 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; 486 &self,
487 cli: &CommandLineArgs,
488 algs: Option<Vec<Named<AlgorithmConfig<F>>>>,
489 ) -> DynError;
518 490
519 /// Return algorithm default config 491 /// Return algorithm default config
520 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F>; 492 fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F>;
521 } 493
522 494 /// Experiment name
523 /// Helper function to print experiment start message and save setup. 495 fn name(&self) -> &str;
524 /// Returns saving prefix. 496 }
525 fn start_experiment<E, S>( 497
526 experiment : &Named<E>, 498 /// Error codes for running an algorithm on an experiment.
527 cli : &CommandLineArgs, 499 #[derive(Error, Debug)]
528 stats : S, 500 pub enum RunError {
529 ) -> DynResult<String> 501 /// Algorithm not implemented for this experiment
502 #[error("Algorithm not implemented for this experiment")]
503 NotImplemented,
504 }
505
506 use RunError::*;
507
508 type DoRunAllIt<'a, F, const N: usize> = LoggingIteratorFactory<
509 'a,
510 Timed<IterInfo<F>>,
511 TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F>>>,
512 >;
513
514 pub trait RunnableExperimentExtras<F: ClapFloat>:
515 RunnableExperiment<F> + Serialize + Sized
516 {
517 /// Helper function to print experiment start message and save setup.
518 /// Returns saving prefix.
519 fn start(&self, cli: &CommandLineArgs) -> DynResult<String> {
520 let experiment_name = self.name();
521 let ser = serde_json::to_string(self);
522
523 println!(
524 "{}\n{}",
525 format!("Performing experiment {}…", experiment_name).cyan(),
526 format!(
527 "Experiment settings: {}",
528 if let Ok(ref s) = ser {
529 s
530 } else {
531 "<serialisation failure>"
532 }
533 )
534 .bright_black(),
535 );
536
537 // Set up output directory
538 let prefix = format!("{}/{}/", cli.outdir, experiment_name);
539
540 // Save experiment configuration and statistics
541 std::fs::create_dir_all(&prefix)?;
542 write_json(format!("{prefix}experiment.json"), self)?;
543 write_json(format!("{prefix}config.json"), cli)?;
544
545 Ok(prefix)
546 }
547
548 /// Helper function to run all algorithms on an experiment.
549 fn do_runall<P, Z, Plot, const N: usize>(
550 &self,
551 prefix: &String,
552 cli: &CommandLineArgs,
553 algorithms: Vec<Named<AlgorithmConfig<F>>>,
554 mut make_plotter: impl FnMut(String) -> Plot,
555 mut save_extra: impl FnMut(String, Z) -> DynError,
556 init: P,
557 mut do_alg: impl FnMut(
558 (&AlgorithmConfig<F>, DoRunAllIt<F, N>, Plot, P, String),
559 ) -> DynResult<(RNDM<N, F>, Z)>,
560 ) -> DynError
561 where
562 F: for<'b> Deserialize<'b>,
563 PlotLookup: Plotting<N>,
564 P: Clone,
565 {
566 let experiment_name = self.name();
567
568 let mut logs = Vec::new();
569
570 let iterator_options = AlgIteratorOptions {
571 max_iter: cli.max_iter,
572 verbose_iter: cli
573 .verbose_iter
574 .map_or(Verbose::LogarithmicCap { base: 10, cap: 2 }, |n| {
575 Verbose::Every(n)
576 }),
577 quiet: cli.quiet,
578 };
579
580 // Run the algorithm(s)
581 for named @ Named { name: alg_name, data: alg } in algorithms.iter() {
582 let this_prefix = format!("{}{}/", prefix, alg_name);
583
584 // Create Logger and IteratorFactory
585 let mut logger = Logger::new();
586 let iterator = iterator_options.instantiate().timed().into_log(&mut logger);
587
588 let running = if !cli.quiet {
589 format!(
590 "{}\n{}\n{}\n",
591 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(),
592 format!(
593 "Iteration settings: {}",
594 serde_json::to_string(&iterator_options)?
595 )
596 .bright_black(),
597 format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black()
598 )
599 } else {
600 "".to_string()
601 };
602 //
603 // The following is for postprocessing, which has been disabled anyway.
604 //
605 // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation {
606 // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)),
607 // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)),
608 // };
609 //let findim_data = reg.prepare_optimise_weights(&opA, &b);
610 //let inner_config : InnerSettings<F> = Default::default();
611 //let inner_it = inner_config.iterator_options;
612
613 // Create plotter and directory if needed.
614 let plotter = make_plotter(this_prefix);
615
616 let start = Instant::now();
617 let start_cpu = ProcessTime::now();
618
619 let (μ, z) = match do_alg((alg, iterator, plotter, init.clone(), running)) {
620 Ok(μ) => μ,
621 Err(e) => {
622 let msg = format!(
623 "Skipping algorithm “{alg_name}” for {experiment_name} due to error: {e}"
624 )
625 .red();
626 eprintln!("{}", msg);
627 continue;
628 }
629 };
630
631 let elapsed = start.elapsed().as_secs_f64();
632 let cpu_time = start_cpu.elapsed().as_secs_f64();
633
634 println!(
635 "{}",
636 format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()
637 );
638
639 // Save results
640 println!("{}", "Saving results …".green());
641
642 let mkname = |t| format!("{prefix}{alg_name}_{t}");
643
644 write_json(mkname("config.json"), &named)?;
645 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?;
646 μ.write_csv(mkname("reco.txt"))?;
647 save_extra(mkname(""), z)?;
648 //logger.write_csv(mkname("log.txt"))?;
649 logs.push((mkname("log.txt"), logger));
650 }
651
652 save_logs(
653 logs,
654 format!("{prefix}valuerange.json"),
655 cli.load_valuerange,
656 )
657 }
658 }
659
660 impl<F, E> RunnableExperimentExtras<F> for E
530 where 661 where
531 E : Serialize + std::fmt::Debug, 662 F: ClapFloat,
532 S : Serialize, 663 Self: RunnableExperiment<F> + Serialize,
533 { 664 {
534 let Named { name : experiment_name, data } = experiment; 665 }
535 666
536 println!("{}\n{}", 667 #[replace_float_literals(F::cast_from(literal))]
537 format!("Performing experiment {}…", experiment_name).cyan(), 668 impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N: usize> RunnableExperiment<F>
538 format!("Experiment settings: {}", serde_json::to_string(&data)?).bright_black()); 669 for Named<ExperimentV2<F, NoiseDistr, S, K, P, N>>
539
540 // Set up output directory
541 let prefix = format!("{}/{}/", cli.outdir, experiment_name);
542
543 // Save experiment configuration and statistics
544 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t);
545 std::fs::create_dir_all(&prefix)?;
546 write_json(mkname_e("experiment"), experiment)?;
547 write_json(mkname_e("config"), cli)?;
548 write_json(mkname_e("stats"), &stats)?;
549
550 Ok(prefix)
551 }
552
553 /// Error codes for running an algorithm on an experiment.
554 enum RunError {
555 /// Algorithm not implemented for this experiment
556 NotImplemented,
557 }
558
559 use RunError::*;
560
561 type DoRunAllIt<'a, F, const N : usize> = LoggingIteratorFactory<
562 'a,
563 Timed<IterInfo<F, N>>,
564 TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F, N>>>
565 >;
566
567 /// Helper function to run all algorithms on an experiment.
568 fn do_runall<F : Float + for<'b> Deserialize<'b>, Z, const N : usize>(
569 experiment_name : &String,
570 prefix : &String,
571 cli : &CommandLineArgs,
572 algorithms : Vec<Named<AlgorithmConfig<F>>>,
573 plotgrid : LinSpace<Loc<F, N>, [usize; N]>,
574 mut save_extra : impl FnMut(String, Z) -> DynError,
575 mut do_alg : impl FnMut(
576 &AlgorithmConfig<F>,
577 DoRunAllIt<F, N>,
578 SeqPlotter<F, N>,
579 String,
580 ) -> Result<(RNDM<F, N>, Z), RunError>,
581 ) -> DynError
582 where 670 where
583 PlotLookup : Plotting<N>, 671 F: ClapFloat
672 + nalgebra::RealField
673 + ToNalgebraRealField<MixedType = F>
674 + Default
675 + for<'b> Deserialize<'b>,
676 [usize; N]: Serialize,
677 S: Sensor<N, F> + Copy + Serialize + std::fmt::Debug,
678 P: Spread<N, F> + Copy + Serialize + std::fmt::Debug,
679 Convolution<S, P>: Spread<N, F>
680 + Bounded<F>
681 + LocalAnalysis<F, Bounds<F>, N>
682 + Copy
683 // TODO: shold not have differentiability as a requirement, but
684 // decide availability of sliding based on it.
685 //+ for<'b> Differentiable<&'b Loc<N, F>, Output = Loc<N, F>>,
686 // TODO: very weird that rust only compiles with Differentiable
687 // instead of the above one on references, which is required by
688 // poitsource_sliding_fb_reg.
689 + DifferentiableRealMapping<N, F>
690 + Lipschitz<L2, FloatType = F>,
691 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<N, F>>>::Differential<'b>:
692 Lipschitz<L2, FloatType = F>, // TODO: should not be required generally, only for sliding_fb.
693 AutoConvolution<P>: BoundedBy<F, K>,
694 K: SimpleConvolutionKernel<N, F>
695 + LocalAnalysis<F, Bounds<F>, N>
696 + Copy
697 + Serialize
698 + std::fmt::Debug,
699 Cube<N, F>: P2Minimise<Loc<N, F>, F> + SetOrd,
700 PlotLookup: Plotting<N>,
701 DefaultBT<F, N>: SensorGridBT<F, S, P, N, Depth = DynamicDepth> + BTSearch<N, F>,
702 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
703 RNDM<N, F>: SpikeMerging<F>,
704 NoiseDistr: Distribution<F> + Serialize + std::fmt::Debug,
584 { 705 {
585 let mut logs = Vec::new(); 706 fn name(&self) -> &str {
586 707 self.name.as_ref()
587 let iterator_options = AlgIteratorOptions{ 708 }
588 max_iter : cli.max_iter, 709
589 verbose_iter : cli.verbose_iter 710 fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F> {
590 .map_or(Verbose::LogarithmicCap{base : 10, cap : 2},
591 |n| Verbose::Every(n)),
592 quiet : cli.quiet,
593 };
594
595 // Run the algorithm(s)
596 for named @ Named { name : alg_name, data : alg } in algorithms.iter() {
597 let this_prefix = format!("{}{}/", prefix, alg_name);
598
599 // Create Logger and IteratorFactory
600 let mut logger = Logger::new();
601 let iterator = iterator_options.instantiate()
602 .timed()
603 .into_log(&mut logger);
604
605 let running = if !cli.quiet {
606 format!("{}\n{}\n{}\n",
607 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(),
608 format!("Iteration settings: {}", serde_json::to_string(&iterator_options)?).bright_black(),
609 format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black())
610 } else {
611 "".to_string()
612 };
613 //
614 // The following is for postprocessing, which has been disabled anyway.
615 //
616 // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation {
617 // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)),
618 // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)),
619 // };
620 //let findim_data = reg.prepare_optimise_weights(&opA, &b);
621 //let inner_config : InnerSettings<F> = Default::default();
622 //let inner_it = inner_config.iterator_options;
623
624 // Create plotter and directory if needed.
625 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 };
626 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid.clone());
627
628 let start = Instant::now();
629 let start_cpu = ProcessTime::now();
630
631 let (μ, z) = match do_alg(alg, iterator, plotter, running) {
632 Ok(μ) => μ,
633 Err(RunError::NotImplemented) => {
634 let msg = format!("Algorithm “{alg_name}” not implemented for {experiment_name}. \
635 Skipping.").red();
636 eprintln!("{}", msg);
637 continue
638 }
639 };
640
641 let elapsed = start.elapsed().as_secs_f64();
642 let cpu_time = start_cpu.elapsed().as_secs_f64();
643
644 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow());
645
646 // Save results
647 println!("{}", "Saving results …".green());
648
649 let mkname = |t| format!("{prefix}{alg_name}_{t}");
650
651 write_json(mkname("config.json"), &named)?;
652 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?;
653 μ.write_csv(mkname("reco.txt"))?;
654 save_extra(mkname(""), z)?;
655 //logger.write_csv(mkname("log.txt"))?;
656 logs.push((mkname("log.txt"), logger));
657 }
658
659 save_logs(logs, format!("{prefix}valuerange.json"), cli.load_valuerange)
660 }
661
662 #[replace_float_literals(F::cast_from(literal))]
663 impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N : usize> RunnableExperiment<F> for
664 Named<ExperimentV2<F, NoiseDistr, S, K, P, N>>
665 where
666 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>
667 + Default + for<'b> Deserialize<'b>,
668 [usize; N] : Serialize,
669 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
670 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug,
671 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy
672 // TODO: shold not have differentiability as a requirement, but
673 // decide availability of sliding based on it.
674 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>,
675 // TODO: very weird that rust only compiles with Differentiable
676 // instead of the above one on references, which is required by
677 // poitsource_sliding_fb_reg.
678 + DifferentiableRealMapping<F, N>
679 + Lipschitz<L2, FloatType=F>,
680 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb.
681 AutoConvolution<P> : BoundedBy<F, K>,
682 K : SimpleConvolutionKernel<F, N>
683 + LocalAnalysis<F, Bounds<F>, N>
684 + Copy + Serialize + std::fmt::Debug,
685 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
686 PlotLookup : Plotting<N>,
687 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
688 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
689 RNDM<F, N> : SpikeMerging<F>,
690 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug,
691 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>,
692 // PreadjointCodomain : Space + Bounded<F> + DifferentiableRealMapping<F, N>,
693 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
694 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
695 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
696 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
697 {
698
699 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> {
700 AlgorithmOverrides { 711 AlgorithmOverrides {
701 merge_radius : Some(self.data.default_merge_radius), 712 merge_radius: Some(self.data.default_merge_radius),
702 .. self.data.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) 713 ..self
703 } 714 .data
704 } 715 .algorithm_overrides
705 716 .get(&alg)
706 fn runall(&self, cli : &CommandLineArgs, 717 .cloned()
707 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { 718 .unwrap_or(Default::default())
719 }
720 }
721
722 fn runall(
723 &self,
724 cli: &CommandLineArgs,
725 algs: Option<Vec<Named<AlgorithmConfig<F>>>>,
726 ) -> DynError {
708 // Get experiment configuration 727 // Get experiment configuration
709 let &Named { 728 let &ExperimentV2 {
710 name : ref experiment_name, 729 domain,
711 data : ExperimentV2 { 730 sensor_count,
712 domain, sensor_count, ref noise_distr, sensor, spread, kernel, 731 ref noise_distr,
713 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, 732 sensor,
714 .. 733 spread,
715 } 734 kernel,
716 } = self; 735 ref μ_hat,
736 regularisation,
737 kernel_plot_width,
738 dataterm,
739 noise_seed,
740 ..
741 } = &self.data;
717 742
718 // Set up algorithms 743 // Set up algorithms
719 let algorithms = match (algs, dataterm) { 744 let algorithms = match (algs, dataterm) {
720 (Some(algs), _) => algs, 745 (Some(algs), _) => algs,
721 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], 746 (None, DataTermType::L222) => vec![DefaultAlgorithm::FB.get_named()],
722 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], 747 (None, DataTermType::L1) => vec![DefaultAlgorithm::PDPS.get_named()],
723 }; 748 };
724 749
725 // Set up operators 750 // Set up operators
726 let depth = DynamicDepth(8); 751 let depth = DynamicDepth(8);
727 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); 752 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth);
736 let b = &b_hat + &noise; 761 let b = &b_hat + &noise;
737 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField 762 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
738 // overloading log10 and conflicting with standard NumTraits one. 763 // overloading log10 and conflicting with standard NumTraits one.
739 let stats = ExperimentStats::new(&b, &noise); 764 let stats = ExperimentStats::new(&b, &noise);
740 765
741 let prefix = start_experiment(&self, cli, stats)?; 766 let prefix = self.start(cli)?;
742 767 write_json(format!("{prefix}stats.json"), &stats)?;
743 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, 768
744 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; 769 plotall(
745 770 cli,
746 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); 771 &prefix,
772 &domain,
773 &sensor,
774 &kernel,
775 &spread,
776 &μ_hat,
777 &op𝒟,
778 &opA,
779 &b_hat,
780 &b,
781 kernel_plot_width,
782 )?;
783
784 let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]);
785 let make_plotter = |this_prefix| {
786 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 };
787 SeqPlotter::new(this_prefix, plot_count, plotgrid.clone())
788 };
747 789
748 let save_extra = |_, ()| Ok(()); 790 let save_extra = |_, ()| Ok(());
749 791
750 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, 792 let μ0 = None; // Zero init
751 |alg, iterator, plotter, running| 793
752 { 794 match (dataterm, regularisation) {
753 let μ = match alg { 795 (DataTermType::L1, Regularisation::Radon(α)) => {
754 AlgorithmConfig::FB(ref algconfig, prox) => { 796 let f = DataTerm::new(opA, b, L1.as_mapping());
755 match (regularisation, dataterm, prox) { 797 let reg = RadonRegTerm(α);
756 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 798 self.do_runall(
757 print!("{running}"); 799 &prefix,
758 pointsource_fb_reg( 800 cli,
759 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 801 algorithms,
760 iterator, plotter 802 make_plotter,
761 ) 803 save_extra,
762 }), 804 μ0,
763 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 805 |p| {
764 print!("{running}"); 806 run_pdps(&f, &reg, &RadonSquared, p, |p| {
765 pointsource_fb_reg( 807 run_pdps(&f, &reg, &op𝒟, p, |_| Err(NotImplemented.into()))
766 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 808 })
767 iterator, plotter 809 .map(|μ| (μ, ()))
768 ) 810 },
769 }), 811 )
770 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 812 }
771 print!("{running}"); 813 (DataTermType::L1, Regularisation::NonnegRadon(α)) => {
772 pointsource_fb_reg( 814 let f = DataTerm::new(opA, b, L1.as_mapping());
773 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 815 let reg = NonnegRadonRegTerm(α);
774 iterator, plotter 816 self.do_runall(
775 ) 817 &prefix,
776 }), 818 cli,
777 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 819 algorithms,
778 print!("{running}"); 820 make_plotter,
779 pointsource_fb_reg( 821 save_extra,
780 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, 822 μ0,
781 iterator, plotter 823 |p| {
782 ) 824 run_pdps(&f, &reg, &RadonSquared, p, |p| {
783 }), 825 run_pdps(&f, &reg, &op𝒟, p, |_| Err(NotImplemented.into()))
784 _ => Err(NotImplemented) 826 })
785 } 827 .map(|μ| (μ, ()))
786 }, 828 },
787 AlgorithmConfig::FISTA(ref algconfig, prox) => { 829 )
788 match (regularisation, dataterm, prox) { 830 }
789 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 831 (DataTermType::L222, Regularisation::Radon(α)) => {
790 print!("{running}"); 832 let f = DataTerm::new(opA, b, Norm222::new());
791 pointsource_fista_reg( 833 let reg = RadonRegTerm(α);
792 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 834 self.do_runall(
793 iterator, plotter 835 &prefix,
794 ) 836 cli,
795 }), 837 algorithms,
796 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 838 make_plotter,
797 print!("{running}"); 839 save_extra,
798 pointsource_fista_reg( 840 μ0,
799 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 841 |p| {
800 iterator, plotter 842 run_fb(&f, &reg, &RadonSquared, p, |p| {
801 ) 843 run_pdps(&f, &reg, &RadonSquared, p, |p| {
802 }), 844 run_fb(&f, &reg, &op𝒟, p, |p| {
803 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 845 run_pdps(&f, &reg, &op𝒟, p, |p| {
804 print!("{running}"); 846 run_fw(&f, &reg, p, |_| Err(NotImplemented.into()))
805 pointsource_fista_reg( 847 })
806 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 848 })
807 iterator, plotter 849 })
808 ) 850 })
809 }), 851 .map(|μ| (μ, ()))
810 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 852 },
811 print!("{running}"); 853 )
812 pointsource_fista_reg( 854 }
813 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, 855 (DataTermType::L222, Regularisation::NonnegRadon(α)) => {
814 iterator, plotter 856 let f = DataTerm::new(opA, b, Norm222::new());
815 ) 857 let reg = NonnegRadonRegTerm(α);
816 }), 858 self.do_runall(
817 _ => Err(NotImplemented), 859 &prefix,
818 } 860 cli,
819 }, 861 algorithms,
820 AlgorithmConfig::SlidingFB(ref algconfig, prox) => { 862 make_plotter,
821 match (regularisation, dataterm, prox) { 863 save_extra,
822 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 864 μ0,
823 print!("{running}"); 865 |p| {
824 pointsource_sliding_fb_reg( 866 run_fb(&f, &reg, &RadonSquared, p, |p| {
825 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 867 run_pdps(&f, &reg, &RadonSquared, p, |p| {
826 iterator, plotter 868 run_fb(&f, &reg, &op𝒟, p, |p| {
827 ) 869 run_pdps(&f, &reg, &op𝒟, p, |p| {
828 }), 870 run_fw(&f, &reg, p, |_| Err(NotImplemented.into()))
829 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 871 })
830 print!("{running}"); 872 })
831 pointsource_sliding_fb_reg( 873 })
832 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 874 })
833 iterator, plotter 875 .map(|μ| (μ, ()))
834 ) 876 },
835 }), 877 )
836 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 878 }
837 print!("{running}"); 879 }
838 pointsource_sliding_fb_reg( 880 }
839 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 881 }
840 iterator, plotter 882
841 ) 883 /// Runs PDPS if `alg` so requests and `prox_penalty` matches.
842 }), 884 ///
843 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 885 /// Due to the structure of the PDPS, the data term `f` has to have a specific form.
844 print!("{running}"); 886 ///
845 pointsource_sliding_fb_reg( 887 /// `cont` gives a continuation attempt to find the algorithm matching the description `alg`.
846 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, 888 pub fn run_pdps<'a, F, A, Phi, Reg, P, I, Plot, const N: usize>(
847 iterator, plotter 889 f: &'a DataTerm<F, RNDM<N, F>, A, Phi>,
848 ) 890 reg: &Reg,
849 }), 891 prox_penalty: &P,
850 _ => Err(NotImplemented), 892 (alg, iterator, plotter, μ0, running): (
851 } 893 &AlgorithmConfig<F>,
852 }, 894 I,
853 AlgorithmConfig::PDPS(ref algconfig, prox) => { 895 Plot,
854 print!("{running}"); 896 Option<RNDM<N, F>>,
855 match (regularisation, dataterm, prox) { 897 String,
856 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 898 ),
857 pointsource_pdps_reg( 899 cont: impl FnOnce(
858 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 900 (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String),
859 iterator, plotter, L2Squared 901 ) -> DynResult<RNDM<N, F>>,
860 ) 902 ) -> DynResult<RNDM<N, F>>
861 }), 903 where
862 (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 904 F: Float + ToNalgebraRealField,
863 pointsource_pdps_reg( 905 A: ForwardModel<RNDM<N, F>, F>,
864 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 906 Phi: Conjugable<A::Observable, F>,
865 iterator, plotter, L2Squared 907 for<'b> Phi::Conjugate<'b>: Prox<A::Observable>,
866 ) 908 for<'b> &'b A::Observable: Instance<A::Observable>,
867 }), 909 A::Observable: AXPY,
868 (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ 910 Reg: SlidingRegTerm<Loc<N, F>, F>,
869 pointsource_pdps_reg( 911 P: ProxPenalty<Loc<N, F>, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD<F, A, RNDM<N, F>>,
870 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 912 RNDM<N, F>: SpikeMerging<F>,
871 iterator, plotter, L1 913 I: AlgIteratorFactory<IterInfo<F>>,
872 ) 914 Plot: Plotter<P::ReturnMapping, A::PreadjointCodomain, RNDM<N, F>>,
873 }), 915 {
874 (Regularisation::Radon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ 916 match alg {
875 pointsource_pdps_reg( 917 &AlgorithmConfig::PDPS(ref algconfig, prox_type) if prox_type == P::prox_type() => {
876 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 918 print!("{running}");
877 iterator, plotter, L1 919 pointsource_pdps_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0)
878 ) 920 }
879 }), 921 _ => cont((alg, iterator, plotter, μ0, running)),
880 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 922 }
881 pointsource_pdps_reg( 923 }
882 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 924
883 iterator, plotter, L2Squared 925 /// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches.
884 ) 926 ///
885 }), 927 /// `cont` gives a continuation attempt to find the algorithm matching the description `alg`.
886 (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 928 pub fn run_fb<F, Dat, Reg, P, I, Plot, const N: usize>(
887 pointsource_pdps_reg( 929 f: &Dat,
888 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, 930 reg: &Reg,
889 iterator, plotter, L2Squared 931 prox_penalty: &P,
890 ) 932 (alg, iterator, plotter, μ0, running): (
891 }), 933 &AlgorithmConfig<F>,
892 (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ 934 I,
893 pointsource_pdps_reg( 935 Plot,
894 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 936 Option<RNDM<N, F>>,
895 iterator, plotter, L1 937 String,
896 ) 938 ),
897 }), 939 cont: impl FnOnce(
898 (Regularisation::Radon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ 940 (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String),
899 pointsource_pdps_reg( 941 ) -> DynResult<RNDM<N, F>>,
900 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, 942 ) -> DynResult<RNDM<N, F>>
901 iterator, plotter, L1 943 where
902 ) 944 F: Float + ToNalgebraRealField,
903 }), 945 I: AlgIteratorFactory<IterInfo<F>>,
904 // _ => Err(NotImplemented), 946 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>,
905 } 947 Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>,
906 }, 948 RNDM<N, F>: SpikeMerging<F>,
907 AlgorithmConfig::FW(ref algconfig) => { 949 Reg: SlidingRegTerm<Loc<N, F>, F>,
908 match (regularisation, dataterm) { 950 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
909 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ 951 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
910 print!("{running}"); 952 {
911 pointsource_fw_reg(&opA, &b, RadonRegTerm(α), 953 let pt = P::prox_type();
912 algconfig, iterator, plotter) 954
913 }), 955 match alg {
914 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ 956 &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => {
915 print!("{running}"); 957 print!("{running}");
916 pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), 958 pointsource_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0)
917 algconfig, iterator, plotter) 959 }
918 }), 960 &AlgorithmConfig::FISTA(ref algconfig, prox_type) if prox_type == pt => {
919 _ => Err(NotImplemented), 961 print!("{running}");
920 } 962 pointsource_fista_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0)
921 }, 963 }
922 _ => Err(NotImplemented), 964 &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => {
923 }?; 965 print!("{running}");
924 Ok((μ, ())) 966 pointsource_sliding_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0)
925 }) 967 }
926 } 968 _ => cont((alg, iterator, plotter, μ0, running)),
927 } 969 }
928 970 }
971
972 /// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches.
973 ///
974 /// For the moment, due to restrictions of the Frank–Wolfe implementation, only the
975 /// $L^2$-squared data term is enabled through the type signatures.
976 ///
977 /// `cont` gives a continuation attempt to find the algorithm matching the description `alg`.
978 pub fn run_fw<'a, F, A, Reg, I, Plot, const N: usize>(
979 f: &'a DataTerm<F, RNDM<N, F>, A, Norm222<F>>,
980 reg: &Reg,
981 (alg, iterator, plotter, μ0, running): (
982 &AlgorithmConfig<F>,
983 I,
984 Plot,
985 Option<RNDM<N, F>>,
986 String,
987 ),
988 cont: impl FnOnce(
989 (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String),
990 ) -> DynResult<RNDM<N, F>>,
991 ) -> DynResult<RNDM<N, F>>
992 where
993 F: Float + ToNalgebraRealField,
994 I: AlgIteratorFactory<IterInfo<F>>,
995 A: ForwardModel<RNDM<N, F>, F>,
996 A::PreadjointCodomain: MinMaxMapping<Loc<N, F>, F>,
997 for<'b> &'b A::PreadjointCodomain: Instance<A::PreadjointCodomain>,
998 Cube<N, F>: P2Minimise<Loc<N, F>, F>,
999 RNDM<N, F>: SpikeMerging<F>,
1000 Reg: RegTermFW<F, A, ValueIteratorFactory<F, AlgIteratorOptions>, N>,
1001 Plot: Plotter<A::PreadjointCodomain, A::PreadjointCodomain, RNDM<N, F>>,
1002 {
1003 match alg {
1004 &AlgorithmConfig::FW(ref algconfig) => {
1005 print!("{running}");
1006 pointsource_fw_reg(f, reg, algconfig, iterator, plotter, μ0)
1007 }
1008 _ => cont((alg, iterator, plotter, μ0, running)),
1009 }
1010 }
929 1011
930 #[replace_float_literals(F::cast_from(literal))] 1012 #[replace_float_literals(F::cast_from(literal))]
931 impl<F, NoiseDistr, S, K, P, B, /*PreadjointCodomain,*/ const N : usize> RunnableExperiment<F> for 1013 impl<F, NoiseDistr, S, K, P, B, const N: usize> RunnableExperiment<F>
932 Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> 1014 for Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>>
933 where 1015 where
934 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> 1016 F: ClapFloat
935 + Default + for<'b> Deserialize<'b>, 1017 + nalgebra::RealField
936 [usize; N] : Serialize, 1018 + ToNalgebraRealField<MixedType = F>
937 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, 1019 + Default
938 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, 1020 + for<'b> Deserialize<'b>,
939 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy 1021 [usize; N]: Serialize,
940 // TODO: shold not have differentiability as a requirement, but 1022 S: Sensor<N, F> + Copy + Serialize + std::fmt::Debug,
941 // decide availability of sliding based on it. 1023 P: Spread<N, F> + Copy + Serialize + std::fmt::Debug,
942 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, 1024 Convolution<S, P>: Spread<N, F>
943 // TODO: very weird that rust only compiles with Differentiable 1025 + Bounded<F>
944 // instead of the above one on references, which is required by
945 // poitsource_sliding_fb_reg.
946 + DifferentiableRealMapping<F, N>
947 + Lipschitz<L2, FloatType=F>,
948 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb.
949 AutoConvolution<P> : BoundedBy<F, K>,
950 K : SimpleConvolutionKernel<F, N>
951 + LocalAnalysis<F, Bounds<F>, N> 1026 + LocalAnalysis<F, Bounds<F>, N>
952 + Copy + Serialize + std::fmt::Debug, 1027 + Copy
953 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, 1028 // TODO: shold not have differentiability as a requirement, but
954 PlotLookup : Plotting<N>, 1029 // decide availability of sliding based on it.
955 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, 1030 //+ for<'b> Differentiable<&'b Loc<N, F>, Output = Loc<N, F>>,
1031 // TODO: very weird that rust only compiles with Differentiable
1032 // instead of the above one on references, which is required by
1033 // poitsource_sliding_fb_reg.
1034 + DifferentiableRealMapping<N, F>
1035 + Lipschitz<L2, FloatType = F>,
1036 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<N, F>>>::Differential<'b>:
1037 Lipschitz<L2, FloatType = F>, // TODO: should not be required generally, only for sliding_fb.
1038 AutoConvolution<P>: BoundedBy<F, K>,
1039 K: SimpleConvolutionKernel<N, F>
1040 + LocalAnalysis<F, Bounds<F>, N>
1041 + Copy
1042 + Serialize
1043 + std::fmt::Debug,
1044 Cube<N, F>: P2Minimise<Loc<N, F>, F> + SetOrd,
1045 PlotLookup: Plotting<N>,
1046 DefaultBT<F, N>: SensorGridBT<F, S, P, N, Depth = DynamicDepth> + BTSearch<N, F>,
956 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 1047 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
957 RNDM<F, N> : SpikeMerging<F>, 1048 RNDM<N, F>: SpikeMerging<F>,
958 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, 1049 NoiseDistr: Distribution<F> + Serialize + std::fmt::Debug,
959 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, 1050 B: Mapping<Loc<N, F>, Codomain = F> + Serialize + std::fmt::Debug,
960 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, 1051 nalgebra::DVector<F>: ClosedMul<F>,
961 // PreadjointCodomain : Bounded<F> + DifferentiableRealMapping<F, N>, 1052 // This is mainly required for the final Mul requirement to be defined
1053 // DefaultSG<F, S, P, N>: ForwardModel<
1054 // RNDM<N, F>,
1055 // F,
1056 // PreadjointCodomain = PreadjointCodomain,
1057 // Observable = DVector<F::MixedType>,
1058 // >,
1059 // PreadjointCodomain: Bounded<F> + DifferentiableRealMapping<N, F> + std::ops::Mul<F>,
1060 // Pair<PreadjointCodomain, DVector<F>>: std::ops::Mul<F>,
962 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, 1061 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
963 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, 1062 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
964 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, 1063 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
965 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, 1064 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
966 { 1065 {
967 1066 fn name(&self) -> &str {
968 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { 1067 self.name.as_ref()
1068 }
1069
1070 fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F> {
969 AlgorithmOverrides { 1071 AlgorithmOverrides {
970 merge_radius : Some(self.data.base.default_merge_radius), 1072 merge_radius: Some(self.data.base.default_merge_radius),
971 .. self.data.base.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) 1073 ..self
972 } 1074 .data
973 } 1075 .base
974 1076 .algorithm_overrides
975 fn runall(&self, cli : &CommandLineArgs, 1077 .get(&alg)
976 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { 1078 .cloned()
1079 .unwrap_or(Default::default())
1080 }
1081 }
1082
1083 fn runall(
1084 &self,
1085 cli: &CommandLineArgs,
1086 algs: Option<Vec<Named<AlgorithmConfig<F>>>>,
1087 ) -> DynError {
977 // Get experiment configuration 1088 // Get experiment configuration
978 let &Named { 1089 let &ExperimentBiased {
979 name : ref experiment_name, 1090 λ,
980 data : ExperimentBiased { 1091 ref bias,
981 λ, 1092 base:
982 ref bias, 1093 ExperimentV2 {
983 base : ExperimentV2 { 1094 domain,
984 domain, sensor_count, ref noise_distr, sensor, spread, kernel, 1095 sensor_count,
985 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, 1096 ref noise_distr,
1097 sensor,
1098 spread,
1099 kernel,
1100 ref μ_hat,
1101 regularisation,
1102 kernel_plot_width,
1103 dataterm,
1104 noise_seed,
986 .. 1105 ..
987 } 1106 },
988 } 1107 } = &self.data;
989 } = self;
990 1108
991 // Set up algorithms 1109 // Set up algorithms
992 let algorithms = match (algs, dataterm) { 1110 let algorithms = match (algs, dataterm) {
993 (Some(algs), _) => algs, 1111 (Some(algs), _) => algs,
994 _ => vec![DefaultAlgorithm::SlidingPDPS.get_named()], 1112 _ => vec![DefaultAlgorithm::SlidingPDPS.get_named()],
998 let depth = DynamicDepth(8); 1116 let depth = DynamicDepth(8);
999 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); 1117 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth);
1000 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); 1118 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel);
1001 let opAext = RowOp(opA.clone(), IdOp::new()); 1119 let opAext = RowOp(opA.clone(), IdOp::new());
1002 let fnR = Zero::new(); 1120 let fnR = Zero::new();
1003 let h = map3(domain.span_start(), domain.span_end(), sensor_count, 1121 let h = map3(
1004 |a, b, n| (b-a)/F::cast_from(n)) 1122 domain.span_start(),
1005 .into_iter() 1123 domain.span_end(),
1006 .reduce(NumTraitsFloat::max) 1124 sensor_count,
1007 .unwrap(); 1125 |a, b, n| (b - a) / F::cast_from(n),
1126 )
1127 .into_iter()
1128 .reduce(NumTraitsFloat::max)
1129 .unwrap();
1008 let z = DVector::zeros(sensor_count.iter().product()); 1130 let z = DVector::zeros(sensor_count.iter().product());
1009 let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap(); 1131 let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap();
1010 let y = opKz.apply(&z); 1132 let y = opKz.apply(&z);
1011 let fnH = Weighted{ base_fn : L1.as_mapping(), weight : λ}; // TODO: L_{2,1} 1133 let fnH = Weighted { base_fn: L1.as_mapping(), weight: λ }; // TODO: L_{2,1}
1012 // let zero_y = y.clone(); 1134 // let zero_y = y.clone();
1013 // let zeroBTFN = opA.preadjoint().apply(&zero_y); 1135 // let zeroBTFN = opA.preadjoint().apply(&zero_y);
1014 // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); 1136 // let opKμ = ZeroOp::new(&zero_y, zeroBTFN);
1015 1137
1016 // Set up random number generator. 1138 // Set up random number generator.
1017 let mut rng = StdRng::seed_from_u64(noise_seed); 1139 let mut rng = StdRng::seed_from_u64(noise_seed);
1018 1140
1019 // Generate the data and calculate SSNR statistic 1141 // Generate the data and calculate SSNR statistic
1020 let bias_vec = DVector::from_vec(opA.grid() 1142 let bias_vec = DVector::from_vec(
1021 .into_iter() 1143 opA.grid()
1022 .map(|v| bias.apply(v)) 1144 .into_iter()
1023 .collect::<Vec<F>>()); 1145 .map(|v| bias.apply(v))
1024 let b_hat : DVector<_> = opA.apply(μ_hat) + &bias_vec; 1146 .collect::<Vec<F>>(),
1147 );
1148 let b_hat: DVector<_> = opA.apply(μ_hat) + &bias_vec;
1025 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); 1149 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng);
1026 let b = &b_hat + &noise; 1150 let b = &b_hat + &noise;
1027 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField 1151 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
1028 // overloading log10 and conflicting with standard NumTraits one. 1152 // overloading log10 and conflicting with standard NumTraits one.
1029 let stats = ExperimentStats::new(&b, &noise); 1153 let stats = ExperimentStats::new(&b, &noise);
1030 1154
1031 let prefix = start_experiment(&self, cli, stats)?; 1155 let prefix = self.start(cli)?;
1032 1156 write_json(format!("{prefix}stats.json"), &stats)?;
1033 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, 1157
1034 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; 1158 plotall(
1159 cli,
1160 &prefix,
1161 &domain,
1162 &sensor,
1163 &kernel,
1164 &spread,
1165 &μ_hat,
1166 &op𝒟,
1167 &opA,
1168 &b_hat,
1169 &b,
1170 kernel_plot_width,
1171 )?;
1035 1172
1036 opA.write_observable(&bias_vec, format!("{prefix}bias"))?; 1173 opA.write_observable(&bias_vec, format!("{prefix}bias"))?;
1037 1174
1038 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); 1175 let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]);
1176 let make_plotter = |this_prefix| {
1177 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 };
1178 SeqPlotter::new(this_prefix, plot_count, plotgrid.clone())
1179 };
1039 1180
1040 let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z")); 1181 let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z"));
1041 1182
1042 // Run the algorithms 1183 let μ0 = None; // Zero init
1043 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, 1184
1044 |alg, iterator, plotter, running| 1185 match (dataterm, regularisation) {
1045 { 1186 (DataTermType::L222, Regularisation::Radon(α)) => {
1046 let Pair(μ, z) = match alg { 1187 let f = DataTerm::new(opAext, b, Norm222::new());
1047 AlgorithmConfig::ForwardPDPS(ref algconfig, prox) => { 1188 let reg = RadonRegTerm(α);
1048 match (regularisation, dataterm, prox) { 1189 self.do_runall(
1049 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 1190 &prefix,
1050 print!("{running}"); 1191 cli,
1051 pointsource_forward_pdps_pair( 1192 algorithms,
1052 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 1193 make_plotter,
1053 iterator, plotter, 1194 save_extra,
1054 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1195 (μ0, z, y),
1055 ) 1196 |p| {
1056 }), 1197 run_pdps_pair(&f, &reg, &RadonSquared, &opKz, &fnR, &fnH, p, |q| {
1057 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 1198 run_pdps_pair(&f, &reg, &op𝒟, &opKz, &fnR, &fnH, q, |_| {
1058 print!("{running}"); 1199 Err(NotImplemented.into())
1059 pointsource_forward_pdps_pair( 1200 })
1060 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, 1201 })
1061 iterator, plotter, 1202 .map(|Pair(μ, z)| (μ, z))
1062 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1203 },
1063 ) 1204 )
1064 }), 1205 }
1065 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 1206 (DataTermType::L222, Regularisation::NonnegRadon(α)) => {
1066 print!("{running}"); 1207 let f = DataTerm::new(opAext, b, Norm222::new());
1067 pointsource_forward_pdps_pair( 1208 let reg = NonnegRadonRegTerm(α);
1068 &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 1209 self.do_runall(
1069 iterator, plotter, 1210 &prefix,
1070 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1211 cli,
1071 ) 1212 algorithms,
1072 }), 1213 make_plotter,
1073 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 1214 save_extra,
1074 print!("{running}"); 1215 (μ0, z, y),
1075 pointsource_forward_pdps_pair( 1216 |p| {
1076 &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, 1217 run_pdps_pair(&f, &reg, &RadonSquared, &opKz, &fnR, &fnH, p, |q| {
1077 iterator, plotter, 1218 run_pdps_pair(&f, &reg, &op𝒟, &opKz, &fnR, &fnH, q, |_| {
1078 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1219 Err(NotImplemented.into())
1079 ) 1220 })
1080 }), 1221 })
1081 _ => Err(NotImplemented) 1222 .map(|Pair(μ, z)| (μ, z))
1082 } 1223 },
1083 }, 1224 )
1084 AlgorithmConfig::SlidingPDPS(ref algconfig, prox) => { 1225 }
1085 match (regularisation, dataterm, prox) { 1226 _ => Err(NotImplemented.into()),
1086 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 1227 }
1087 print!("{running}"); 1228 }
1088 pointsource_sliding_pdps_pair( 1229 }
1089 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 1230
1090 iterator, plotter, 1231 type MeasureZ<F, Z, const N: usize> = Pair<RNDM<N, F>, Z>;
1091 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1232
1092 ) 1233 pub fn run_pdps_pair<F, S, Dat, Reg, Z, R, Y, KOpZ, H, P, I, Plot, const N: usize>(
1093 }), 1234 f: &Dat,
1094 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 1235 reg: &Reg,
1095 print!("{running}"); 1236 prox_penalty: &P,
1096 pointsource_sliding_pdps_pair( 1237 opKz: &KOpZ,
1097 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, 1238 fnR: &R,
1098 iterator, plotter, 1239 fnH: &H,
1099 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1240 (alg, iterator, plotter, μ0zy, running): (
1100 ) 1241 &AlgorithmConfig<F>,
1101 }), 1242 I,
1102 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 1243 Plot,
1103 print!("{running}"); 1244 (Option<RNDM<N, F>>, Z, Y),
1104 pointsource_sliding_pdps_pair( 1245 String,
1105 &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 1246 ),
1106 iterator, plotter, 1247 cont: impl FnOnce(
1107 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1248 (
1108 ) 1249 &AlgorithmConfig<F>,
1109 }), 1250 I,
1110 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 1251 Plot,
1111 print!("{running}"); 1252 (Option<RNDM<N, F>>, Z, Y),
1112 pointsource_sliding_pdps_pair( 1253 String,
1113 &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, 1254 ),
1114 iterator, plotter, 1255 ) -> DynResult<Pair<RNDM<N, F>, Z>>,
1115 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1256 ) -> DynResult<Pair<RNDM<N, F>, Z>>
1116 ) 1257 where
1117 }), 1258 F: Float + ToNalgebraRealField,
1118 _ => Err(NotImplemented) 1259 I: AlgIteratorFactory<IterInfo<F>>,
1119 } 1260 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>
1120 }, 1261 + BoundedCurvature<F>,
1121 _ => Err(NotImplemented) 1262 S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
1122 }?; 1263 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
1123 Ok((μ, z)) 1264 //Pair<S, Z>: ClosedMul<F>,
1124 }) 1265 RNDM<N, F>: SpikeMerging<F>,
1266 Reg: SlidingRegTerm<Loc<N, F>, F>,
1267 P: ProxPenalty<Loc<N, F>, S, Reg, F>,
1268 KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y>
1269 + GEMV<F, Z>
1270 + SimplyAdjointable<Z, Y, AdjointCodomain = Z>,
1271 KOpZ::SimpleAdjoint: GEMV<F, Y>,
1272 Y: ClosedEuclidean<F> + Clone,
1273 for<'b> &'b Y: Instance<Y>,
1274 Z: ClosedEuclidean<F> + Clone + ClosedMul<F>,
1275 for<'b> &'b Z: Instance<Z>,
1276 R: Prox<Z, Codomain = F>,
1277 H: Conjugable<Y, F, Codomain = F>,
1278 for<'b> H::Conjugate<'b>: Prox<Y>,
1279 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
1280 {
1281 let pt = P::prox_type();
1282
1283 match alg {
1284 &AlgorithmConfig::ForwardPDPS(ref algconfig, prox_type) if prox_type == pt => {
1285 print!("{running}");
1286 pointsource_forward_pdps_pair(
1287 f,
1288 reg,
1289 prox_penalty,
1290 algconfig,
1291 iterator,
1292 plotter,
1293 μ0zy,
1294 opKz,
1295 fnR,
1296 fnH,
1297 )
1298 }
1299 &AlgorithmConfig::SlidingPDPS(ref algconfig, prox_type) if prox_type == pt => {
1300 print!("{running}");
1301 pointsource_sliding_pdps_pair(
1302 f,
1303 reg,
1304 prox_penalty,
1305 algconfig,
1306 iterator,
1307 plotter,
1308 μ0zy,
1309 opKz,
1310 fnR,
1311 fnH,
1312 )
1313 }
1314 _ => cont((alg, iterator, plotter, μ0zy, running)),
1315 }
1316 }
1317
1318 pub fn run_fb_pair<F, S, Dat, Reg, Z, R, P, I, Plot, const N: usize>(
1319 f: &Dat,
1320 reg: &Reg,
1321 prox_penalty: &P,
1322 fnR: &R,
1323 (alg, iterator, plotter, μ0z, running): (
1324 &AlgorithmConfig<F>,
1325 I,
1326 Plot,
1327 (Option<RNDM<N, F>>, Z),
1328 String,
1329 ),
1330 cont: impl FnOnce(
1331 (
1332 &AlgorithmConfig<F>,
1333 I,
1334 Plot,
1335 (Option<RNDM<N, F>>, Z),
1336 String,
1337 ),
1338 ) -> DynResult<Pair<RNDM<N, F>, Z>>,
1339 ) -> DynResult<Pair<RNDM<N, F>, Z>>
1340 where
1341 F: Float + ToNalgebraRealField,
1342 I: AlgIteratorFactory<IterInfo<F>>,
1343 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>
1344 + BoundedCurvature<F>,
1345 S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
1346 RNDM<N, F>: SpikeMerging<F>,
1347 Reg: SlidingRegTerm<Loc<N, F>, F>,
1348 P: ProxPenalty<Loc<N, F>, S, Reg, F>,
1349 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
1350 Z: ClosedEuclidean<F> + AXPY + Clone,
1351 for<'b> &'b Z: Instance<Z>,
1352 R: Prox<Z, Codomain = F>,
1353 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
1354 // We should not need to explicitly require this:
1355 for<'b> &'b Loc<0, F>: Instance<Loc<0, F>>,
1356 {
1357 let pt = P::prox_type();
1358
1359 match alg {
1360 &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => {
1361 print!("{running}");
1362 pointsource_fb_pair(f, reg, prox_penalty, algconfig, iterator, plotter, μ0z, fnR)
1363 }
1364 &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => {
1365 print!("{running}");
1366 pointsource_sliding_fb_pair(
1367 f,
1368 reg,
1369 prox_penalty,
1370 algconfig,
1371 iterator,
1372 plotter,
1373 μ0z,
1374 fnR,
1375 )
1376 }
1377 _ => cont((alg, iterator, plotter, μ0z, running)),
1125 } 1378 }
1126 } 1379 }
1127 1380
1128 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 1381 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
1129 struct ValueRange<F : Float> { 1382 struct ValueRange<F: Float> {
1130 ini : F, 1383 ini: F,
1131 min : F, 1384 min: F,
1132 } 1385 }
1133 1386
1134 impl<F : Float> ValueRange<F> { 1387 impl<F: Float> ValueRange<F> {
1135 fn expand_with(self, other : Self) -> Self { 1388 fn expand_with(self, other: Self) -> Self {
1136 ValueRange { 1389 ValueRange { ini: self.ini.max(other.ini), min: self.min.min(other.min) }
1137 ini : self.ini.max(other.ini),
1138 min : self.min.min(other.min),
1139 }
1140 } 1390 }
1141 } 1391 }
1142 1392
1143 /// Calculative minimum and maximum values of all the `logs`, and save them into 1393 /// Calculative minimum and maximum values of all the `logs`, and save them into
1144 /// corresponding file names given as the first elements of the tuples in the vectors. 1394 /// corresponding file names given as the first elements of the tuples in the vectors.
1145 fn save_logs<F : Float + for<'b> Deserialize<'b>, const N : usize>( 1395 fn save_logs<F: Float + for<'b> Deserialize<'b>>(
1146 logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)>, 1396 logs: Vec<(String, Logger<Timed<IterInfo<F>>>)>,
1147 valuerange_file : String, 1397 valuerange_file: String,
1148 load_valuerange : bool, 1398 load_valuerange: bool,
1149 ) -> DynError { 1399 ) -> DynError {
1150 // Process logs for relative values 1400 // Process logs for relative values
1151 println!("{}", "Processing logs…"); 1401 println!("{}", "Processing logs…");
1152 1402
1153 // Find minimum value and initial value within a single log 1403 // Find minimum value and initial value within a single log
1154 let proc_single_log = |log : &Logger<Timed<IterInfo<F, N>>>| { 1404 let proc_single_log = |log: &Logger<Timed<IterInfo<F>>>| {
1155 let d = log.data(); 1405 let d = log.data();
1156 let mi = d.iter() 1406 let mi = d.iter().map(|i| i.data.value).reduce(NumTraitsFloat::min);
1157 .map(|i| i.data.value)
1158 .reduce(NumTraitsFloat::min);
1159 d.first() 1407 d.first()
1160 .map(|i| i.data.value) 1408 .map(|i| i.data.value)
1161 .zip(mi) 1409 .zip(mi)
1162 .map(|(ini, min)| ValueRange{ ini, min }) 1410 .map(|(ini, min)| ValueRange { ini, min })
1163 }; 1411 };
1164 1412
1165 // Find minimum and maximum value over all logs 1413 // Find minimum and maximum value over all logs
1166 let mut v = logs.iter() 1414 let mut v = logs
1167 .filter_map(|&(_, ref log)| proc_single_log(log)) 1415 .iter()
1168 .reduce(|v1, v2| v1.expand_with(v2)) 1416 .filter_map(|&(_, ref log)| proc_single_log(log))
1169 .ok_or(anyhow!("No algorithms found"))?; 1417 .reduce(|v1, v2| v1.expand_with(v2))
1418 .ok_or(anyhow!("No algorithms found"))?;
1170 1419
1171 // Load existing range 1420 // Load existing range
1172 if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() { 1421 if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() {
1173 let data = std::fs::read_to_string(&valuerange_file)?; 1422 let data = std::fs::read_to_string(&valuerange_file)?;
1174 v = v.expand_with(serde_json::from_str(&data)?); 1423 v = v.expand_with(serde_json::from_str(&data)?);
1181 inner_iters, 1430 inner_iters,
1182 merged, 1431 merged,
1183 pruned, 1432 pruned,
1184 //postprocessing, 1433 //postprocessing,
1185 this_iters, 1434 this_iters,
1435 ε,
1186 .. 1436 ..
1187 } = data; 1437 } = data;
1188 // let post_value = match (postprocessing, dataterm) { 1438 // let post_value = match (postprocessing, dataterm) {
1189 // (Some(mut μ), DataTerm::L2Squared) => { 1439 // (Some(mut μ), DataTermType::L222) => {
1190 // // Comparison postprocessing is only implemented for the case handled 1440 // // Comparison postprocessing is only implemented for the case handled
1191 // // by the FW variants. 1441 // // by the FW variants.
1192 // reg.optimise_weights( 1442 // reg.optimise_weights(
1193 // &mut μ, &opA, &b, &findim_data, &inner_config, 1443 // &mut μ, &opA, &b, &findim_data, &inner_config,
1194 // inner_it 1444 // inner_it
1196 // dataterm.value_at_residual(opA.apply(&μ) - &b) 1446 // dataterm.value_at_residual(opA.apply(&μ) - &b)
1197 // + regularisation.apply(&μ) 1447 // + regularisation.apply(&μ)
1198 // }, 1448 // },
1199 // _ => value, 1449 // _ => value,
1200 // }; 1450 // };
1201 let relative_value = (value - v.min)/(v.ini - v.min); 1451 let relative_value = (value - v.min) / (v.ini - v.min);
1202 CSVLog { 1452 CSVLog {
1203 iter, 1453 iter,
1204 value, 1454 value,
1205 relative_value, 1455 relative_value,
1206 //post_value, 1456 //post_value,
1207 n_spikes, 1457 n_spikes,
1208 cpu_time : cpu_time.as_secs_f64(), 1458 cpu_time: cpu_time.as_secs_f64(),
1209 inner_iters, 1459 inner_iters,
1210 merged, 1460 merged,
1211 pruned, 1461 pruned,
1212 this_iters 1462 this_iters,
1463 epsilon: ε,
1213 } 1464 }
1214 }; 1465 };
1215 1466
1216 println!("{}", "Saving logs …".green()); 1467 println!("{}", "Saving logs …".green());
1217 1468
1222 } 1473 }
1223 1474
1224 Ok(()) 1475 Ok(())
1225 } 1476 }
1226 1477
1227
1228 /// Plot experiment setup 1478 /// Plot experiment setup
1229 #[replace_float_literals(F::cast_from(literal))] 1479 #[replace_float_literals(F::cast_from(literal))]
1230 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( 1480 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N: usize>(
1231 cli : &CommandLineArgs, 1481 cli: &CommandLineArgs,
1232 prefix : &String, 1482 prefix: &String,
1233 domain : &Cube<F, N>, 1483 domain: &Cube<N, F>,
1234 sensor : &Sensor, 1484 sensor: &Sensor,
1235 kernel : &Kernel, 1485 kernel: &Kernel,
1236 spread : &Spread, 1486 spread: &Spread,
1237 μ_hat : &RNDM<F, N>, 1487 μ_hat: &RNDM<N, F>,
1238 op𝒟 : &𝒟, 1488 op𝒟: &𝒟,
1239 opA : &A, 1489 opA: &A,
1240 b_hat : &A::Observable, 1490 b_hat: &A::Observable,
1241 b : &A::Observable, 1491 b: &A::Observable,
1242 kernel_plot_width : F, 1492 kernel_plot_width: F,
1243 ) -> DynError 1493 ) -> DynError
1244 where F : Float + ToNalgebraRealField, 1494 where
1245 Sensor : RealMapping<F, N> + Support<F, N> + Clone, 1495 F: Float + ToNalgebraRealField,
1246 Spread : RealMapping<F, N> + Support<F, N> + Clone, 1496 Sensor: RealMapping<N, F> + Support<N, F> + Clone,
1247 Kernel : RealMapping<F, N> + Support<F, N>, 1497 Spread: RealMapping<N, F> + Support<N, F> + Clone,
1248 Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>, 1498 Kernel: RealMapping<N, F> + Support<N, F>,
1249 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, 1499 Convolution<Sensor, Spread>: DifferentiableRealMapping<N, F> + Support<N, F>,
1250 𝒟::Codomain : RealMapping<F, N>, 1500 𝒟: DiscreteMeasureOp<Loc<N, F>, F>,
1251 A : ForwardModel<RNDM<F, N>, F>, 1501 𝒟::Codomain: RealMapping<N, F>,
1252 for<'a> &'a A::Observable : Instance<A::Observable>, 1502 A: ForwardModel<RNDM<N, F>, F>,
1253 A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>, 1503 for<'a> &'a A::Observable: Instance<A::Observable>,
1254 PlotLookup : Plotting<N>, 1504 A::PreadjointCodomain: DifferentiableRealMapping<N, F> + Bounded<F>,
1255 Cube<F, N> : SetOrd { 1505 PlotLookup: Plotting<N>,
1256 1506 Cube<N, F>: SetOrd,
1507 {
1257 if cli.plot < PlotLevel::Data { 1508 if cli.plot < PlotLevel::Data {
1258 return Ok(()) 1509 return Ok(());
1259 } 1510 }
1260 1511
1261 let base = Convolution(sensor.clone(), spread.clone()); 1512 let base = Convolution(sensor.clone(), spread.clone());
1262 1513
1263 let resolution = if N==1 { 100 } else { 40 }; 1514 let resolution = if N == 1 { 100 } else { 40 };
1264 let pfx = |n| format!("{prefix}{n}"); 1515 let pfx = |n| format!("{prefix}{n}");
1265 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); 1516 let plotgrid = lingrid(
1517 &[[-kernel_plot_width, kernel_plot_width]; N].into(),
1518 &[resolution; N],
1519 );
1266 1520
1267 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor")); 1521 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"));
1268 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel")); 1522 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"));
1269 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread")); 1523 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"));
1270 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor")); 1524 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"));
1271 1525
1272 let plotgrid2 = lingrid(&domain, &[resolution; N]); 1526 let plotgrid2 = lingrid(&domain, &[resolution; N]);
1273 1527
1274 let ω_hat = op𝒟.apply(μ_hat); 1528 let ω_hat = op𝒟.apply(μ_hat);
1275 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); 1529 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b);
1276 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat")); 1530 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"));
1277 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise")); 1531 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"));
1278 1532
1279 let preadj_b = opA.preadjoint().apply(b); 1533 let preadj_b = opA.preadjoint().apply(b);
1280 let preadj_b_hat = opA.preadjoint().apply(b_hat); 1534 let preadj_b_hat = opA.preadjoint().apply(b_hat);
1281 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); 1535 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds());
1282 PlotLookup::plot_into_file_spikes( 1536 PlotLookup::plot_into_file_spikes(
1283 Some(&preadj_b), 1537 Some(&preadj_b),
1284 Some(&preadj_b_hat), 1538 Some(&preadj_b_hat),
1285 plotgrid2, 1539 plotgrid2,
1286 &μ_hat, 1540 &μ_hat,
1287 pfx("omega_b") 1541 pfx("omega_b"),
1288 ); 1542 );
1289 PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b")); 1543 PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b"));
1290 PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat")); 1544 PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"));
1291 1545
1292 // Save true solution and observables 1546 // Save true solution and observables

mercurial