src/run.rs

branch
dev
changeset 61
4f468d35fa29
parent 46
f358958cc1a6
child 62
32328a74c790
child 63
7a8a55fd41c0
equal deleted inserted replaced
60:9738b51d90d7 61:4f468d35fa29
1 /*! 1 /*!
2 This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. 2 This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment.
3 */ 3 */
4 4
5 use numeric_literals::replace_float_literals; 5 use crate::fb::{pointsource_fb_reg, pointsource_fista_reg, FBConfig, InsertionConfig};
6 use colored::Colorize;
7 use serde::{Serialize, Deserialize};
8 use serde_json;
9 use nalgebra::base::DVector;
10 use std::hash::Hash;
11 use chrono::{DateTime, Utc};
12 use cpu_time::ProcessTime;
13 use clap::ValueEnum;
14 use std::collections::HashMap;
15 use std::time::Instant;
16
17 use rand::prelude::{
18 StdRng,
19 SeedableRng
20 };
21 use rand_distr::Distribution;
22
23 use alg_tools::bisection_tree::*;
24 use alg_tools::iterate::{
25 Timed,
26 AlgIteratorOptions,
27 Verbose,
28 AlgIteratorFactory,
29 LoggingIteratorFactory,
30 TimingIteratorFactory,
31 BasicAlgIteratorFactory,
32 };
33 use alg_tools::logger::Logger;
34 use alg_tools::error::{
35 DynError,
36 DynResult,
37 };
38 use alg_tools::tabledump::TableDump;
39 use alg_tools::sets::Cube;
40 use alg_tools::mapping::{
41 RealMapping,
42 DifferentiableMapping,
43 DifferentiableRealMapping,
44 Instance
45 };
46 use alg_tools::nalgebra_support::ToNalgebraRealField;
47 use alg_tools::euclidean::Euclidean;
48 use alg_tools::lingrid::{lingrid, LinSpace};
49 use alg_tools::sets::SetOrd;
50 use alg_tools::linops::{RowOp, IdOp /*, ZeroOp*/};
51 use alg_tools::discrete_gradient::{Grad, ForwardNeumann};
52 use alg_tools::convex::Zero;
53 use alg_tools::maputil::map3;
54 use alg_tools::direct_product::Pair;
55
56 use crate::kernels::*;
57 use crate::types::*;
58 use crate::measures::*;
59 use crate::measures::merging::{SpikeMerging,SpikeMergingMethod};
60 use crate::forward_model::*;
61 use crate::forward_model::sensor_grid::{ 6 use crate::forward_model::sensor_grid::{
7 //SensorGridBTFN,
8 Sensor,
62 SensorGrid, 9 SensorGrid,
63 SensorGridBT, 10 SensorGridBT,
64 //SensorGridBTFN,
65 Sensor,
66 Spread, 11 Spread,
67 }; 12 };
68 13 use crate::forward_model::*;
69 use crate::fb::{ 14 use crate::forward_pdps::{pointsource_fb_pair, pointsource_forward_pdps_pair, ForwardPDPSConfig};
70 FBConfig, 15 use crate::frank_wolfe::{pointsource_fw_reg, FWConfig, FWVariant, RegTermFW};
71 FBGenericConfig, 16 use crate::kernels::*;
72 pointsource_fb_reg, 17 use crate::measures::merging::{SpikeMerging, SpikeMergingMethod};
73 pointsource_fista_reg, 18 use crate::measures::*;
19 use crate::pdps::{pointsource_pdps_reg, PDPSConfig};
20 use crate::plot::*;
21 use crate::prox_penalty::{
22 ProxPenalty, ProxTerm, RadonSquared, StepLengthBound, StepLengthBoundPD, StepLengthBoundPair,
74 }; 23 };
75 use crate::sliding_fb::{ 24 use crate::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation, SlidingRegTerm};
76 SlidingFBConfig, 25 use crate::seminorms::*;
77 TransportConfig, 26 use crate::sliding_fb::{pointsource_sliding_fb_reg, SlidingFBConfig, TransportConfig};
78 pointsource_sliding_fb_reg 27 use crate::sliding_pdps::{
28 pointsource_sliding_fb_pair, pointsource_sliding_pdps_pair, SlidingPDPSConfig,
79 }; 29 };
80 use crate::sliding_pdps::{ 30 use crate::subproblem::{InnerMethod, InnerSettings};
81 SlidingPDPSConfig, 31 use crate::tolerance::Tolerance;
82 pointsource_sliding_pdps_pair 32 use crate::types::*;
33 use crate::{AlgorithmOverrides, CommandLineArgs};
34 use alg_tools::bisection_tree::*;
35 use alg_tools::bounds::{Bounded, MinMaxMapping};
36 use alg_tools::convex::{Conjugable, Norm222, Prox, Zero};
37 use alg_tools::direct_product::Pair;
38 use alg_tools::discrete_gradient::{ForwardNeumann, Grad};
39 use alg_tools::error::{DynError, DynResult};
40 use alg_tools::euclidean::{ClosedEuclidean, Euclidean};
41 use alg_tools::iterate::{
42 AlgIteratorFactory, AlgIteratorOptions, BasicAlgIteratorFactory, LoggingIteratorFactory, Timed,
43 TimingIteratorFactory, ValueIteratorFactory, Verbose,
83 }; 44 };
84 use crate::forward_pdps::{ 45 use alg_tools::lingrid::lingrid;
85 ForwardPDPSConfig, 46 use alg_tools::linops::{IdOp, RowOp, AXPY};
86 pointsource_forward_pdps_pair 47 use alg_tools::logger::Logger;
48 use alg_tools::mapping::{
49 DataTerm, DifferentiableMapping, DifferentiableRealMapping, Instance, RealMapping,
87 }; 50 };
88 use crate::pdps::{ 51 use alg_tools::maputil::map3;
89 PDPSConfig, 52 use alg_tools::nalgebra_support::ToNalgebraRealField;
90 pointsource_pdps_reg, 53 use alg_tools::norms::{NormExponent, L1, L2};
91 };
92 use crate::frank_wolfe::{
93 FWConfig,
94 FWVariant,
95 pointsource_fw_reg,
96 //WeightOptim,
97 };
98 use crate::subproblem::{InnerSettings, InnerMethod};
99 use crate::seminorms::*;
100 use crate::plot::*;
101 use crate::{AlgorithmOverrides, CommandLineArgs};
102 use crate::tolerance::Tolerance;
103 use crate::regularisation::{
104 Regularisation,
105 RadonRegTerm,
106 NonnegRadonRegTerm
107 };
108 use crate::dataterm::{
109 L1,
110 L2Squared,
111 };
112 use crate::prox_penalty::{
113 RadonSquared,
114 //ProxPenalty,
115 };
116 use alg_tools::norms::{L2, NormExponent};
117 use alg_tools::operator_arithmetic::Weighted; 54 use alg_tools::operator_arithmetic::Weighted;
55 use alg_tools::sets::Cube;
56 use alg_tools::sets::SetOrd;
57 use alg_tools::tabledump::TableDump;
118 use anyhow::anyhow; 58 use anyhow::anyhow;
119 59 use chrono::{DateTime, Utc};
120 /// Available proximal terms 60 use clap::ValueEnum;
121 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 61 use colored::Colorize;
122 pub enum ProxTerm { 62 use cpu_time::ProcessTime;
123 /// Partial-to-wave operator 𝒟. 63 use nalgebra::base::DVector;
124 Wave, 64 use numeric_literals::replace_float_literals;
125 /// Radon-norm squared 65 use rand::prelude::{SeedableRng, StdRng};
126 RadonSquared 66 use rand_distr::Distribution;
127 } 67 use serde::{Deserialize, Serialize};
68 use serde_json;
69 use std::collections::HashMap;
70 use std::hash::Hash;
71 use std::time::Instant;
72 use thiserror::Error;
73
74 //#[cfg(feature = "pyo3")]
75 //use pyo3::pyclass;
128 76
129 /// Available algorithms and their configurations 77 /// Available algorithms and their configurations
130 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 78 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
131 pub enum AlgorithmConfig<F : Float> { 79 pub enum AlgorithmConfig<F: Float> {
132 FB(FBConfig<F>, ProxTerm), 80 FB(FBConfig<F>, ProxTerm),
133 FISTA(FBConfig<F>, ProxTerm), 81 FISTA(FBConfig<F>, ProxTerm),
134 FW(FWConfig<F>), 82 FW(FWConfig<F>),
135 PDPS(PDPSConfig<F>, ProxTerm), 83 PDPS(PDPSConfig<F>, ProxTerm),
136 SlidingFB(SlidingFBConfig<F>, ProxTerm), 84 SlidingFB(SlidingFBConfig<F>, ProxTerm),
137 ForwardPDPS(ForwardPDPSConfig<F>, ProxTerm), 85 ForwardPDPS(ForwardPDPSConfig<F>, ProxTerm),
138 SlidingPDPS(SlidingPDPSConfig<F>, ProxTerm), 86 SlidingPDPS(SlidingPDPSConfig<F>, ProxTerm),
139 } 87 }
140 88
141 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { 89 fn unpack_tolerance<F: Float>(v: &Vec<F>) -> Tolerance<F> {
142 assert!(v.len() == 3); 90 assert!(v.len() == 3);
143 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } 91 Tolerance::Power { initial: v[0], factor: v[1], exponent: v[2] }
144 } 92 }
145 93
146 impl<F : ClapFloat> AlgorithmConfig<F> { 94 impl<F: ClapFloat> AlgorithmConfig<F> {
147 /// Override supported parameters based on the command line. 95 /// Override supported parameters based on the command line.
148 pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { 96 pub fn cli_override(self, cli: &AlgorithmOverrides<F>) -> Self {
149 let override_merging = |g : SpikeMergingMethod<F>| { 97 let override_merging = |g: SpikeMergingMethod<F>| SpikeMergingMethod {
150 SpikeMergingMethod { 98 enabled: cli.merge.unwrap_or(g.enabled),
151 enabled : cli.merge.unwrap_or(g.enabled), 99 radius: cli.merge_radius.unwrap_or(g.radius),
152 radius : cli.merge_radius.unwrap_or(g.radius), 100 interp: cli.merge_interp.unwrap_or(g.interp),
153 interp : cli.merge_interp.unwrap_or(g.interp),
154 }
155 }; 101 };
156 let override_fb_generic = |g : FBGenericConfig<F>| { 102 let override_fb_generic = |g: InsertionConfig<F>| InsertionConfig {
157 FBGenericConfig { 103 bootstrap_insertions: cli
158 bootstrap_insertions : cli.bootstrap_insertions 104 .bootstrap_insertions
159 .as_ref() 105 .as_ref()
160 .map_or(g.bootstrap_insertions, 106 .map_or(g.bootstrap_insertions, |n| Some((n[0], n[1]))),
161 |n| Some((n[0], n[1]))), 107 merge_every: cli.merge_every.unwrap_or(g.merge_every),
162 merge_every : cli.merge_every.unwrap_or(g.merge_every), 108 merging: override_merging(g.merging),
163 merging : override_merging(g.merging), 109 final_merging: cli.final_merging.unwrap_or(g.final_merging),
164 final_merging : cli.final_merging.unwrap_or(g.final_merging), 110 fitness_merging: cli.fitness_merging.unwrap_or(g.fitness_merging),
165 fitness_merging : cli.fitness_merging.unwrap_or(g.fitness_merging), 111 tolerance: cli
166 tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance), 112 .tolerance
167 .. g 113 .as_ref()
168 } 114 .map(unpack_tolerance)
115 .unwrap_or(g.tolerance),
116 ..g
169 }; 117 };
170 let override_transport = |g : TransportConfig<F>| { 118 let override_transport = |g: TransportConfig<F>| TransportConfig {
171 TransportConfig { 119 θ0: cli.theta0.unwrap_or(g.θ0),
172 θ0 : cli.theta0.unwrap_or(g.θ0), 120 tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con),
173 tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con), 121 adaptation: cli.transport_adaptation.unwrap_or(g.adaptation),
174 adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), 122 ..g
175 .. g
176 }
177 }; 123 };
178 124
179 use AlgorithmConfig::*; 125 use AlgorithmConfig::*;
180 match self { 126 match self {
181 FB(fb, prox) => FB(FBConfig { 127 FB(fb, prox) => FB(
182 τ0 : cli.tau0.unwrap_or(fb.τ0), 128 FBConfig {
183 generic : override_fb_generic(fb.generic), 129 τ0: cli.tau0.unwrap_or(fb.τ0),
184 .. fb 130 σp0: cli.sigmap0.unwrap_or(fb.σp0),
185 }, prox), 131 insertion: override_fb_generic(fb.insertion),
186 FISTA(fb, prox) => FISTA(FBConfig { 132 ..fb
187 τ0 : cli.tau0.unwrap_or(fb.τ0), 133 },
188 generic : override_fb_generic(fb.generic), 134 prox,
189 .. fb 135 ),
190 }, prox), 136 FISTA(fb, prox) => FISTA(
191 PDPS(pdps, prox) => PDPS(PDPSConfig { 137 FBConfig {
192 τ0 : cli.tau0.unwrap_or(pdps.τ0), 138 τ0: cli.tau0.unwrap_or(fb.τ0),
193 σ0 : cli.sigma0.unwrap_or(pdps.σ0), 139 σp0: cli.sigmap0.unwrap_or(fb.σp0),
194 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), 140 insertion: override_fb_generic(fb.insertion),
195 generic : override_fb_generic(pdps.generic), 141 ..fb
196 .. pdps 142 },
197 }, prox), 143 prox,
144 ),
145 PDPS(pdps, prox) => PDPS(
146 PDPSConfig {
147 τ0: cli.tau0.unwrap_or(pdps.τ0),
148 σ0: cli.sigma0.unwrap_or(pdps.σ0),
149 acceleration: cli.acceleration.unwrap_or(pdps.acceleration),
150 generic: override_fb_generic(pdps.generic),
151 ..pdps
152 },
153 prox,
154 ),
198 FW(fw) => FW(FWConfig { 155 FW(fw) => FW(FWConfig {
199 merging : override_merging(fw.merging), 156 merging: override_merging(fw.merging),
200 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), 157 tolerance: cli
201 .. fw 158 .tolerance
159 .as_ref()
160 .map(unpack_tolerance)
161 .unwrap_or(fw.tolerance),
162 ..fw
202 }), 163 }),
203 SlidingFB(sfb, prox) => SlidingFB(SlidingFBConfig { 164 SlidingFB(sfb, prox) => SlidingFB(
204 τ0 : cli.tau0.unwrap_or(sfb.τ0), 165 SlidingFBConfig {
205 transport : override_transport(sfb.transport), 166 τ0: cli.tau0.unwrap_or(sfb.τ0),
206 insertion : override_fb_generic(sfb.insertion), 167 σp0: cli.sigmap0.unwrap_or(sfb.σp0),
207 .. sfb 168 transport: override_transport(sfb.transport),
208 }, prox), 169 insertion: override_fb_generic(sfb.insertion),
209 SlidingPDPS(spdps, prox) => SlidingPDPS(SlidingPDPSConfig { 170 ..sfb
210 τ0 : cli.tau0.unwrap_or(spdps.τ0), 171 },
211 σp0 : cli.sigmap0.unwrap_or(spdps.σp0), 172 prox,
212 σd0 : cli.sigma0.unwrap_or(spdps.σd0), 173 ),
213 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), 174 SlidingPDPS(spdps, prox) => SlidingPDPS(
214 transport : override_transport(spdps.transport), 175 SlidingPDPSConfig {
215 insertion : override_fb_generic(spdps.insertion), 176 τ0: cli.tau0.unwrap_or(spdps.τ0),
216 .. spdps 177 σp0: cli.sigmap0.unwrap_or(spdps.σp0),
217 }, prox), 178 σd0: cli.sigma0.unwrap_or(spdps.σd0),
218 ForwardPDPS(fpdps, prox) => ForwardPDPS(ForwardPDPSConfig { 179 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
219 τ0 : cli.tau0.unwrap_or(fpdps.τ0), 180 transport: override_transport(spdps.transport),
220 σp0 : cli.sigmap0.unwrap_or(fpdps.σp0), 181 insertion: override_fb_generic(spdps.insertion),
221 σd0 : cli.sigma0.unwrap_or(fpdps.σd0), 182 ..spdps
222 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), 183 },
223 insertion : override_fb_generic(fpdps.insertion), 184 prox,
224 .. fpdps 185 ),
225 }, prox), 186 ForwardPDPS(fpdps, prox) => ForwardPDPS(
187 ForwardPDPSConfig {
188 τ0: cli.tau0.unwrap_or(fpdps.τ0),
189 σp0: cli.sigmap0.unwrap_or(fpdps.σp0),
190 σd0: cli.sigma0.unwrap_or(fpdps.σd0),
191 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
192 insertion: override_fb_generic(fpdps.insertion),
193 ..fpdps
194 },
195 prox,
196 ),
226 } 197 }
227 } 198 }
228 } 199 }
229 200
230 /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name. 201 /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name.
231 #[derive(Clone, Debug, Serialize, Deserialize)] 202 #[derive(Clone, Debug, Serialize, Deserialize)]
232 pub struct Named<Data> { 203 pub struct Named<Data> {
233 pub name : String, 204 pub name: String,
234 #[serde(flatten)] 205 #[serde(flatten)]
235 pub data : Data, 206 pub data: Data,
236 } 207 }
237 208
238 /// Shorthand algorithm configurations, to be used with the command line parser 209 /// Shorthand algorithm configurations, to be used with the command line parser
239 #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] 210 #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
211 //#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))]
240 pub enum DefaultAlgorithm { 212 pub enum DefaultAlgorithm {
241 /// The μFB forward-backward method 213 /// The μFB forward-backward method
242 #[clap(name = "fb")] 214 #[clap(name = "fb")]
243 FB, 215 FB,
244 /// The μFISTA inertial forward-backward method 216 /// The μFISTA inertial forward-backward method
262 /// The PDPS method with a forward step for the smooth function 234 /// The PDPS method with a forward step for the smooth function
263 #[clap(name = "forward_pdps", alias = "fpdps")] 235 #[clap(name = "forward_pdps", alias = "fpdps")]
264 ForwardPDPS, 236 ForwardPDPS,
265 237
266 // Radon variants 238 // Radon variants
267
268 /// The μFB forward-backward method with radon-norm squared proximal term 239 /// The μFB forward-backward method with radon-norm squared proximal term
269 #[clap(name = "radon_fb")] 240 #[clap(name = "radon_fb")]
270 RadonFB, 241 RadonFB,
271 /// The μFISTA inertial forward-backward method with radon-norm squared proximal term 242 /// The μFISTA inertial forward-backward method with radon-norm squared proximal term
272 #[clap(name = "radon_fista")] 243 #[clap(name = "radon_fista")]
285 RadonForwardPDPS, 256 RadonForwardPDPS,
286 } 257 }
287 258
288 impl DefaultAlgorithm { 259 impl DefaultAlgorithm {
289 /// Returns the algorithm configuration corresponding to the algorithm shorthand 260 /// Returns the algorithm configuration corresponding to the algorithm shorthand
290 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { 261 pub fn default_config<F: Float>(&self) -> AlgorithmConfig<F> {
291 use DefaultAlgorithm::*; 262 use DefaultAlgorithm::*;
292 let radon_insertion = FBGenericConfig { 263 let radon_insertion = InsertionConfig {
293 merging : SpikeMergingMethod{ interp : false, .. Default::default() }, 264 merging: SpikeMergingMethod { interp: false, ..Default::default() },
294 inner : InnerSettings { 265 inner: InnerSettings {
295 method : InnerMethod::PDPS, // SSN not implemented 266 method: InnerMethod::PDPS, // SSN not implemented
296 .. Default::default() 267 ..Default::default()
297 }, 268 },
298 .. Default::default() 269 ..Default::default()
299 }; 270 };
300 match *self { 271 match *self {
301 FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave), 272 FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave),
302 FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave), 273 FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave),
303 FW => AlgorithmConfig::FW(Default::default()), 274 FW => AlgorithmConfig::FW(Default::default()),
304 FWRelax => AlgorithmConfig::FW(FWConfig{ 275 FWRelax => {
305 variant : FWVariant::Relaxed, 276 AlgorithmConfig::FW(FWConfig { variant: FWVariant::Relaxed, ..Default::default() })
306 .. Default::default() 277 }
307 }),
308 PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave), 278 PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave),
309 SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave), 279 SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave),
310 SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave), 280 SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave),
311 ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave), 281 ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave),
312 282
313 // Radon variants 283 // Radon variants
314
315 RadonFB => AlgorithmConfig::FB( 284 RadonFB => AlgorithmConfig::FB(
316 FBConfig{ generic : radon_insertion, ..Default::default() }, 285 FBConfig { insertion: radon_insertion, ..Default::default() },
317 ProxTerm::RadonSquared 286 ProxTerm::RadonSquared,
318 ), 287 ),
319 RadonFISTA => AlgorithmConfig::FISTA( 288 RadonFISTA => AlgorithmConfig::FISTA(
320 FBConfig{ generic : radon_insertion, ..Default::default() }, 289 FBConfig { insertion: radon_insertion, ..Default::default() },
321 ProxTerm::RadonSquared 290 ProxTerm::RadonSquared,
322 ), 291 ),
323 RadonPDPS => AlgorithmConfig::PDPS( 292 RadonPDPS => AlgorithmConfig::PDPS(
324 PDPSConfig{ generic : radon_insertion, ..Default::default() }, 293 PDPSConfig { generic: radon_insertion, ..Default::default() },
325 ProxTerm::RadonSquared 294 ProxTerm::RadonSquared,
326 ), 295 ),
327 RadonSlidingFB => AlgorithmConfig::SlidingFB( 296 RadonSlidingFB => AlgorithmConfig::SlidingFB(
328 SlidingFBConfig{ insertion : radon_insertion, ..Default::default() }, 297 SlidingFBConfig { insertion: radon_insertion, ..Default::default() },
329 ProxTerm::RadonSquared 298 ProxTerm::RadonSquared,
330 ), 299 ),
331 RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS( 300 RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS(
332 SlidingPDPSConfig{ insertion : radon_insertion, ..Default::default() }, 301 SlidingPDPSConfig { insertion: radon_insertion, ..Default::default() },
333 ProxTerm::RadonSquared 302 ProxTerm::RadonSquared,
334 ), 303 ),
335 RadonForwardPDPS => AlgorithmConfig::ForwardPDPS( 304 RadonForwardPDPS => AlgorithmConfig::ForwardPDPS(
336 ForwardPDPSConfig{ insertion : radon_insertion, ..Default::default() }, 305 ForwardPDPSConfig { insertion: radon_insertion, ..Default::default() },
337 ProxTerm::RadonSquared 306 ProxTerm::RadonSquared,
338 ), 307 ),
339 } 308 }
340 } 309 }
341 310
342 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand 311 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand
343 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { 312 pub fn get_named<F: Float>(&self) -> Named<AlgorithmConfig<F>> {
344 self.to_named(self.default_config()) 313 self.to_named(self.default_config())
345 } 314 }
346 315
347 pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { 316 pub fn to_named<F: Float>(self, alg: AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> {
348 let name = self.to_possible_value().unwrap().get_name().to_string(); 317 Named { name: self.name(), data: alg }
349 Named{ name , data : alg } 318 }
350 } 319
351 } 320 pub fn name(self) -> String {
352 321 self.to_possible_value().unwrap().get_name().to_string()
322 }
323 }
353 324
354 // // Floats cannot be hashed directly, so just hash the debug formatting 325 // // Floats cannot be hashed directly, so just hash the debug formatting
355 // // for use as file identifier. 326 // // for use as file identifier.
356 // impl<F : Float> Hash for AlgorithmConfig<F> { 327 // impl<F : Float> Hash for AlgorithmConfig<F> {
357 // fn hash<H: Hasher>(&self, state: &mut H) { 328 // fn hash<H: Hasher>(&self, state: &mut H) {
377 fn default() -> Self { 348 fn default() -> Self {
378 Self::Data 349 Self::Data
379 } 350 }
380 } 351 }
381 352
382 type DefaultBT<F, const N : usize> = BT< 353 type DefaultBT<F, const N: usize> = BT<DynamicDepth, F, usize, Bounds<F>, N>;
383 DynamicDepth, 354 type DefaultSeminormOp<F, K, const N: usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>;
384 F, 355 type DefaultSG<F, Sensor, Spread, const N: usize> =
385 usize, 356 SensorGrid<F, Sensor, Spread, DefaultBT<F, N>, N>;
386 Bounds<F>,
387 N
388 >;
389 type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>;
390 type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::<
391 F,
392 Sensor,
393 Spread,
394 DefaultBT<F, N>,
395 N
396 >;
397 357
398 /// This is a dirty workaround to rust-csv not supporting struct flattening etc. 358 /// This is a dirty workaround to rust-csv not supporting struct flattening etc.
399 #[derive(Serialize)] 359 #[derive(Serialize)]
400 struct CSVLog<F> { 360 struct CSVLog<F> {
401 iter : usize, 361 iter: usize,
402 cpu_time : f64, 362 cpu_time: f64,
403 value : F, 363 value: F,
404 relative_value : F, 364 relative_value: F,
405 //post_value : F, 365 //post_value : F,
406 n_spikes : usize, 366 n_spikes: usize,
407 inner_iters : usize, 367 inner_iters: usize,
408 merged : usize, 368 merged: usize,
409 pruned : usize, 369 pruned: usize,
410 this_iters : usize, 370 this_iters: usize,
371 epsilon: F,
411 } 372 }
412 373
413 /// Collected experiment statistics 374 /// Collected experiment statistics
414 #[derive(Clone, Debug, Serialize)] 375 #[derive(Clone, Debug, Serialize)]
415 struct ExperimentStats<F : Float> { 376 struct ExperimentStats<F: Float> {
416 /// Signal-to-noise ratio in decibels 377 /// Signal-to-noise ratio in decibels
417 ssnr : F, 378 ssnr: F,
418 /// Proportion of noise in the signal as a number in $[0, 1]$. 379 /// Proportion of noise in the signal as a number in $[0, 1]$.
419 noise_ratio : F, 380 noise_ratio: F,
420 /// When the experiment was run (UTC) 381 /// When the experiment was run (UTC)
421 when : DateTime<Utc>, 382 when: DateTime<Utc>,
422 } 383 }
423 384
424 #[replace_float_literals(F::cast_from(literal))] 385 #[replace_float_literals(F::cast_from(literal))]
425 impl<F : Float> ExperimentStats<F> { 386 impl<F: Float> ExperimentStats<F> {
426 /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. 387 /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal.
427 fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self { 388 fn new<E: Euclidean<F>>(signal: &E, noise: &E) -> Self {
428 let s = signal.norm2_squared(); 389 let s = signal.norm2_squared();
429 let n = noise.norm2_squared(); 390 let n = noise.norm2_squared();
430 let noise_ratio = (n / s).sqrt(); 391 let noise_ratio = (n / s).sqrt();
431 let ssnr = 10.0 * (s / n).log10(); 392 let ssnr = 10.0 * (s / n).log10();
432 ExperimentStats { 393 ExperimentStats { ssnr, noise_ratio, when: Utc::now() }
433 ssnr,
434 noise_ratio,
435 when : Utc::now(),
436 }
437 } 394 }
438 } 395 }
439 /// Collected algorithm statistics 396 /// Collected algorithm statistics
440 #[derive(Clone, Debug, Serialize)] 397 #[derive(Clone, Debug, Serialize)]
441 struct AlgorithmStats<F : Float> { 398 struct AlgorithmStats<F: Float> {
442 /// Overall CPU time spent 399 /// Overall CPU time spent
443 cpu_time : F, 400 cpu_time: F,
444 /// Real time spent 401 /// Real time spent
445 elapsed : F 402 elapsed: F,
446 } 403 }
447
448 404
449 /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input 405 /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input
450 /// and outputs a [`DynError`]. 406 /// and outputs a [`DynError`].
451 fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { 407 fn write_json<T: Serialize>(filename: String, data: &T) -> DynError {
452 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; 408 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?;
453 Ok(()) 409 Ok(())
454 } 410 }
455 411
456
457 /// Struct for experiment configurations 412 /// Struct for experiment configurations
458 #[derive(Debug, Clone, Serialize)] 413 #[derive(Debug, Clone, Serialize)]
459 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize> 414 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N: usize>
460 where F : Float + ClapFloat, 415 where
461 [usize; N] : Serialize, 416 F: Float + ClapFloat,
462 NoiseDistr : Distribution<F>, 417 [usize; N]: Serialize,
463 S : Sensor<F, N>, 418 NoiseDistr: Distribution<F>,
464 P : Spread<F, N>, 419 S: Sensor<N, F>,
465 K : SimpleConvolutionKernel<F, N>, 420 P: Spread<N, F>,
421 K: SimpleConvolutionKernel<N, F>,
466 { 422 {
467 /// Domain $Ω$. 423 /// Domain $Ω$.
468 pub domain : Cube<F, N>, 424 pub domain: Cube<N, F>,
469 /// Number of sensors along each dimension 425 /// Number of sensors along each dimension
470 pub sensor_count : [usize; N], 426 pub sensor_count: [usize; N],
471 /// Noise distribution 427 /// Noise distribution
472 pub noise_distr : NoiseDistr, 428 pub noise_distr: NoiseDistr,
473 /// Seed for random noise generation (for repeatable experiments) 429 /// Seed for random noise generation (for repeatable experiments)
474 pub noise_seed : u64, 430 pub noise_seed: u64,
475 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. 431 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$.
476 pub sensor : S, 432 pub sensor: S,
477 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. 433 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$.
478 pub spread : P, 434 pub spread: P,
479 /// Kernel $ρ$ of $𝒟$. 435 /// Kernel $ρ$ of $𝒟$.
480 pub kernel : K, 436 pub kernel: K,
481 /// True point sources 437 /// True point sources
482 pub μ_hat : RNDM<F, N>, 438 pub μ_hat: RNDM<N, F>,
483 /// Regularisation term and parameter 439 /// Regularisation term and parameter
484 pub regularisation : Regularisation<F>, 440 pub regularisation: Regularisation<F>,
485 /// For plotting : how wide should the kernels be plotted 441 /// For plotting : how wide should the kernels be plotted
486 pub kernel_plot_width : F, 442 pub kernel_plot_width: F,
487 /// Data term 443 /// Data term
488 pub dataterm : DataTerm, 444 pub dataterm: DataTermType,
489 /// A map of default configurations for algorithms 445 /// A map of default configurations for algorithms
490 pub algorithm_overrides : HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>, 446 pub algorithm_overrides: HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>,
491 /// Default merge radius 447 /// Default merge radius
492 pub default_merge_radius : F, 448 pub default_merge_radius: F,
493 } 449 }
494 450
495 #[derive(Debug, Clone, Serialize)] 451 #[derive(Debug, Clone, Serialize)]
496 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize> 452 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N: usize>
497 where F : Float + ClapFloat, 453 where
498 [usize; N] : Serialize, 454 F: Float + ClapFloat,
499 NoiseDistr : Distribution<F>, 455 [usize; N]: Serialize,
500 S : Sensor<F, N>, 456 NoiseDistr: Distribution<F>,
501 P : Spread<F, N>, 457 S: Sensor<N, F>,
502 K : SimpleConvolutionKernel<F, N>, 458 P: Spread<N, F>,
503 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, 459 K: SimpleConvolutionKernel<N, F>,
460 B: Mapping<Loc<N, F>, Codomain = F> + Serialize + std::fmt::Debug,
504 { 461 {
505 /// Basic setup 462 /// Basic setup
506 pub base : ExperimentV2<F, NoiseDistr, S, K, P, N>, 463 pub base: ExperimentV2<F, NoiseDistr, S, K, P, N>,
507 /// Weight of TV term 464 /// Weight of TV term
508 pub λ : F, 465 pub λ: F,
509 /// Bias function 466 /// Bias function
510 pub bias : B, 467 pub bias: B,
511 } 468 }
512 469
513 /// Trait for runnable experiments 470 /// Trait for runnable experiments
514 pub trait RunnableExperiment<F : ClapFloat> { 471 pub trait RunnableExperiment<F: ClapFloat> {
515 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. 472 /// Run all algorithms provided, or default algorithms if none provided, on the experiment.
516 fn runall(&self, cli : &CommandLineArgs, 473 fn runall(
517 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; 474 &self,
475 cli: &CommandLineArgs,
476 algs: Option<Vec<Named<AlgorithmConfig<F>>>>,
477 ) -> DynError;
518 478
519 /// Return algorithm default config 479 /// Return algorithm default config
520 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F>; 480 fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F>;
521 } 481
522 482 /// Experiment name
523 /// Helper function to print experiment start message and save setup. 483 fn name(&self) -> &str;
524 /// Returns saving prefix. 484 }
525 fn start_experiment<E, S>( 485
526 experiment : &Named<E>, 486 /// Error codes for running an algorithm on an experiment.
527 cli : &CommandLineArgs, 487 #[derive(Error, Debug)]
528 stats : S, 488 pub enum RunError {
529 ) -> DynResult<String> 489 /// Algorithm not implemented for this experiment
490 #[error("Algorithm not implemented for this experiment")]
491 NotImplemented,
492 }
493
494 use RunError::*;
495
496 type DoRunAllIt<'a, F, const N: usize> = LoggingIteratorFactory<
497 'a,
498 Timed<IterInfo<F>>,
499 TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F>>>,
500 >;
501
502 pub trait RunnableExperimentExtras<F: ClapFloat>:
503 RunnableExperiment<F> + Serialize + Sized
504 {
505 /// Helper function to print experiment start message and save setup.
506 /// Returns saving prefix.
507 fn start(&self, cli: &CommandLineArgs) -> DynResult<String> {
508 let experiment_name = self.name();
509 let ser = serde_json::to_string(self);
510
511 println!(
512 "{}\n{}",
513 format!("Performing experiment {}…", experiment_name).cyan(),
514 format!(
515 "Experiment settings: {}",
516 if let Ok(ref s) = ser {
517 s
518 } else {
519 "<serialisation failure>"
520 }
521 )
522 .bright_black(),
523 );
524
525 // Set up output directory
526 let prefix = format!("{}/{}/", cli.outdir, experiment_name);
527
528 // Save experiment configuration and statistics
529 std::fs::create_dir_all(&prefix)?;
530 write_json(format!("{prefix}experiment.json"), self)?;
531 write_json(format!("{prefix}config.json"), cli)?;
532
533 Ok(prefix)
534 }
535
536 /// Helper function to run all algorithms on an experiment.
537 fn do_runall<P, Z, Plot, const N: usize>(
538 &self,
539 prefix: &String,
540 cli: &CommandLineArgs,
541 algorithms: Vec<Named<AlgorithmConfig<F>>>,
542 mut make_plotter: impl FnMut(String) -> Plot,
543 mut save_extra: impl FnMut(String, Z) -> DynError,
544 init: P,
545 mut do_alg: impl FnMut(
546 (&AlgorithmConfig<F>, DoRunAllIt<F, N>, Plot, P, String),
547 ) -> DynResult<(RNDM<N, F>, Z)>,
548 ) -> DynError
549 where
550 F: for<'b> Deserialize<'b>,
551 PlotLookup: Plotting<N>,
552 P: Clone,
553 {
554 let experiment_name = self.name();
555
556 let mut logs = Vec::new();
557
558 let iterator_options = AlgIteratorOptions {
559 max_iter: cli.max_iter,
560 verbose_iter: cli
561 .verbose_iter
562 .map_or(Verbose::LogarithmicCap { base: 10, cap: 2 }, |n| {
563 Verbose::Every(n)
564 }),
565 quiet: cli.quiet,
566 };
567
568 // Run the algorithm(s)
569 for named @ Named { name: alg_name, data: alg } in algorithms.iter() {
570 let this_prefix = format!("{}{}/", prefix, alg_name);
571
572 // Create Logger and IteratorFactory
573 let mut logger = Logger::new();
574 let iterator = iterator_options.instantiate().timed().into_log(&mut logger);
575
576 let running = if !cli.quiet {
577 format!(
578 "{}\n{}\n{}\n",
579 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(),
580 format!(
581 "Iteration settings: {}",
582 serde_json::to_string(&iterator_options)?
583 )
584 .bright_black(),
585 format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black()
586 )
587 } else {
588 "".to_string()
589 };
590 //
591 // The following is for postprocessing, which has been disabled anyway.
592 //
593 // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation {
594 // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)),
595 // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)),
596 // };
597 //let findim_data = reg.prepare_optimise_weights(&opA, &b);
598 //let inner_config : InnerSettings<F> = Default::default();
599 //let inner_it = inner_config.iterator_options;
600
601 // Create plotter and directory if needed.
602 let plotter = make_plotter(this_prefix);
603
604 let start = Instant::now();
605 let start_cpu = ProcessTime::now();
606
607 let (μ, z) = match do_alg((alg, iterator, plotter, init.clone(), running)) {
608 Ok(μ) => μ,
609 Err(e) => {
610 let msg = format!(
611 "Skipping algorithm “{alg_name}” for {experiment_name} due to error: {e}"
612 )
613 .red();
614 eprintln!("{}", msg);
615 continue;
616 }
617 };
618
619 let elapsed = start.elapsed().as_secs_f64();
620 let cpu_time = start_cpu.elapsed().as_secs_f64();
621
622 println!(
623 "{}",
624 format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()
625 );
626
627 // Save results
628 println!("{}", "Saving results …".green());
629
630 let mkname = |t| format!("{prefix}{alg_name}_{t}");
631
632 write_json(mkname("config.json"), &named)?;
633 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?;
634 μ.write_csv(mkname("reco.txt"))?;
635 save_extra(mkname(""), z)?;
636 //logger.write_csv(mkname("log.txt"))?;
637 logs.push((mkname("log.txt"), logger));
638 }
639
640 save_logs(
641 logs,
642 format!("{prefix}valuerange.json"),
643 cli.load_valuerange,
644 )
645 }
646 }
647
648 impl<F, E> RunnableExperimentExtras<F> for E
530 where 649 where
531 E : Serialize + std::fmt::Debug, 650 F: ClapFloat,
532 S : Serialize, 651 Self: RunnableExperiment<F> + Serialize,
533 { 652 {
534 let Named { name : experiment_name, data } = experiment; 653 }
535 654
536 println!("{}\n{}", 655 #[replace_float_literals(F::cast_from(literal))]
537 format!("Performing experiment {}…", experiment_name).cyan(), 656 impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N: usize> RunnableExperiment<F>
538 format!("Experiment settings: {}", serde_json::to_string(&data)?).bright_black()); 657 for Named<ExperimentV2<F, NoiseDistr, S, K, P, N>>
539
540 // Set up output directory
541 let prefix = format!("{}/{}/", cli.outdir, experiment_name);
542
543 // Save experiment configuration and statistics
544 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t);
545 std::fs::create_dir_all(&prefix)?;
546 write_json(mkname_e("experiment"), experiment)?;
547 write_json(mkname_e("config"), cli)?;
548 write_json(mkname_e("stats"), &stats)?;
549
550 Ok(prefix)
551 }
552
553 /// Error codes for running an algorithm on an experiment.
554 enum RunError {
555 /// Algorithm not implemented for this experiment
556 NotImplemented,
557 }
558
559 use RunError::*;
560
561 type DoRunAllIt<'a, F, const N : usize> = LoggingIteratorFactory<
562 'a,
563 Timed<IterInfo<F, N>>,
564 TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F, N>>>
565 >;
566
567 /// Helper function to run all algorithms on an experiment.
568 fn do_runall<F : Float + for<'b> Deserialize<'b>, Z, const N : usize>(
569 experiment_name : &String,
570 prefix : &String,
571 cli : &CommandLineArgs,
572 algorithms : Vec<Named<AlgorithmConfig<F>>>,
573 plotgrid : LinSpace<Loc<F, N>, [usize; N]>,
574 mut save_extra : impl FnMut(String, Z) -> DynError,
575 mut do_alg : impl FnMut(
576 &AlgorithmConfig<F>,
577 DoRunAllIt<F, N>,
578 SeqPlotter<F, N>,
579 String,
580 ) -> Result<(RNDM<F, N>, Z), RunError>,
581 ) -> DynError
582 where 658 where
583 PlotLookup : Plotting<N>, 659 F: ClapFloat
660 + nalgebra::RealField
661 + ToNalgebraRealField<MixedType = F>
662 + Default
663 + for<'b> Deserialize<'b>,
664 [usize; N]: Serialize,
665 S: Sensor<N, F> + Copy + Serialize + std::fmt::Debug,
666 P: Spread<N, F> + Copy + Serialize + std::fmt::Debug,
667 Convolution<S, P>: Spread<N, F>
668 + Bounded<F>
669 + LocalAnalysis<F, Bounds<F>, N>
670 + Copy
671 // TODO: shold not have differentiability as a requirement, but
672 // decide availability of sliding based on it.
673 //+ for<'b> Differentiable<&'b Loc<N, F>, Output = Loc<N, F>>,
674 // TODO: very weird that rust only compiles with Differentiable
675 // instead of the above one on references, which is required by
676 // poitsource_sliding_fb_reg.
677 + DifferentiableRealMapping<N, F>
678 + Lipschitz<L2, FloatType = F>,
679 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<N, F>>>::Differential<'b>:
680 Lipschitz<L2, FloatType = F>, // TODO: should not be required generally, only for sliding_fb.
681 AutoConvolution<P>: BoundedBy<F, K>,
682 K: SimpleConvolutionKernel<N, F>
683 + LocalAnalysis<F, Bounds<F>, N>
684 + Copy
685 + Serialize
686 + std::fmt::Debug,
687 Cube<N, F>: P2Minimise<Loc<N, F>, F> + SetOrd,
688 PlotLookup: Plotting<N>,
689 DefaultBT<F, N>: SensorGridBT<F, S, P, N, Depth = DynamicDepth> + BTSearch<N, F>,
690 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
691 RNDM<N, F>: SpikeMerging<F>,
692 NoiseDistr: Distribution<F> + Serialize + std::fmt::Debug,
584 { 693 {
585 let mut logs = Vec::new(); 694 fn name(&self) -> &str {
586 695 self.name.as_ref()
587 let iterator_options = AlgIteratorOptions{ 696 }
588 max_iter : cli.max_iter, 697
589 verbose_iter : cli.verbose_iter 698 fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F> {
590 .map_or(Verbose::LogarithmicCap{base : 10, cap : 2},
591 |n| Verbose::Every(n)),
592 quiet : cli.quiet,
593 };
594
595 // Run the algorithm(s)
596 for named @ Named { name : alg_name, data : alg } in algorithms.iter() {
597 let this_prefix = format!("{}{}/", prefix, alg_name);
598
599 // Create Logger and IteratorFactory
600 let mut logger = Logger::new();
601 let iterator = iterator_options.instantiate()
602 .timed()
603 .into_log(&mut logger);
604
605 let running = if !cli.quiet {
606 format!("{}\n{}\n{}\n",
607 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(),
608 format!("Iteration settings: {}", serde_json::to_string(&iterator_options)?).bright_black(),
609 format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black())
610 } else {
611 "".to_string()
612 };
613 //
614 // The following is for postprocessing, which has been disabled anyway.
615 //
616 // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation {
617 // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)),
618 // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)),
619 // };
620 //let findim_data = reg.prepare_optimise_weights(&opA, &b);
621 //let inner_config : InnerSettings<F> = Default::default();
622 //let inner_it = inner_config.iterator_options;
623
624 // Create plotter and directory if needed.
625 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 };
626 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid.clone());
627
628 let start = Instant::now();
629 let start_cpu = ProcessTime::now();
630
631 let (μ, z) = match do_alg(alg, iterator, plotter, running) {
632 Ok(μ) => μ,
633 Err(RunError::NotImplemented) => {
634 let msg = format!("Algorithm “{alg_name}” not implemented for {experiment_name}. \
635 Skipping.").red();
636 eprintln!("{}", msg);
637 continue
638 }
639 };
640
641 let elapsed = start.elapsed().as_secs_f64();
642 let cpu_time = start_cpu.elapsed().as_secs_f64();
643
644 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow());
645
646 // Save results
647 println!("{}", "Saving results …".green());
648
649 let mkname = |t| format!("{prefix}{alg_name}_{t}");
650
651 write_json(mkname("config.json"), &named)?;
652 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?;
653 μ.write_csv(mkname("reco.txt"))?;
654 save_extra(mkname(""), z)?;
655 //logger.write_csv(mkname("log.txt"))?;
656 logs.push((mkname("log.txt"), logger));
657 }
658
659 save_logs(logs, format!("{prefix}valuerange.json"), cli.load_valuerange)
660 }
661
662 #[replace_float_literals(F::cast_from(literal))]
663 impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N : usize> RunnableExperiment<F> for
664 Named<ExperimentV2<F, NoiseDistr, S, K, P, N>>
665 where
666 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>
667 + Default + for<'b> Deserialize<'b>,
668 [usize; N] : Serialize,
669 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
670 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug,
671 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy
672 // TODO: shold not have differentiability as a requirement, but
673 // decide availability of sliding based on it.
674 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>,
675 // TODO: very weird that rust only compiles with Differentiable
676 // instead of the above one on references, which is required by
677 // poitsource_sliding_fb_reg.
678 + DifferentiableRealMapping<F, N>
679 + Lipschitz<L2, FloatType=F>,
680 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb.
681 AutoConvolution<P> : BoundedBy<F, K>,
682 K : SimpleConvolutionKernel<F, N>
683 + LocalAnalysis<F, Bounds<F>, N>
684 + Copy + Serialize + std::fmt::Debug,
685 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
686 PlotLookup : Plotting<N>,
687 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
688 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
689 RNDM<F, N> : SpikeMerging<F>,
690 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug,
691 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>,
692 // PreadjointCodomain : Space + Bounded<F> + DifferentiableRealMapping<F, N>,
693 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
694 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
695 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
696 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
697 {
698
699 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> {
700 AlgorithmOverrides { 699 AlgorithmOverrides {
701 merge_radius : Some(self.data.default_merge_radius), 700 merge_radius: Some(self.data.default_merge_radius),
702 .. self.data.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) 701 ..self
703 } 702 .data
704 } 703 .algorithm_overrides
705 704 .get(&alg)
706 fn runall(&self, cli : &CommandLineArgs, 705 .cloned()
707 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { 706 .unwrap_or(Default::default())
707 }
708 }
709
710 fn runall(
711 &self,
712 cli: &CommandLineArgs,
713 algs: Option<Vec<Named<AlgorithmConfig<F>>>>,
714 ) -> DynError {
708 // Get experiment configuration 715 // Get experiment configuration
709 let &Named { 716 let &ExperimentV2 {
710 name : ref experiment_name, 717 domain,
711 data : ExperimentV2 { 718 sensor_count,
712 domain, sensor_count, ref noise_distr, sensor, spread, kernel, 719 ref noise_distr,
713 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, 720 sensor,
714 .. 721 spread,
715 } 722 kernel,
716 } = self; 723 ref μ_hat,
724 regularisation,
725 kernel_plot_width,
726 dataterm,
727 noise_seed,
728 ..
729 } = &self.data;
717 730
718 // Set up algorithms 731 // Set up algorithms
719 let algorithms = match (algs, dataterm) { 732 let algorithms = match (algs, dataterm) {
720 (Some(algs), _) => algs, 733 (Some(algs), _) => algs,
721 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], 734 (None, DataTermType::L222) => vec![DefaultAlgorithm::FB.get_named()],
722 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], 735 (None, DataTermType::L1) => vec![DefaultAlgorithm::PDPS.get_named()],
723 }; 736 };
724 737
725 // Set up operators 738 // Set up operators
726 let depth = DynamicDepth(8); 739 let depth = DynamicDepth(8);
727 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); 740 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth);
736 let b = &b_hat + &noise; 749 let b = &b_hat + &noise;
737 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField 750 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
738 // overloading log10 and conflicting with standard NumTraits one. 751 // overloading log10 and conflicting with standard NumTraits one.
739 let stats = ExperimentStats::new(&b, &noise); 752 let stats = ExperimentStats::new(&b, &noise);
740 753
741 let prefix = start_experiment(&self, cli, stats)?; 754 let prefix = self.start(cli)?;
742 755 write_json(format!("{prefix}stats.json"), &stats)?;
743 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, 756
744 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; 757 plotall(
745 758 cli,
746 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); 759 &prefix,
760 &domain,
761 &sensor,
762 &kernel,
763 &spread,
764 &μ_hat,
765 &op𝒟,
766 &opA,
767 &b_hat,
768 &b,
769 kernel_plot_width,
770 )?;
771
772 let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]);
773 let make_plotter = |this_prefix| {
774 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 };
775 SeqPlotter::new(this_prefix, plot_count, plotgrid.clone())
776 };
747 777
748 let save_extra = |_, ()| Ok(()); 778 let save_extra = |_, ()| Ok(());
749 779
750 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, 780 let μ0 = None; // Zero init
751 |alg, iterator, plotter, running| 781
752 { 782 match (dataterm, regularisation) {
753 let μ = match alg { 783 (DataTermType::L1, Regularisation::Radon(α)) => {
754 AlgorithmConfig::FB(ref algconfig, prox) => { 784 let f = DataTerm::new(opA, b, L1.as_mapping());
755 match (regularisation, dataterm, prox) { 785 let reg = RadonRegTerm(α);
756 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 786 self.do_runall(
757 print!("{running}"); 787 &prefix,
758 pointsource_fb_reg( 788 cli,
759 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 789 algorithms,
760 iterator, plotter 790 make_plotter,
761 ) 791 save_extra,
762 }), 792 μ0,
763 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 793 |p| {
764 print!("{running}"); 794 run_pdps(&f, &reg, &RadonSquared, p, |p| {
765 pointsource_fb_reg( 795 run_pdps(&f, &reg, &op𝒟, p, |_| Err(NotImplemented.into()))
766 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 796 })
767 iterator, plotter 797 .map(|μ| (μ, ()))
768 ) 798 },
769 }), 799 )
770 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 800 }
771 print!("{running}"); 801 (DataTermType::L1, Regularisation::NonnegRadon(α)) => {
772 pointsource_fb_reg( 802 let f = DataTerm::new(opA, b, L1.as_mapping());
773 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 803 let reg = NonnegRadonRegTerm(α);
774 iterator, plotter 804 self.do_runall(
775 ) 805 &prefix,
776 }), 806 cli,
777 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 807 algorithms,
778 print!("{running}"); 808 make_plotter,
779 pointsource_fb_reg( 809 save_extra,
780 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, 810 μ0,
781 iterator, plotter 811 |p| {
782 ) 812 run_pdps(&f, &reg, &RadonSquared, p, |p| {
783 }), 813 run_pdps(&f, &reg, &op𝒟, p, |_| Err(NotImplemented.into()))
784 _ => Err(NotImplemented) 814 })
785 } 815 .map(|μ| (μ, ()))
786 }, 816 },
787 AlgorithmConfig::FISTA(ref algconfig, prox) => { 817 )
788 match (regularisation, dataterm, prox) { 818 }
789 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 819 (DataTermType::L222, Regularisation::Radon(α)) => {
790 print!("{running}"); 820 let f = DataTerm::new(opA, b, Norm222::new());
791 pointsource_fista_reg( 821 let reg = RadonRegTerm(α);
792 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 822 self.do_runall(
793 iterator, plotter 823 &prefix,
794 ) 824 cli,
795 }), 825 algorithms,
796 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 826 make_plotter,
797 print!("{running}"); 827 save_extra,
798 pointsource_fista_reg( 828 μ0,
799 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 829 |p| {
800 iterator, plotter 830 run_fb(&f, &reg, &RadonSquared, p, |p| {
801 ) 831 run_pdps(&f, &reg, &RadonSquared, p, |p| {
802 }), 832 run_fb(&f, &reg, &op𝒟, p, |p| {
803 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 833 run_pdps(&f, &reg, &op𝒟, p, |p| {
804 print!("{running}"); 834 run_fw(&f, &reg, p, |_| Err(NotImplemented.into()))
805 pointsource_fista_reg( 835 })
806 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 836 })
807 iterator, plotter 837 })
808 ) 838 })
809 }), 839 .map(|μ| (μ, ()))
810 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 840 },
811 print!("{running}"); 841 )
812 pointsource_fista_reg( 842 }
813 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, 843 (DataTermType::L222, Regularisation::NonnegRadon(α)) => {
814 iterator, plotter 844 let f = DataTerm::new(opA, b, Norm222::new());
815 ) 845 let reg = NonnegRadonRegTerm(α);
816 }), 846 self.do_runall(
817 _ => Err(NotImplemented), 847 &prefix,
818 } 848 cli,
819 }, 849 algorithms,
820 AlgorithmConfig::SlidingFB(ref algconfig, prox) => { 850 make_plotter,
821 match (regularisation, dataterm, prox) { 851 save_extra,
822 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 852 μ0,
823 print!("{running}"); 853 |p| {
824 pointsource_sliding_fb_reg( 854 run_fb(&f, &reg, &RadonSquared, p, |p| {
825 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 855 run_pdps(&f, &reg, &RadonSquared, p, |p| {
826 iterator, plotter 856 run_fb(&f, &reg, &op𝒟, p, |p| {
827 ) 857 run_pdps(&f, &reg, &op𝒟, p, |p| {
828 }), 858 run_fw(&f, &reg, p, |_| Err(NotImplemented.into()))
829 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 859 })
830 print!("{running}"); 860 })
831 pointsource_sliding_fb_reg( 861 })
832 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 862 })
833 iterator, plotter 863 .map(|μ| (μ, ()))
834 ) 864 },
835 }), 865 )
836 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 866 }
837 print!("{running}"); 867 }
838 pointsource_sliding_fb_reg( 868 }
839 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 869 }
840 iterator, plotter 870
841 ) 871 /// Runs PDPS if `alg` so requests and `prox_penalty` matches.
842 }), 872 ///
843 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 873 /// Due to the structure of the PDPS, the data term `f` has to have a specific form.
844 print!("{running}"); 874 ///
845 pointsource_sliding_fb_reg( 875 /// `cont` gives a continuation attempt to find the algorithm matching the description `alg`.
846 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, 876 pub fn run_pdps<'a, F, A, Phi, Reg, P, I, Plot, const N: usize>(
847 iterator, plotter 877 f: &'a DataTerm<F, RNDM<N, F>, A, Phi>,
848 ) 878 reg: &Reg,
849 }), 879 prox_penalty: &P,
850 _ => Err(NotImplemented), 880 (alg, iterator, plotter, μ0, running): (
851 } 881 &AlgorithmConfig<F>,
852 }, 882 I,
853 AlgorithmConfig::PDPS(ref algconfig, prox) => { 883 Plot,
854 print!("{running}"); 884 Option<RNDM<N, F>>,
855 match (regularisation, dataterm, prox) { 885 String,
856 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 886 ),
857 pointsource_pdps_reg( 887 cont: impl FnOnce(
858 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 888 (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String),
859 iterator, plotter, L2Squared 889 ) -> DynResult<RNDM<N, F>>,
860 ) 890 ) -> DynResult<RNDM<N, F>>
861 }), 891 where
862 (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 892 F: Float + ToNalgebraRealField,
863 pointsource_pdps_reg( 893 A: ForwardModel<RNDM<N, F>, F>,
864 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 894 Phi: Conjugable<A::Observable, F>,
865 iterator, plotter, L2Squared 895 for<'b> Phi::Conjugate<'b>: Prox<A::Observable>,
866 ) 896 for<'b> &'b A::Observable: Instance<A::Observable>,
867 }), 897 A::Observable: AXPY,
868 (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ 898 Reg: SlidingRegTerm<Loc<N, F>, F>,
869 pointsource_pdps_reg( 899 P: ProxPenalty<Loc<N, F>, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD<F, A, RNDM<N, F>>,
870 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 900 RNDM<N, F>: SpikeMerging<F>,
871 iterator, plotter, L1 901 I: AlgIteratorFactory<IterInfo<F>>,
872 ) 902 Plot: Plotter<P::ReturnMapping, A::PreadjointCodomain, RNDM<N, F>>,
873 }), 903 {
874 (Regularisation::Radon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ 904 match alg {
875 pointsource_pdps_reg( 905 &AlgorithmConfig::PDPS(ref algconfig, prox_type) if prox_type == P::prox_type() => {
876 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 906 print!("{running}");
877 iterator, plotter, L1 907 pointsource_pdps_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0)
878 ) 908 }
879 }), 909 _ => cont((alg, iterator, plotter, μ0, running)),
880 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 910 }
881 pointsource_pdps_reg( 911 }
882 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 912
883 iterator, plotter, L2Squared 913 /// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches.
884 ) 914 ///
885 }), 915 /// `cont` gives a continuation attempt to find the algorithm matching the description `alg`.
886 (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 916 pub fn run_fb<F, Dat, Reg, P, I, Plot, const N: usize>(
887 pointsource_pdps_reg( 917 f: &Dat,
888 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, 918 reg: &Reg,
889 iterator, plotter, L2Squared 919 prox_penalty: &P,
890 ) 920 (alg, iterator, plotter, μ0, running): (
891 }), 921 &AlgorithmConfig<F>,
892 (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ 922 I,
893 pointsource_pdps_reg( 923 Plot,
894 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 924 Option<RNDM<N, F>>,
895 iterator, plotter, L1 925 String,
896 ) 926 ),
897 }), 927 cont: impl FnOnce(
898 (Regularisation::Radon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ 928 (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String),
899 pointsource_pdps_reg( 929 ) -> DynResult<RNDM<N, F>>,
900 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, 930 ) -> DynResult<RNDM<N, F>>
901 iterator, plotter, L1 931 where
902 ) 932 F: Float + ToNalgebraRealField,
903 }), 933 I: AlgIteratorFactory<IterInfo<F>>,
904 // _ => Err(NotImplemented), 934 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>,
905 } 935 Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>,
906 }, 936 RNDM<N, F>: SpikeMerging<F>,
907 AlgorithmConfig::FW(ref algconfig) => { 937 Reg: SlidingRegTerm<Loc<N, F>, F>,
908 match (regularisation, dataterm) { 938 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
909 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ 939 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
910 print!("{running}"); 940 {
911 pointsource_fw_reg(&opA, &b, RadonRegTerm(α), 941 let pt = P::prox_type();
912 algconfig, iterator, plotter) 942
913 }), 943 match alg {
914 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ 944 &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => {
915 print!("{running}"); 945 print!("{running}");
916 pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), 946 pointsource_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0)
917 algconfig, iterator, plotter) 947 }
918 }), 948 &AlgorithmConfig::FISTA(ref algconfig, prox_type) if prox_type == pt => {
919 _ => Err(NotImplemented), 949 print!("{running}");
920 } 950 pointsource_fista_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0)
921 }, 951 }
922 _ => Err(NotImplemented), 952 &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => {
923 }?; 953 print!("{running}");
924 Ok((μ, ())) 954 pointsource_sliding_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0)
925 }) 955 }
926 } 956 _ => cont((alg, iterator, plotter, μ0, running)),
927 } 957 }
928 958 }
959
960 /// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches.
961 ///
962 /// For the moment, due to restrictions of the Frank–Wolfe implementation, only the
963 /// $L^2$-squared data term is enabled through the type signatures.
964 ///
965 /// `cont` gives a continuation attempt to find the algorithm matching the description `alg`.
966 pub fn run_fw<'a, F, A, Reg, I, Plot, const N: usize>(
967 f: &'a DataTerm<F, RNDM<N, F>, A, Norm222<F>>,
968 reg: &Reg,
969 (alg, iterator, plotter, μ0, running): (
970 &AlgorithmConfig<F>,
971 I,
972 Plot,
973 Option<RNDM<N, F>>,
974 String,
975 ),
976 cont: impl FnOnce(
977 (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String),
978 ) -> DynResult<RNDM<N, F>>,
979 ) -> DynResult<RNDM<N, F>>
980 where
981 F: Float + ToNalgebraRealField,
982 I: AlgIteratorFactory<IterInfo<F>>,
983 A: ForwardModel<RNDM<N, F>, F>,
984 A::PreadjointCodomain: MinMaxMapping<Loc<N, F>, F>,
985 for<'b> &'b A::PreadjointCodomain: Instance<A::PreadjointCodomain>,
986 Cube<N, F>: P2Minimise<Loc<N, F>, F>,
987 RNDM<N, F>: SpikeMerging<F>,
988 Reg: RegTermFW<F, A, ValueIteratorFactory<F, AlgIteratorOptions>, N>,
989 Plot: Plotter<A::PreadjointCodomain, A::PreadjointCodomain, RNDM<N, F>>,
990 {
991 match alg {
992 &AlgorithmConfig::FW(ref algconfig) => {
993 print!("{running}");
994 pointsource_fw_reg(f, reg, algconfig, iterator, plotter, μ0)
995 }
996 _ => cont((alg, iterator, plotter, μ0, running)),
997 }
998 }
929 999
930 #[replace_float_literals(F::cast_from(literal))] 1000 #[replace_float_literals(F::cast_from(literal))]
931 impl<F, NoiseDistr, S, K, P, B, /*PreadjointCodomain,*/ const N : usize> RunnableExperiment<F> for 1001 impl<F, NoiseDistr, S, K, P, B, const N: usize> RunnableExperiment<F>
932 Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> 1002 for Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>>
933 where 1003 where
934 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> 1004 F: ClapFloat
935 + Default + for<'b> Deserialize<'b>, 1005 + nalgebra::RealField
936 [usize; N] : Serialize, 1006 + ToNalgebraRealField<MixedType = F>
937 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, 1007 + Default
938 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, 1008 + for<'b> Deserialize<'b>,
939 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy 1009 [usize; N]: Serialize,
940 // TODO: shold not have differentiability as a requirement, but 1010 S: Sensor<N, F> + Copy + Serialize + std::fmt::Debug,
941 // decide availability of sliding based on it. 1011 P: Spread<N, F> + Copy + Serialize + std::fmt::Debug,
942 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, 1012 Convolution<S, P>: Spread<N, F>
943 // TODO: very weird that rust only compiles with Differentiable 1013 + Bounded<F>
944 // instead of the above one on references, which is required by
945 // poitsource_sliding_fb_reg.
946 + DifferentiableRealMapping<F, N>
947 + Lipschitz<L2, FloatType=F>,
948 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb.
949 AutoConvolution<P> : BoundedBy<F, K>,
950 K : SimpleConvolutionKernel<F, N>
951 + LocalAnalysis<F, Bounds<F>, N> 1014 + LocalAnalysis<F, Bounds<F>, N>
952 + Copy + Serialize + std::fmt::Debug, 1015 + Copy
953 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, 1016 // TODO: shold not have differentiability as a requirement, but
954 PlotLookup : Plotting<N>, 1017 // decide availability of sliding based on it.
955 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, 1018 //+ for<'b> Differentiable<&'b Loc<N, F>, Output = Loc<N, F>>,
1019 // TODO: very weird that rust only compiles with Differentiable
1020 // instead of the above one on references, which is required by
1021 // poitsource_sliding_fb_reg.
1022 + DifferentiableRealMapping<N, F>
1023 + Lipschitz<L2, FloatType = F>,
1024 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<N, F>>>::Differential<'b>:
1025 Lipschitz<L2, FloatType = F>, // TODO: should not be required generally, only for sliding_fb.
1026 AutoConvolution<P>: BoundedBy<F, K>,
1027 K: SimpleConvolutionKernel<N, F>
1028 + LocalAnalysis<F, Bounds<F>, N>
1029 + Copy
1030 + Serialize
1031 + std::fmt::Debug,
1032 Cube<N, F>: P2Minimise<Loc<N, F>, F> + SetOrd,
1033 PlotLookup: Plotting<N>,
1034 DefaultBT<F, N>: SensorGridBT<F, S, P, N, Depth = DynamicDepth> + BTSearch<N, F>,
956 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 1035 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
957 RNDM<F, N> : SpikeMerging<F>, 1036 RNDM<N, F>: SpikeMerging<F>,
958 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, 1037 NoiseDistr: Distribution<F> + Serialize + std::fmt::Debug,
959 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, 1038 B: Mapping<Loc<N, F>, Codomain = F> + Serialize + std::fmt::Debug,
960 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, 1039 nalgebra::DVector<F>: ClosedMul<F>,
961 // PreadjointCodomain : Bounded<F> + DifferentiableRealMapping<F, N>, 1040 // This is mainly required for the final Mul requirement to be defined
1041 // DefaultSG<F, S, P, N>: ForwardModel<
1042 // RNDM<N, F>,
1043 // F,
1044 // PreadjointCodomain = PreadjointCodomain,
1045 // Observable = DVector<F::MixedType>,
1046 // >,
1047 // PreadjointCodomain: Bounded<F> + DifferentiableRealMapping<N, F> + std::ops::Mul<F>,
1048 // Pair<PreadjointCodomain, DVector<F>>: std::ops::Mul<F>,
962 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, 1049 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
963 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, 1050 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
964 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, 1051 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
965 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, 1052 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
966 { 1053 {
967 1054 fn name(&self) -> &str {
968 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { 1055 self.name.as_ref()
1056 }
1057
1058 fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F> {
969 AlgorithmOverrides { 1059 AlgorithmOverrides {
970 merge_radius : Some(self.data.base.default_merge_radius), 1060 merge_radius: Some(self.data.base.default_merge_radius),
971 .. self.data.base.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) 1061 ..self
972 } 1062 .data
973 } 1063 .base
974 1064 .algorithm_overrides
975 fn runall(&self, cli : &CommandLineArgs, 1065 .get(&alg)
976 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { 1066 .cloned()
1067 .unwrap_or(Default::default())
1068 }
1069 }
1070
1071 fn runall(
1072 &self,
1073 cli: &CommandLineArgs,
1074 algs: Option<Vec<Named<AlgorithmConfig<F>>>>,
1075 ) -> DynError {
977 // Get experiment configuration 1076 // Get experiment configuration
978 let &Named { 1077 let &ExperimentBiased {
979 name : ref experiment_name, 1078 λ,
980 data : ExperimentBiased { 1079 ref bias,
981 λ, 1080 base:
982 ref bias, 1081 ExperimentV2 {
983 base : ExperimentV2 { 1082 domain,
984 domain, sensor_count, ref noise_distr, sensor, spread, kernel, 1083 sensor_count,
985 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, 1084 ref noise_distr,
1085 sensor,
1086 spread,
1087 kernel,
1088 ref μ_hat,
1089 regularisation,
1090 kernel_plot_width,
1091 dataterm,
1092 noise_seed,
986 .. 1093 ..
987 } 1094 },
988 } 1095 } = &self.data;
989 } = self;
990 1096
991 // Set up algorithms 1097 // Set up algorithms
992 let algorithms = match (algs, dataterm) { 1098 let algorithms = match (algs, dataterm) {
993 (Some(algs), _) => algs, 1099 (Some(algs), _) => algs,
994 _ => vec![DefaultAlgorithm::SlidingPDPS.get_named()], 1100 _ => vec![DefaultAlgorithm::SlidingPDPS.get_named()],
998 let depth = DynamicDepth(8); 1104 let depth = DynamicDepth(8);
999 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); 1105 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth);
1000 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); 1106 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel);
1001 let opAext = RowOp(opA.clone(), IdOp::new()); 1107 let opAext = RowOp(opA.clone(), IdOp::new());
1002 let fnR = Zero::new(); 1108 let fnR = Zero::new();
1003 let h = map3(domain.span_start(), domain.span_end(), sensor_count, 1109 let h = map3(
1004 |a, b, n| (b-a)/F::cast_from(n)) 1110 domain.span_start(),
1005 .into_iter() 1111 domain.span_end(),
1006 .reduce(NumTraitsFloat::max) 1112 sensor_count,
1007 .unwrap(); 1113 |a, b, n| (b - a) / F::cast_from(n),
1114 )
1115 .into_iter()
1116 .reduce(NumTraitsFloat::max)
1117 .unwrap();
1008 let z = DVector::zeros(sensor_count.iter().product()); 1118 let z = DVector::zeros(sensor_count.iter().product());
1009 let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap(); 1119 let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap();
1010 let y = opKz.apply(&z); 1120 let y = opKz.apply(&z);
1011 let fnH = Weighted{ base_fn : L1.as_mapping(), weight : λ}; // TODO: L_{2,1} 1121 let fnH = Weighted { base_fn: L1.as_mapping(), weight: λ }; // TODO: L_{2,1}
1012 // let zero_y = y.clone(); 1122 // let zero_y = y.clone();
1013 // let zeroBTFN = opA.preadjoint().apply(&zero_y); 1123 // let zeroBTFN = opA.preadjoint().apply(&zero_y);
1014 // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); 1124 // let opKμ = ZeroOp::new(&zero_y, zeroBTFN);
1015 1125
1016 // Set up random number generator. 1126 // Set up random number generator.
1017 let mut rng = StdRng::seed_from_u64(noise_seed); 1127 let mut rng = StdRng::seed_from_u64(noise_seed);
1018 1128
1019 // Generate the data and calculate SSNR statistic 1129 // Generate the data and calculate SSNR statistic
1020 let bias_vec = DVector::from_vec(opA.grid() 1130 let bias_vec = DVector::from_vec(
1021 .into_iter() 1131 opA.grid()
1022 .map(|v| bias.apply(v)) 1132 .into_iter()
1023 .collect::<Vec<F>>()); 1133 .map(|v| bias.apply(v))
1024 let b_hat : DVector<_> = opA.apply(μ_hat) + &bias_vec; 1134 .collect::<Vec<F>>(),
1135 );
1136 let b_hat: DVector<_> = opA.apply(μ_hat) + &bias_vec;
1025 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); 1137 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng);
1026 let b = &b_hat + &noise; 1138 let b = &b_hat + &noise;
1027 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField 1139 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
1028 // overloading log10 and conflicting with standard NumTraits one. 1140 // overloading log10 and conflicting with standard NumTraits one.
1029 let stats = ExperimentStats::new(&b, &noise); 1141 let stats = ExperimentStats::new(&b, &noise);
1030 1142
1031 let prefix = start_experiment(&self, cli, stats)?; 1143 let prefix = self.start(cli)?;
1032 1144 write_json(format!("{prefix}stats.json"), &stats)?;
1033 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, 1145
1034 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; 1146 plotall(
1147 cli,
1148 &prefix,
1149 &domain,
1150 &sensor,
1151 &kernel,
1152 &spread,
1153 &μ_hat,
1154 &op𝒟,
1155 &opA,
1156 &b_hat,
1157 &b,
1158 kernel_plot_width,
1159 )?;
1035 1160
1036 opA.write_observable(&bias_vec, format!("{prefix}bias"))?; 1161 opA.write_observable(&bias_vec, format!("{prefix}bias"))?;
1037 1162
1038 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); 1163 let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]);
1164 let make_plotter = |this_prefix| {
1165 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 };
1166 SeqPlotter::new(this_prefix, plot_count, plotgrid.clone())
1167 };
1039 1168
1040 let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z")); 1169 let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z"));
1041 1170
1042 // Run the algorithms 1171 let μ0 = None; // Zero init
1043 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, 1172
1044 |alg, iterator, plotter, running| 1173 match (dataterm, regularisation) {
1045 { 1174 (DataTermType::L222, Regularisation::Radon(α)) => {
1046 let Pair(μ, z) = match alg { 1175 let f = DataTerm::new(opAext, b, Norm222::new());
1047 AlgorithmConfig::ForwardPDPS(ref algconfig, prox) => { 1176 let reg = RadonRegTerm(α);
1048 match (regularisation, dataterm, prox) { 1177 self.do_runall(
1049 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 1178 &prefix,
1050 print!("{running}"); 1179 cli,
1051 pointsource_forward_pdps_pair( 1180 algorithms,
1052 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 1181 make_plotter,
1053 iterator, plotter, 1182 save_extra,
1054 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1183 (μ0, z, y),
1055 ) 1184 |p| {
1056 }), 1185 run_pdps_pair(&f, &reg, &RadonSquared, &opKz, &fnR, &fnH, p, |q| {
1057 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 1186 run_pdps_pair(&f, &reg, &op𝒟, &opKz, &fnR, &fnH, q, |_| {
1058 print!("{running}"); 1187 Err(NotImplemented.into())
1059 pointsource_forward_pdps_pair( 1188 })
1060 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, 1189 })
1061 iterator, plotter, 1190 .map(|Pair(μ, z)| (μ, z))
1062 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1191 },
1063 ) 1192 )
1064 }), 1193 }
1065 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 1194 (DataTermType::L222, Regularisation::NonnegRadon(α)) => {
1066 print!("{running}"); 1195 let f = DataTerm::new(opAext, b, Norm222::new());
1067 pointsource_forward_pdps_pair( 1196 let reg = NonnegRadonRegTerm(α);
1068 &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 1197 self.do_runall(
1069 iterator, plotter, 1198 &prefix,
1070 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1199 cli,
1071 ) 1200 algorithms,
1072 }), 1201 make_plotter,
1073 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 1202 save_extra,
1074 print!("{running}"); 1203 (μ0, z, y),
1075 pointsource_forward_pdps_pair( 1204 |p| {
1076 &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, 1205 run_pdps_pair(&f, &reg, &RadonSquared, &opKz, &fnR, &fnH, p, |q| {
1077 iterator, plotter, 1206 run_pdps_pair(&f, &reg, &op𝒟, &opKz, &fnR, &fnH, q, |_| {
1078 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1207 Err(NotImplemented.into())
1079 ) 1208 })
1080 }), 1209 })
1081 _ => Err(NotImplemented) 1210 .map(|Pair(μ, z)| (μ, z))
1082 } 1211 },
1083 }, 1212 )
1084 AlgorithmConfig::SlidingPDPS(ref algconfig, prox) => { 1213 }
1085 match (regularisation, dataterm, prox) { 1214 _ => Err(NotImplemented.into()),
1086 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 1215 }
1087 print!("{running}"); 1216 }
1088 pointsource_sliding_pdps_pair( 1217 }
1089 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 1218
1090 iterator, plotter, 1219 type MeasureZ<F, Z, const N: usize> = Pair<RNDM<N, F>, Z>;
1091 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1220
1092 ) 1221 pub fn run_pdps_pair<F, S, Dat, Reg, Z, R, Y, KOpZ, H, P, I, Plot, const N: usize>(
1093 }), 1222 f: &Dat,
1094 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ 1223 reg: &Reg,
1095 print!("{running}"); 1224 prox_penalty: &P,
1096 pointsource_sliding_pdps_pair( 1225 opKz: &KOpZ,
1097 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, 1226 fnR: &R,
1098 iterator, plotter, 1227 fnH: &H,
1099 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1228 (alg, iterator, plotter, μ0zy, running): (
1100 ) 1229 &AlgorithmConfig<F>,
1101 }), 1230 I,
1102 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 1231 Plot,
1103 print!("{running}"); 1232 (Option<RNDM<N, F>>, Z, Y),
1104 pointsource_sliding_pdps_pair( 1233 String,
1105 &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, 1234 ),
1106 iterator, plotter, 1235 cont: impl FnOnce(
1107 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1236 (
1108 ) 1237 &AlgorithmConfig<F>,
1109 }), 1238 I,
1110 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ 1239 Plot,
1111 print!("{running}"); 1240 (Option<RNDM<N, F>>, Z, Y),
1112 pointsource_sliding_pdps_pair( 1241 String,
1113 &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, 1242 ),
1114 iterator, plotter, 1243 ) -> DynResult<Pair<RNDM<N, F>, Z>>,
1115 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1244 ) -> DynResult<Pair<RNDM<N, F>, Z>>
1116 ) 1245 where
1117 }), 1246 F: Float + ToNalgebraRealField,
1118 _ => Err(NotImplemented) 1247 I: AlgIteratorFactory<IterInfo<F>>,
1119 } 1248 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>
1120 }, 1249 + BoundedCurvature<F>,
1121 _ => Err(NotImplemented) 1250 S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
1122 }?; 1251 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
1123 Ok((μ, z)) 1252 //Pair<S, Z>: ClosedMul<F>,
1124 }) 1253 RNDM<N, F>: SpikeMerging<F>,
1254 Reg: SlidingRegTerm<Loc<N, F>, F>,
1255 P: ProxPenalty<Loc<N, F>, S, Reg, F>,
1256 KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y>
1257 + GEMV<F, Z>
1258 + SimplyAdjointable<Z, Y, AdjointCodomain = Z>,
1259 KOpZ::SimpleAdjoint: GEMV<F, Y>,
1260 Y: ClosedEuclidean<F> + Clone,
1261 for<'b> &'b Y: Instance<Y>,
1262 Z: ClosedEuclidean<F> + Clone + ClosedMul<F>,
1263 for<'b> &'b Z: Instance<Z>,
1264 R: Prox<Z, Codomain = F>,
1265 H: Conjugable<Y, F, Codomain = F>,
1266 for<'b> H::Conjugate<'b>: Prox<Y>,
1267 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
1268 {
1269 let pt = P::prox_type();
1270
1271 match alg {
1272 &AlgorithmConfig::ForwardPDPS(ref algconfig, prox_type) if prox_type == pt => {
1273 print!("{running}");
1274 pointsource_forward_pdps_pair(
1275 f,
1276 reg,
1277 prox_penalty,
1278 algconfig,
1279 iterator,
1280 plotter,
1281 μ0zy,
1282 opKz,
1283 fnR,
1284 fnH,
1285 )
1286 }
1287 &AlgorithmConfig::SlidingPDPS(ref algconfig, prox_type) if prox_type == pt => {
1288 print!("{running}");
1289 pointsource_sliding_pdps_pair(
1290 f,
1291 reg,
1292 prox_penalty,
1293 algconfig,
1294 iterator,
1295 plotter,
1296 μ0zy,
1297 opKz,
1298 fnR,
1299 fnH,
1300 )
1301 }
1302 _ => cont((alg, iterator, plotter, μ0zy, running)),
1303 }
1304 }
1305
1306 pub fn run_fb_pair<F, S, Dat, Reg, Z, R, P, I, Plot, const N: usize>(
1307 f: &Dat,
1308 reg: &Reg,
1309 prox_penalty: &P,
1310 fnR: &R,
1311 (alg, iterator, plotter, μ0z, running): (
1312 &AlgorithmConfig<F>,
1313 I,
1314 Plot,
1315 (Option<RNDM<N, F>>, Z),
1316 String,
1317 ),
1318 cont: impl FnOnce(
1319 (
1320 &AlgorithmConfig<F>,
1321 I,
1322 Plot,
1323 (Option<RNDM<N, F>>, Z),
1324 String,
1325 ),
1326 ) -> DynResult<Pair<RNDM<N, F>, Z>>,
1327 ) -> DynResult<Pair<RNDM<N, F>, Z>>
1328 where
1329 F: Float + ToNalgebraRealField,
1330 I: AlgIteratorFactory<IterInfo<F>>,
1331 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>
1332 + BoundedCurvature<F>,
1333 S: DifferentiableRealMapping<N, F> + ClosedMul<F>,
1334 RNDM<N, F>: SpikeMerging<F>,
1335 Reg: SlidingRegTerm<Loc<N, F>, F>,
1336 P: ProxPenalty<Loc<N, F>, S, Reg, F>,
1337 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>,
1338 Z: ClosedEuclidean<F> + AXPY + Clone,
1339 for<'b> &'b Z: Instance<Z>,
1340 R: Prox<Z, Codomain = F>,
1341 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>,
1342 // We should not need to explicitly require this:
1343 for<'b> &'b Loc<0, F>: Instance<Loc<0, F>>,
1344 {
1345 let pt = P::prox_type();
1346
1347 match alg {
1348 &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => {
1349 print!("{running}");
1350 pointsource_fb_pair(f, reg, prox_penalty, algconfig, iterator, plotter, μ0z, fnR)
1351 }
1352 &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => {
1353 print!("{running}");
1354 pointsource_sliding_fb_pair(
1355 f,
1356 reg,
1357 prox_penalty,
1358 algconfig,
1359 iterator,
1360 plotter,
1361 μ0z,
1362 fnR,
1363 )
1364 }
1365 _ => cont((alg, iterator, plotter, μ0z, running)),
1125 } 1366 }
1126 } 1367 }
1127 1368
1128 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 1369 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
1129 struct ValueRange<F : Float> { 1370 struct ValueRange<F: Float> {
1130 ini : F, 1371 ini: F,
1131 min : F, 1372 min: F,
1132 } 1373 }
1133 1374
1134 impl<F : Float> ValueRange<F> { 1375 impl<F: Float> ValueRange<F> {
1135 fn expand_with(self, other : Self) -> Self { 1376 fn expand_with(self, other: Self) -> Self {
1136 ValueRange { 1377 ValueRange { ini: self.ini.max(other.ini), min: self.min.min(other.min) }
1137 ini : self.ini.max(other.ini),
1138 min : self.min.min(other.min),
1139 }
1140 } 1378 }
1141 } 1379 }
1142 1380
1143 /// Calculative minimum and maximum values of all the `logs`, and save them into 1381 /// Calculative minimum and maximum values of all the `logs`, and save them into
1144 /// corresponding file names given as the first elements of the tuples in the vectors. 1382 /// corresponding file names given as the first elements of the tuples in the vectors.
1145 fn save_logs<F : Float + for<'b> Deserialize<'b>, const N : usize>( 1383 fn save_logs<F: Float + for<'b> Deserialize<'b>>(
1146 logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)>, 1384 logs: Vec<(String, Logger<Timed<IterInfo<F>>>)>,
1147 valuerange_file : String, 1385 valuerange_file: String,
1148 load_valuerange : bool, 1386 load_valuerange: bool,
1149 ) -> DynError { 1387 ) -> DynError {
1150 // Process logs for relative values 1388 // Process logs for relative values
1151 println!("{}", "Processing logs…"); 1389 println!("{}", "Processing logs…");
1152 1390
1153 // Find minimum value and initial value within a single log 1391 // Find minimum value and initial value within a single log
1154 let proc_single_log = |log : &Logger<Timed<IterInfo<F, N>>>| { 1392 let proc_single_log = |log: &Logger<Timed<IterInfo<F>>>| {
1155 let d = log.data(); 1393 let d = log.data();
1156 let mi = d.iter() 1394 let mi = d.iter().map(|i| i.data.value).reduce(NumTraitsFloat::min);
1157 .map(|i| i.data.value)
1158 .reduce(NumTraitsFloat::min);
1159 d.first() 1395 d.first()
1160 .map(|i| i.data.value) 1396 .map(|i| i.data.value)
1161 .zip(mi) 1397 .zip(mi)
1162 .map(|(ini, min)| ValueRange{ ini, min }) 1398 .map(|(ini, min)| ValueRange { ini, min })
1163 }; 1399 };
1164 1400
1165 // Find minimum and maximum value over all logs 1401 // Find minimum and maximum value over all logs
1166 let mut v = logs.iter() 1402 let mut v = logs
1167 .filter_map(|&(_, ref log)| proc_single_log(log)) 1403 .iter()
1168 .reduce(|v1, v2| v1.expand_with(v2)) 1404 .filter_map(|&(_, ref log)| proc_single_log(log))
1169 .ok_or(anyhow!("No algorithms found"))?; 1405 .reduce(|v1, v2| v1.expand_with(v2))
1406 .ok_or(anyhow!("No algorithms found"))?;
1170 1407
1171 // Load existing range 1408 // Load existing range
1172 if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() { 1409 if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() {
1173 let data = std::fs::read_to_string(&valuerange_file)?; 1410 let data = std::fs::read_to_string(&valuerange_file)?;
1174 v = v.expand_with(serde_json::from_str(&data)?); 1411 v = v.expand_with(serde_json::from_str(&data)?);
1181 inner_iters, 1418 inner_iters,
1182 merged, 1419 merged,
1183 pruned, 1420 pruned,
1184 //postprocessing, 1421 //postprocessing,
1185 this_iters, 1422 this_iters,
1423 ε,
1186 .. 1424 ..
1187 } = data; 1425 } = data;
1188 // let post_value = match (postprocessing, dataterm) { 1426 // let post_value = match (postprocessing, dataterm) {
1189 // (Some(mut μ), DataTerm::L2Squared) => { 1427 // (Some(mut μ), DataTermType::L222) => {
1190 // // Comparison postprocessing is only implemented for the case handled 1428 // // Comparison postprocessing is only implemented for the case handled
1191 // // by the FW variants. 1429 // // by the FW variants.
1192 // reg.optimise_weights( 1430 // reg.optimise_weights(
1193 // &mut μ, &opA, &b, &findim_data, &inner_config, 1431 // &mut μ, &opA, &b, &findim_data, &inner_config,
1194 // inner_it 1432 // inner_it
1196 // dataterm.value_at_residual(opA.apply(&μ) - &b) 1434 // dataterm.value_at_residual(opA.apply(&μ) - &b)
1197 // + regularisation.apply(&μ) 1435 // + regularisation.apply(&μ)
1198 // }, 1436 // },
1199 // _ => value, 1437 // _ => value,
1200 // }; 1438 // };
1201 let relative_value = (value - v.min)/(v.ini - v.min); 1439 let relative_value = (value - v.min) / (v.ini - v.min);
1202 CSVLog { 1440 CSVLog {
1203 iter, 1441 iter,
1204 value, 1442 value,
1205 relative_value, 1443 relative_value,
1206 //post_value, 1444 //post_value,
1207 n_spikes, 1445 n_spikes,
1208 cpu_time : cpu_time.as_secs_f64(), 1446 cpu_time: cpu_time.as_secs_f64(),
1209 inner_iters, 1447 inner_iters,
1210 merged, 1448 merged,
1211 pruned, 1449 pruned,
1212 this_iters 1450 this_iters,
1451 epsilon: ε,
1213 } 1452 }
1214 }; 1453 };
1215 1454
1216 println!("{}", "Saving logs …".green()); 1455 println!("{}", "Saving logs …".green());
1217 1456
1222 } 1461 }
1223 1462
1224 Ok(()) 1463 Ok(())
1225 } 1464 }
1226 1465
1227
1228 /// Plot experiment setup 1466 /// Plot experiment setup
1229 #[replace_float_literals(F::cast_from(literal))] 1467 #[replace_float_literals(F::cast_from(literal))]
1230 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( 1468 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N: usize>(
1231 cli : &CommandLineArgs, 1469 cli: &CommandLineArgs,
1232 prefix : &String, 1470 prefix: &String,
1233 domain : &Cube<F, N>, 1471 domain: &Cube<N, F>,
1234 sensor : &Sensor, 1472 sensor: &Sensor,
1235 kernel : &Kernel, 1473 kernel: &Kernel,
1236 spread : &Spread, 1474 spread: &Spread,
1237 μ_hat : &RNDM<F, N>, 1475 μ_hat: &RNDM<N, F>,
1238 op𝒟 : &𝒟, 1476 op𝒟: &𝒟,
1239 opA : &A, 1477 opA: &A,
1240 b_hat : &A::Observable, 1478 b_hat: &A::Observable,
1241 b : &A::Observable, 1479 b: &A::Observable,
1242 kernel_plot_width : F, 1480 kernel_plot_width: F,
1243 ) -> DynError 1481 ) -> DynError
1244 where F : Float + ToNalgebraRealField, 1482 where
1245 Sensor : RealMapping<F, N> + Support<F, N> + Clone, 1483 F: Float + ToNalgebraRealField,
1246 Spread : RealMapping<F, N> + Support<F, N> + Clone, 1484 Sensor: RealMapping<N, F> + Support<N, F> + Clone,
1247 Kernel : RealMapping<F, N> + Support<F, N>, 1485 Spread: RealMapping<N, F> + Support<N, F> + Clone,
1248 Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>, 1486 Kernel: RealMapping<N, F> + Support<N, F>,
1249 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, 1487 Convolution<Sensor, Spread>: DifferentiableRealMapping<N, F> + Support<N, F>,
1250 𝒟::Codomain : RealMapping<F, N>, 1488 𝒟: DiscreteMeasureOp<Loc<N, F>, F>,
1251 A : ForwardModel<RNDM<F, N>, F>, 1489 𝒟::Codomain: RealMapping<N, F>,
1252 for<'a> &'a A::Observable : Instance<A::Observable>, 1490 A: ForwardModel<RNDM<N, F>, F>,
1253 A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>, 1491 for<'a> &'a A::Observable: Instance<A::Observable>,
1254 PlotLookup : Plotting<N>, 1492 A::PreadjointCodomain: DifferentiableRealMapping<N, F> + Bounded<F>,
1255 Cube<F, N> : SetOrd { 1493 PlotLookup: Plotting<N>,
1256 1494 Cube<N, F>: SetOrd,
1495 {
1257 if cli.plot < PlotLevel::Data { 1496 if cli.plot < PlotLevel::Data {
1258 return Ok(()) 1497 return Ok(());
1259 } 1498 }
1260 1499
1261 let base = Convolution(sensor.clone(), spread.clone()); 1500 let base = Convolution(sensor.clone(), spread.clone());
1262 1501
1263 let resolution = if N==1 { 100 } else { 40 }; 1502 let resolution = if N == 1 { 100 } else { 40 };
1264 let pfx = |n| format!("{prefix}{n}"); 1503 let pfx = |n| format!("{prefix}{n}");
1265 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); 1504 let plotgrid = lingrid(
1505 &[[-kernel_plot_width, kernel_plot_width]; N].into(),
1506 &[resolution; N],
1507 );
1266 1508
1267 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor")); 1509 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"));
1268 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel")); 1510 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"));
1269 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread")); 1511 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"));
1270 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor")); 1512 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"));
1271 1513
1272 let plotgrid2 = lingrid(&domain, &[resolution; N]); 1514 let plotgrid2 = lingrid(&domain, &[resolution; N]);
1273 1515
1274 let ω_hat = op𝒟.apply(μ_hat); 1516 let ω_hat = op𝒟.apply(μ_hat);
1275 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); 1517 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b);
1276 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat")); 1518 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat"));
1277 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise")); 1519 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise"));
1278 1520
1279 let preadj_b = opA.preadjoint().apply(b); 1521 let preadj_b = opA.preadjoint().apply(b);
1280 let preadj_b_hat = opA.preadjoint().apply(b_hat); 1522 let preadj_b_hat = opA.preadjoint().apply(b_hat);
1281 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); 1523 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds());
1282 PlotLookup::plot_into_file_spikes( 1524 PlotLookup::plot_into_file_spikes(
1283 Some(&preadj_b), 1525 Some(&preadj_b),
1284 Some(&preadj_b_hat), 1526 Some(&preadj_b_hat),
1285 plotgrid2, 1527 plotgrid2,
1286 &μ_hat, 1528 &μ_hat,
1287 pfx("omega_b") 1529 pfx("omega_b"),
1288 ); 1530 );
1289 PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b")); 1531 PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b"));
1290 PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat")); 1532 PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"));
1291 1533
1292 // Save true solution and observables 1534 // Save true solution and observables

mercurial