| 1 /*! |
1 /*! |
| 2 This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. |
2 This module provides [`RunnableExperiment`] for running chosen algorithms on a chosen experiment. |
| 3 */ |
3 */ |
| 4 |
4 |
| 5 use numeric_literals::replace_float_literals; |
5 use crate::fb::{pointsource_fb_reg, pointsource_fista_reg, FBConfig, InsertionConfig}; |
| 6 use colored::Colorize; |
|
| 7 use serde::{Serialize, Deserialize}; |
|
| 8 use serde_json; |
|
| 9 use nalgebra::base::DVector; |
|
| 10 use std::hash::Hash; |
|
| 11 use chrono::{DateTime, Utc}; |
|
| 12 use cpu_time::ProcessTime; |
|
| 13 use clap::ValueEnum; |
|
| 14 use std::collections::HashMap; |
|
| 15 use std::time::Instant; |
|
| 16 |
|
| 17 use rand::prelude::{ |
|
| 18 StdRng, |
|
| 19 SeedableRng |
|
| 20 }; |
|
| 21 use rand_distr::Distribution; |
|
| 22 |
|
| 23 use alg_tools::bisection_tree::*; |
|
| 24 use alg_tools::iterate::{ |
|
| 25 Timed, |
|
| 26 AlgIteratorOptions, |
|
| 27 Verbose, |
|
| 28 AlgIteratorFactory, |
|
| 29 LoggingIteratorFactory, |
|
| 30 TimingIteratorFactory, |
|
| 31 BasicAlgIteratorFactory, |
|
| 32 }; |
|
| 33 use alg_tools::logger::Logger; |
|
| 34 use alg_tools::error::{ |
|
| 35 DynError, |
|
| 36 DynResult, |
|
| 37 }; |
|
| 38 use alg_tools::tabledump::TableDump; |
|
| 39 use alg_tools::sets::Cube; |
|
| 40 use alg_tools::mapping::{ |
|
| 41 RealMapping, |
|
| 42 DifferentiableMapping, |
|
| 43 DifferentiableRealMapping, |
|
| 44 Instance |
|
| 45 }; |
|
| 46 use alg_tools::nalgebra_support::ToNalgebraRealField; |
|
| 47 use alg_tools::euclidean::Euclidean; |
|
| 48 use alg_tools::lingrid::{lingrid, LinSpace}; |
|
| 49 use alg_tools::sets::SetOrd; |
|
| 50 use alg_tools::linops::{RowOp, IdOp /*, ZeroOp*/}; |
|
| 51 use alg_tools::discrete_gradient::{Grad, ForwardNeumann}; |
|
| 52 use alg_tools::convex::Zero; |
|
| 53 use alg_tools::maputil::map3; |
|
| 54 use alg_tools::direct_product::Pair; |
|
| 55 |
|
| 56 use crate::kernels::*; |
|
| 57 use crate::types::*; |
|
| 58 use crate::measures::*; |
|
| 59 use crate::measures::merging::{SpikeMerging,SpikeMergingMethod}; |
|
| 60 use crate::forward_model::*; |
|
| 61 use crate::forward_model::sensor_grid::{ |
6 use crate::forward_model::sensor_grid::{ |
| |
7 //SensorGridBTFN, |
| |
8 Sensor, |
| 62 SensorGrid, |
9 SensorGrid, |
| 63 SensorGridBT, |
10 SensorGridBT, |
| 64 //SensorGridBTFN, |
|
| 65 Sensor, |
|
| 66 Spread, |
11 Spread, |
| 67 }; |
12 }; |
| 68 |
13 use crate::forward_model::*; |
| 69 use crate::fb::{ |
14 use crate::forward_pdps::{pointsource_fb_pair, pointsource_forward_pdps_pair, ForwardPDPSConfig}; |
| 70 FBConfig, |
15 use crate::frank_wolfe::{pointsource_fw_reg, FWConfig, FWVariant, RegTermFW}; |
| 71 FBGenericConfig, |
16 use crate::kernels::*; |
| 72 pointsource_fb_reg, |
17 use crate::measures::merging::{SpikeMerging, SpikeMergingMethod}; |
| 73 pointsource_fista_reg, |
18 use crate::measures::*; |
| |
19 use crate::pdps::{pointsource_pdps_reg, PDPSConfig}; |
| |
20 use crate::plot::*; |
| |
21 use crate::prox_penalty::{ |
| |
22 ProxPenalty, ProxTerm, RadonSquared, StepLengthBound, StepLengthBoundPD, StepLengthBoundPair, |
| 74 }; |
23 }; |
| 75 use crate::sliding_fb::{ |
24 use crate::regularisation::{NonnegRadonRegTerm, RadonRegTerm, Regularisation, SlidingRegTerm}; |
| 76 SlidingFBConfig, |
25 use crate::seminorms::*; |
| 77 TransportConfig, |
26 use crate::sliding_fb::{pointsource_sliding_fb_reg, SlidingFBConfig, TransportConfig}; |
| 78 pointsource_sliding_fb_reg |
27 use crate::sliding_pdps::{ |
| |
28 pointsource_sliding_fb_pair, pointsource_sliding_pdps_pair, SlidingPDPSConfig, |
| 79 }; |
29 }; |
| 80 use crate::sliding_pdps::{ |
30 use crate::subproblem::{InnerMethod, InnerSettings}; |
| 81 SlidingPDPSConfig, |
31 use crate::tolerance::Tolerance; |
| 82 pointsource_sliding_pdps_pair |
32 use crate::types::*; |
| |
33 use crate::{AlgorithmOverrides, CommandLineArgs}; |
| |
34 use alg_tools::bisection_tree::*; |
| |
35 use alg_tools::bounds::{Bounded, MinMaxMapping}; |
| |
36 use alg_tools::convex::{Conjugable, Norm222, Prox, Zero}; |
| |
37 use alg_tools::direct_product::Pair; |
| |
38 use alg_tools::discrete_gradient::{ForwardNeumann, Grad}; |
| |
39 use alg_tools::error::{DynError, DynResult}; |
| |
40 use alg_tools::euclidean::{ClosedEuclidean, Euclidean}; |
| |
41 use alg_tools::iterate::{ |
| |
42 AlgIteratorFactory, AlgIteratorOptions, BasicAlgIteratorFactory, LoggingIteratorFactory, Timed, |
| |
43 TimingIteratorFactory, ValueIteratorFactory, Verbose, |
| 83 }; |
44 }; |
| 84 use crate::forward_pdps::{ |
45 use alg_tools::lingrid::lingrid; |
| 85 ForwardPDPSConfig, |
46 use alg_tools::linops::{IdOp, RowOp, AXPY}; |
| 86 pointsource_forward_pdps_pair |
47 use alg_tools::logger::Logger; |
| |
48 use alg_tools::mapping::{ |
| |
49 DataTerm, DifferentiableMapping, DifferentiableRealMapping, Instance, RealMapping, |
| 87 }; |
50 }; |
| 88 use crate::pdps::{ |
51 use alg_tools::maputil::map3; |
| 89 PDPSConfig, |
52 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 90 pointsource_pdps_reg, |
53 use alg_tools::norms::{NormExponent, L1, L2}; |
| 91 }; |
|
| 92 use crate::frank_wolfe::{ |
|
| 93 FWConfig, |
|
| 94 FWVariant, |
|
| 95 pointsource_fw_reg, |
|
| 96 //WeightOptim, |
|
| 97 }; |
|
| 98 use crate::subproblem::{InnerSettings, InnerMethod}; |
|
| 99 use crate::seminorms::*; |
|
| 100 use crate::plot::*; |
|
| 101 use crate::{AlgorithmOverrides, CommandLineArgs}; |
|
| 102 use crate::tolerance::Tolerance; |
|
| 103 use crate::regularisation::{ |
|
| 104 Regularisation, |
|
| 105 RadonRegTerm, |
|
| 106 NonnegRadonRegTerm |
|
| 107 }; |
|
| 108 use crate::dataterm::{ |
|
| 109 L1, |
|
| 110 L2Squared, |
|
| 111 }; |
|
| 112 use crate::prox_penalty::{ |
|
| 113 RadonSquared, |
|
| 114 //ProxPenalty, |
|
| 115 }; |
|
| 116 use alg_tools::norms::{L2, NormExponent}; |
|
| 117 use alg_tools::operator_arithmetic::Weighted; |
54 use alg_tools::operator_arithmetic::Weighted; |
| |
55 use alg_tools::sets::Cube; |
| |
56 use alg_tools::sets::SetOrd; |
| |
57 use alg_tools::tabledump::TableDump; |
| 118 use anyhow::anyhow; |
58 use anyhow::anyhow; |
| 119 |
59 use chrono::{DateTime, Utc}; |
| 120 /// Available proximal terms |
60 use clap::ValueEnum; |
| 121 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
61 use colored::Colorize; |
| 122 pub enum ProxTerm { |
62 use cpu_time::ProcessTime; |
| 123 /// Partial-to-wave operator 𝒟. |
63 use nalgebra::base::DVector; |
| 124 Wave, |
64 use numeric_literals::replace_float_literals; |
| 125 /// Radon-norm squared |
65 use rand::prelude::{SeedableRng, StdRng}; |
| 126 RadonSquared |
66 use rand_distr::Distribution; |
| 127 } |
67 use serde::{Deserialize, Serialize}; |
| |
68 use serde_json; |
| |
69 use std::collections::HashMap; |
| |
70 use std::hash::Hash; |
| |
71 use std::time::Instant; |
| |
72 use thiserror::Error; |
| |
73 |
| |
74 //#[cfg(feature = "pyo3")] |
| |
75 //use pyo3::pyclass; |
| 128 |
76 |
| 129 /// Available algorithms and their configurations |
77 /// Available algorithms and their configurations |
| 130 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
78 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
| 131 pub enum AlgorithmConfig<F : Float> { |
79 pub enum AlgorithmConfig<F: Float> { |
| 132 FB(FBConfig<F>, ProxTerm), |
80 FB(FBConfig<F>, ProxTerm), |
| 133 FISTA(FBConfig<F>, ProxTerm), |
81 FISTA(FBConfig<F>, ProxTerm), |
| 134 FW(FWConfig<F>), |
82 FW(FWConfig<F>), |
| 135 PDPS(PDPSConfig<F>, ProxTerm), |
83 PDPS(PDPSConfig<F>, ProxTerm), |
| 136 SlidingFB(SlidingFBConfig<F>, ProxTerm), |
84 SlidingFB(SlidingFBConfig<F>, ProxTerm), |
| 137 ForwardPDPS(ForwardPDPSConfig<F>, ProxTerm), |
85 ForwardPDPS(ForwardPDPSConfig<F>, ProxTerm), |
| 138 SlidingPDPS(SlidingPDPSConfig<F>, ProxTerm), |
86 SlidingPDPS(SlidingPDPSConfig<F>, ProxTerm), |
| 139 } |
87 } |
| 140 |
88 |
| 141 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { |
89 fn unpack_tolerance<F: Float>(v: &Vec<F>) -> Tolerance<F> { |
| 142 assert!(v.len() == 3); |
90 assert!(v.len() == 3); |
| 143 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } |
91 Tolerance::Power { initial: v[0], factor: v[1], exponent: v[2] } |
| 144 } |
92 } |
| 145 |
93 |
| 146 impl<F : ClapFloat> AlgorithmConfig<F> { |
94 impl<F: ClapFloat> AlgorithmConfig<F> { |
| 147 /// Override supported parameters based on the command line. |
95 /// Override supported parameters based on the command line. |
| 148 pub fn cli_override(self, cli : &AlgorithmOverrides<F>) -> Self { |
96 pub fn cli_override(self, cli: &AlgorithmOverrides<F>) -> Self { |
| 149 let override_merging = |g : SpikeMergingMethod<F>| { |
97 let override_merging = |g: SpikeMergingMethod<F>| SpikeMergingMethod { |
| 150 SpikeMergingMethod { |
98 enabled: cli.merge.unwrap_or(g.enabled), |
| 151 enabled : cli.merge.unwrap_or(g.enabled), |
99 radius: cli.merge_radius.unwrap_or(g.radius), |
| 152 radius : cli.merge_radius.unwrap_or(g.radius), |
100 interp: cli.merge_interp.unwrap_or(g.interp), |
| 153 interp : cli.merge_interp.unwrap_or(g.interp), |
|
| 154 } |
|
| 155 }; |
101 }; |
| 156 let override_fb_generic = |g : FBGenericConfig<F>| { |
102 let override_fb_generic = |g: InsertionConfig<F>| InsertionConfig { |
| 157 FBGenericConfig { |
103 bootstrap_insertions: cli |
| 158 bootstrap_insertions : cli.bootstrap_insertions |
104 .bootstrap_insertions |
| 159 .as_ref() |
105 .as_ref() |
| 160 .map_or(g.bootstrap_insertions, |
106 .map_or(g.bootstrap_insertions, |n| Some((n[0], n[1]))), |
| 161 |n| Some((n[0], n[1]))), |
107 merge_every: cli.merge_every.unwrap_or(g.merge_every), |
| 162 merge_every : cli.merge_every.unwrap_or(g.merge_every), |
108 merging: override_merging(g.merging), |
| 163 merging : override_merging(g.merging), |
109 final_merging: cli.final_merging.unwrap_or(g.final_merging), |
| 164 final_merging : cli.final_merging.unwrap_or(g.final_merging), |
110 fitness_merging: cli.fitness_merging.unwrap_or(g.fitness_merging), |
| 165 fitness_merging : cli.fitness_merging.unwrap_or(g.fitness_merging), |
111 tolerance: cli |
| 166 tolerance: cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(g.tolerance), |
112 .tolerance |
| 167 .. g |
113 .as_ref() |
| 168 } |
114 .map(unpack_tolerance) |
| |
115 .unwrap_or(g.tolerance), |
| |
116 ..g |
| 169 }; |
117 }; |
| 170 let override_transport = |g : TransportConfig<F>| { |
118 let override_transport = |g: TransportConfig<F>| TransportConfig { |
| 171 TransportConfig { |
119 θ0: cli.theta0.unwrap_or(g.θ0), |
| 172 θ0 : cli.theta0.unwrap_or(g.θ0), |
120 tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con), |
| 173 tolerance_mult_con: cli.transport_tolerance_pos.unwrap_or(g.tolerance_mult_con), |
121 adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), |
| 174 adaptation: cli.transport_adaptation.unwrap_or(g.adaptation), |
122 ..g |
| 175 .. g |
|
| 176 } |
|
| 177 }; |
123 }; |
| 178 |
124 |
| 179 use AlgorithmConfig::*; |
125 use AlgorithmConfig::*; |
| 180 match self { |
126 match self { |
| 181 FB(fb, prox) => FB(FBConfig { |
127 FB(fb, prox) => FB( |
| 182 τ0 : cli.tau0.unwrap_or(fb.τ0), |
128 FBConfig { |
| 183 generic : override_fb_generic(fb.generic), |
129 τ0: cli.tau0.unwrap_or(fb.τ0), |
| 184 .. fb |
130 σp0: cli.sigmap0.unwrap_or(fb.σp0), |
| 185 }, prox), |
131 insertion: override_fb_generic(fb.insertion), |
| 186 FISTA(fb, prox) => FISTA(FBConfig { |
132 ..fb |
| 187 τ0 : cli.tau0.unwrap_or(fb.τ0), |
133 }, |
| 188 generic : override_fb_generic(fb.generic), |
134 prox, |
| 189 .. fb |
135 ), |
| 190 }, prox), |
136 FISTA(fb, prox) => FISTA( |
| 191 PDPS(pdps, prox) => PDPS(PDPSConfig { |
137 FBConfig { |
| 192 τ0 : cli.tau0.unwrap_or(pdps.τ0), |
138 τ0: cli.tau0.unwrap_or(fb.τ0), |
| 193 σ0 : cli.sigma0.unwrap_or(pdps.σ0), |
139 σp0: cli.sigmap0.unwrap_or(fb.σp0), |
| 194 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
140 insertion: override_fb_generic(fb.insertion), |
| 195 generic : override_fb_generic(pdps.generic), |
141 ..fb |
| 196 .. pdps |
142 }, |
| 197 }, prox), |
143 prox, |
| |
144 ), |
| |
145 PDPS(pdps, prox) => PDPS( |
| |
146 PDPSConfig { |
| |
147 τ0: cli.tau0.unwrap_or(pdps.τ0), |
| |
148 σ0: cli.sigma0.unwrap_or(pdps.σ0), |
| |
149 acceleration: cli.acceleration.unwrap_or(pdps.acceleration), |
| |
150 generic: override_fb_generic(pdps.generic), |
| |
151 ..pdps |
| |
152 }, |
| |
153 prox, |
| |
154 ), |
| 198 FW(fw) => FW(FWConfig { |
155 FW(fw) => FW(FWConfig { |
| 199 merging : override_merging(fw.merging), |
156 merging: override_merging(fw.merging), |
| 200 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), |
157 tolerance: cli |
| 201 .. fw |
158 .tolerance |
| |
159 .as_ref() |
| |
160 .map(unpack_tolerance) |
| |
161 .unwrap_or(fw.tolerance), |
| |
162 ..fw |
| 202 }), |
163 }), |
| 203 SlidingFB(sfb, prox) => SlidingFB(SlidingFBConfig { |
164 SlidingFB(sfb, prox) => SlidingFB( |
| 204 τ0 : cli.tau0.unwrap_or(sfb.τ0), |
165 SlidingFBConfig { |
| 205 transport : override_transport(sfb.transport), |
166 τ0: cli.tau0.unwrap_or(sfb.τ0), |
| 206 insertion : override_fb_generic(sfb.insertion), |
167 σp0: cli.sigmap0.unwrap_or(sfb.σp0), |
| 207 .. sfb |
168 transport: override_transport(sfb.transport), |
| 208 }, prox), |
169 insertion: override_fb_generic(sfb.insertion), |
| 209 SlidingPDPS(spdps, prox) => SlidingPDPS(SlidingPDPSConfig { |
170 ..sfb |
| 210 τ0 : cli.tau0.unwrap_or(spdps.τ0), |
171 }, |
| 211 σp0 : cli.sigmap0.unwrap_or(spdps.σp0), |
172 prox, |
| 212 σd0 : cli.sigma0.unwrap_or(spdps.σd0), |
173 ), |
| 213 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
174 SlidingPDPS(spdps, prox) => SlidingPDPS( |
| 214 transport : override_transport(spdps.transport), |
175 SlidingPDPSConfig { |
| 215 insertion : override_fb_generic(spdps.insertion), |
176 τ0: cli.tau0.unwrap_or(spdps.τ0), |
| 216 .. spdps |
177 σp0: cli.sigmap0.unwrap_or(spdps.σp0), |
| 217 }, prox), |
178 σd0: cli.sigma0.unwrap_or(spdps.σd0), |
| 218 ForwardPDPS(fpdps, prox) => ForwardPDPS(ForwardPDPSConfig { |
179 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
| 219 τ0 : cli.tau0.unwrap_or(fpdps.τ0), |
180 transport: override_transport(spdps.transport), |
| 220 σp0 : cli.sigmap0.unwrap_or(fpdps.σp0), |
181 insertion: override_fb_generic(spdps.insertion), |
| 221 σd0 : cli.sigma0.unwrap_or(fpdps.σd0), |
182 ..spdps |
| 222 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
183 }, |
| 223 insertion : override_fb_generic(fpdps.insertion), |
184 prox, |
| 224 .. fpdps |
185 ), |
| 225 }, prox), |
186 ForwardPDPS(fpdps, prox) => ForwardPDPS( |
| |
187 ForwardPDPSConfig { |
| |
188 τ0: cli.tau0.unwrap_or(fpdps.τ0), |
| |
189 σp0: cli.sigmap0.unwrap_or(fpdps.σp0), |
| |
190 σd0: cli.sigma0.unwrap_or(fpdps.σd0), |
| |
191 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
| |
192 insertion: override_fb_generic(fpdps.insertion), |
| |
193 ..fpdps |
| |
194 }, |
| |
195 prox, |
| |
196 ), |
| 226 } |
197 } |
| 227 } |
198 } |
| 228 } |
199 } |
| 229 |
200 |
| 230 /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name. |
201 /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name. |
| 231 #[derive(Clone, Debug, Serialize, Deserialize)] |
202 #[derive(Clone, Debug, Serialize, Deserialize)] |
| 232 pub struct Named<Data> { |
203 pub struct Named<Data> { |
| 233 pub name : String, |
204 pub name: String, |
| 234 #[serde(flatten)] |
205 #[serde(flatten)] |
| 235 pub data : Data, |
206 pub data: Data, |
| 236 } |
207 } |
| 237 |
208 |
| 238 /// Shorthand algorithm configurations, to be used with the command line parser |
209 /// Shorthand algorithm configurations, to be used with the command line parser |
| 239 #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] |
210 #[derive(ValueEnum, Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] |
| |
211 //#[cfg_attr(feature = "pyo3", pyclass(module = "pointsource_algs"))] |
| 240 pub enum DefaultAlgorithm { |
212 pub enum DefaultAlgorithm { |
| 241 /// The μFB forward-backward method |
213 /// The μFB forward-backward method |
| 242 #[clap(name = "fb")] |
214 #[clap(name = "fb")] |
| 243 FB, |
215 FB, |
| 244 /// The μFISTA inertial forward-backward method |
216 /// The μFISTA inertial forward-backward method |
| 285 RadonForwardPDPS, |
256 RadonForwardPDPS, |
| 286 } |
257 } |
| 287 |
258 |
| 288 impl DefaultAlgorithm { |
259 impl DefaultAlgorithm { |
| 289 /// Returns the algorithm configuration corresponding to the algorithm shorthand |
260 /// Returns the algorithm configuration corresponding to the algorithm shorthand |
| 290 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { |
261 pub fn default_config<F: Float>(&self) -> AlgorithmConfig<F> { |
| 291 use DefaultAlgorithm::*; |
262 use DefaultAlgorithm::*; |
| 292 let radon_insertion = FBGenericConfig { |
263 let radon_insertion = InsertionConfig { |
| 293 merging : SpikeMergingMethod{ interp : false, .. Default::default() }, |
264 merging: SpikeMergingMethod { interp: false, ..Default::default() }, |
| 294 inner : InnerSettings { |
265 inner: InnerSettings { |
| 295 method : InnerMethod::PDPS, // SSN not implemented |
266 method: InnerMethod::PDPS, // SSN not implemented |
| 296 .. Default::default() |
267 ..Default::default() |
| 297 }, |
268 }, |
| 298 .. Default::default() |
269 ..Default::default() |
| 299 }; |
270 }; |
| 300 match *self { |
271 match *self { |
| 301 FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave), |
272 FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave), |
| 302 FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave), |
273 FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave), |
| 303 FW => AlgorithmConfig::FW(Default::default()), |
274 FW => AlgorithmConfig::FW(Default::default()), |
| 304 FWRelax => AlgorithmConfig::FW(FWConfig{ |
275 FWRelax => { |
| 305 variant : FWVariant::Relaxed, |
276 AlgorithmConfig::FW(FWConfig { variant: FWVariant::Relaxed, ..Default::default() }) |
| 306 .. Default::default() |
277 } |
| 307 }), |
|
| 308 PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave), |
278 PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave), |
| 309 SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave), |
279 SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave), |
| 310 SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave), |
280 SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave), |
| 311 ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave), |
281 ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave), |
| 312 |
282 |
| 313 // Radon variants |
283 // Radon variants |
| 314 |
|
| 315 RadonFB => AlgorithmConfig::FB( |
284 RadonFB => AlgorithmConfig::FB( |
| 316 FBConfig{ generic : radon_insertion, ..Default::default() }, |
285 FBConfig { insertion: radon_insertion, ..Default::default() }, |
| 317 ProxTerm::RadonSquared |
286 ProxTerm::RadonSquared, |
| 318 ), |
287 ), |
| 319 RadonFISTA => AlgorithmConfig::FISTA( |
288 RadonFISTA => AlgorithmConfig::FISTA( |
| 320 FBConfig{ generic : radon_insertion, ..Default::default() }, |
289 FBConfig { insertion: radon_insertion, ..Default::default() }, |
| 321 ProxTerm::RadonSquared |
290 ProxTerm::RadonSquared, |
| 322 ), |
291 ), |
| 323 RadonPDPS => AlgorithmConfig::PDPS( |
292 RadonPDPS => AlgorithmConfig::PDPS( |
| 324 PDPSConfig{ generic : radon_insertion, ..Default::default() }, |
293 PDPSConfig { generic: radon_insertion, ..Default::default() }, |
| 325 ProxTerm::RadonSquared |
294 ProxTerm::RadonSquared, |
| 326 ), |
295 ), |
| 327 RadonSlidingFB => AlgorithmConfig::SlidingFB( |
296 RadonSlidingFB => AlgorithmConfig::SlidingFB( |
| 328 SlidingFBConfig{ insertion : radon_insertion, ..Default::default() }, |
297 SlidingFBConfig { insertion: radon_insertion, ..Default::default() }, |
| 329 ProxTerm::RadonSquared |
298 ProxTerm::RadonSquared, |
| 330 ), |
299 ), |
| 331 RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS( |
300 RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS( |
| 332 SlidingPDPSConfig{ insertion : radon_insertion, ..Default::default() }, |
301 SlidingPDPSConfig { insertion: radon_insertion, ..Default::default() }, |
| 333 ProxTerm::RadonSquared |
302 ProxTerm::RadonSquared, |
| 334 ), |
303 ), |
| 335 RadonForwardPDPS => AlgorithmConfig::ForwardPDPS( |
304 RadonForwardPDPS => AlgorithmConfig::ForwardPDPS( |
| 336 ForwardPDPSConfig{ insertion : radon_insertion, ..Default::default() }, |
305 ForwardPDPSConfig { insertion: radon_insertion, ..Default::default() }, |
| 337 ProxTerm::RadonSquared |
306 ProxTerm::RadonSquared, |
| 338 ), |
307 ), |
| 339 } |
308 } |
| 340 } |
309 } |
| 341 |
310 |
| 342 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand |
311 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand |
| 343 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { |
312 pub fn get_named<F: Float>(&self) -> Named<AlgorithmConfig<F>> { |
| 344 self.to_named(self.default_config()) |
313 self.to_named(self.default_config()) |
| 345 } |
314 } |
| 346 |
315 |
| 347 pub fn to_named<F : Float>(self, alg : AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { |
316 pub fn to_named<F: Float>(self, alg: AlgorithmConfig<F>) -> Named<AlgorithmConfig<F>> { |
| 348 let name = self.to_possible_value().unwrap().get_name().to_string(); |
317 Named { name: self.name(), data: alg } |
| 349 Named{ name , data : alg } |
318 } |
| 350 } |
319 |
| 351 } |
320 pub fn name(self) -> String { |
| 352 |
321 self.to_possible_value().unwrap().get_name().to_string() |
| |
322 } |
| |
323 } |
| 353 |
324 |
| 354 // // Floats cannot be hashed directly, so just hash the debug formatting |
325 // // Floats cannot be hashed directly, so just hash the debug formatting |
| 355 // // for use as file identifier. |
326 // // for use as file identifier. |
| 356 // impl<F : Float> Hash for AlgorithmConfig<F> { |
327 // impl<F : Float> Hash for AlgorithmConfig<F> { |
| 357 // fn hash<H: Hasher>(&self, state: &mut H) { |
328 // fn hash<H: Hasher>(&self, state: &mut H) { |
| 377 fn default() -> Self { |
348 fn default() -> Self { |
| 378 Self::Data |
349 Self::Data |
| 379 } |
350 } |
| 380 } |
351 } |
| 381 |
352 |
| 382 type DefaultBT<F, const N : usize> = BT< |
353 type DefaultBT<F, const N: usize> = BT<DynamicDepth, F, usize, Bounds<F>, N>; |
| 383 DynamicDepth, |
354 type DefaultSeminormOp<F, K, const N: usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>; |
| 384 F, |
355 type DefaultSG<F, Sensor, Spread, const N: usize> = |
| 385 usize, |
356 SensorGrid<F, Sensor, Spread, DefaultBT<F, N>, N>; |
| 386 Bounds<F>, |
|
| 387 N |
|
| 388 >; |
|
| 389 type DefaultSeminormOp<F, K, const N : usize> = ConvolutionOp<F, K, DefaultBT<F, N>, N>; |
|
| 390 type DefaultSG<F, Sensor, Spread, const N : usize> = SensorGrid::< |
|
| 391 F, |
|
| 392 Sensor, |
|
| 393 Spread, |
|
| 394 DefaultBT<F, N>, |
|
| 395 N |
|
| 396 >; |
|
| 397 |
357 |
| 398 /// This is a dirty workaround to rust-csv not supporting struct flattening etc. |
358 /// This is a dirty workaround to rust-csv not supporting struct flattening etc. |
| 399 #[derive(Serialize)] |
359 #[derive(Serialize)] |
| 400 struct CSVLog<F> { |
360 struct CSVLog<F> { |
| 401 iter : usize, |
361 iter: usize, |
| 402 cpu_time : f64, |
362 cpu_time: f64, |
| 403 value : F, |
363 value: F, |
| 404 relative_value : F, |
364 relative_value: F, |
| 405 //post_value : F, |
365 //post_value : F, |
| 406 n_spikes : usize, |
366 n_spikes: usize, |
| 407 inner_iters : usize, |
367 inner_iters: usize, |
| 408 merged : usize, |
368 merged: usize, |
| 409 pruned : usize, |
369 pruned: usize, |
| 410 this_iters : usize, |
370 this_iters: usize, |
| |
371 epsilon: F, |
| 411 } |
372 } |
| 412 |
373 |
| 413 /// Collected experiment statistics |
374 /// Collected experiment statistics |
| 414 #[derive(Clone, Debug, Serialize)] |
375 #[derive(Clone, Debug, Serialize)] |
| 415 struct ExperimentStats<F : Float> { |
376 struct ExperimentStats<F: Float> { |
| 416 /// Signal-to-noise ratio in decibels |
377 /// Signal-to-noise ratio in decibels |
| 417 ssnr : F, |
378 ssnr: F, |
| 418 /// Proportion of noise in the signal as a number in $[0, 1]$. |
379 /// Proportion of noise in the signal as a number in $[0, 1]$. |
| 419 noise_ratio : F, |
380 noise_ratio: F, |
| 420 /// When the experiment was run (UTC) |
381 /// When the experiment was run (UTC) |
| 421 when : DateTime<Utc>, |
382 when: DateTime<Utc>, |
| 422 } |
383 } |
| 423 |
384 |
| 424 #[replace_float_literals(F::cast_from(literal))] |
385 #[replace_float_literals(F::cast_from(literal))] |
| 425 impl<F : Float> ExperimentStats<F> { |
386 impl<F: Float> ExperimentStats<F> { |
| 426 /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. |
387 /// Calculate [`ExperimentStats`] based on a noisy `signal` and the separated `noise` signal. |
| 427 fn new<E : Euclidean<F>>(signal : &E, noise : &E) -> Self { |
388 fn new<E: Euclidean<F>>(signal: &E, noise: &E) -> Self { |
| 428 let s = signal.norm2_squared(); |
389 let s = signal.norm2_squared(); |
| 429 let n = noise.norm2_squared(); |
390 let n = noise.norm2_squared(); |
| 430 let noise_ratio = (n / s).sqrt(); |
391 let noise_ratio = (n / s).sqrt(); |
| 431 let ssnr = 10.0 * (s / n).log10(); |
392 let ssnr = 10.0 * (s / n).log10(); |
| 432 ExperimentStats { |
393 ExperimentStats { ssnr, noise_ratio, when: Utc::now() } |
| 433 ssnr, |
|
| 434 noise_ratio, |
|
| 435 when : Utc::now(), |
|
| 436 } |
|
| 437 } |
394 } |
| 438 } |
395 } |
| 439 /// Collected algorithm statistics |
396 /// Collected algorithm statistics |
| 440 #[derive(Clone, Debug, Serialize)] |
397 #[derive(Clone, Debug, Serialize)] |
| 441 struct AlgorithmStats<F : Float> { |
398 struct AlgorithmStats<F: Float> { |
| 442 /// Overall CPU time spent |
399 /// Overall CPU time spent |
| 443 cpu_time : F, |
400 cpu_time: F, |
| 444 /// Real time spent |
401 /// Real time spent |
| 445 elapsed : F |
402 elapsed: F, |
| 446 } |
403 } |
| 447 |
|
| 448 |
404 |
| 449 /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input |
405 /// A wrapper for [`serde_json::to_writer_pretty`] that takes a filename as input |
| 450 /// and outputs a [`DynError`]. |
406 /// and outputs a [`DynError`]. |
| 451 fn write_json<T : Serialize>(filename : String, data : &T) -> DynError { |
407 fn write_json<T: Serialize>(filename: String, data: &T) -> DynError { |
| 452 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; |
408 serde_json::to_writer_pretty(std::fs::File::create(filename)?, data)?; |
| 453 Ok(()) |
409 Ok(()) |
| 454 } |
410 } |
| 455 |
411 |
| 456 |
|
| 457 /// Struct for experiment configurations |
412 /// Struct for experiment configurations |
| 458 #[derive(Debug, Clone, Serialize)] |
413 #[derive(Debug, Clone, Serialize)] |
| 459 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N : usize> |
414 pub struct ExperimentV2<F, NoiseDistr, S, K, P, const N: usize> |
| 460 where F : Float + ClapFloat, |
415 where |
| 461 [usize; N] : Serialize, |
416 F: Float + ClapFloat, |
| 462 NoiseDistr : Distribution<F>, |
417 [usize; N]: Serialize, |
| 463 S : Sensor<F, N>, |
418 NoiseDistr: Distribution<F>, |
| 464 P : Spread<F, N>, |
419 S: Sensor<N, F>, |
| 465 K : SimpleConvolutionKernel<F, N>, |
420 P: Spread<N, F>, |
| |
421 K: SimpleConvolutionKernel<N, F>, |
| 466 { |
422 { |
| 467 /// Domain $Ω$. |
423 /// Domain $Ω$. |
| 468 pub domain : Cube<F, N>, |
424 pub domain: Cube<N, F>, |
| 469 /// Number of sensors along each dimension |
425 /// Number of sensors along each dimension |
| 470 pub sensor_count : [usize; N], |
426 pub sensor_count: [usize; N], |
| 471 /// Noise distribution |
427 /// Noise distribution |
| 472 pub noise_distr : NoiseDistr, |
428 pub noise_distr: NoiseDistr, |
| 473 /// Seed for random noise generation (for repeatable experiments) |
429 /// Seed for random noise generation (for repeatable experiments) |
| 474 pub noise_seed : u64, |
430 pub noise_seed: u64, |
| 475 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. |
431 /// Sensor $θ$; $θ * ψ$ forms the forward operator $𝒜$. |
| 476 pub sensor : S, |
432 pub sensor: S, |
| 477 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. |
433 /// Spread $ψ$; $θ * ψ$ forms the forward operator $𝒜$. |
| 478 pub spread : P, |
434 pub spread: P, |
| 479 /// Kernel $ρ$ of $𝒟$. |
435 /// Kernel $ρ$ of $𝒟$. |
| 480 pub kernel : K, |
436 pub kernel: K, |
| 481 /// True point sources |
437 /// True point sources |
| 482 pub μ_hat : RNDM<F, N>, |
438 pub μ_hat: RNDM<N, F>, |
| 483 /// Regularisation term and parameter |
439 /// Regularisation term and parameter |
| 484 pub regularisation : Regularisation<F>, |
440 pub regularisation: Regularisation<F>, |
| 485 /// For plotting : how wide should the kernels be plotted |
441 /// For plotting : how wide should the kernels be plotted |
| 486 pub kernel_plot_width : F, |
442 pub kernel_plot_width: F, |
| 487 /// Data term |
443 /// Data term |
| 488 pub dataterm : DataTerm, |
444 pub dataterm: DataTermType, |
| 489 /// A map of default configurations for algorithms |
445 /// A map of default configurations for algorithms |
| 490 pub algorithm_overrides : HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>, |
446 pub algorithm_overrides: HashMap<DefaultAlgorithm, AlgorithmOverrides<F>>, |
| 491 /// Default merge radius |
447 /// Default merge radius |
| 492 pub default_merge_radius : F, |
448 pub default_merge_radius: F, |
| 493 } |
449 } |
| 494 |
450 |
| 495 #[derive(Debug, Clone, Serialize)] |
451 #[derive(Debug, Clone, Serialize)] |
| 496 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N : usize> |
452 pub struct ExperimentBiased<F, NoiseDistr, S, K, P, B, const N: usize> |
| 497 where F : Float + ClapFloat, |
453 where |
| 498 [usize; N] : Serialize, |
454 F: Float + ClapFloat, |
| 499 NoiseDistr : Distribution<F>, |
455 [usize; N]: Serialize, |
| 500 S : Sensor<F, N>, |
456 NoiseDistr: Distribution<F>, |
| 501 P : Spread<F, N>, |
457 S: Sensor<N, F>, |
| 502 K : SimpleConvolutionKernel<F, N>, |
458 P: Spread<N, F>, |
| 503 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, |
459 K: SimpleConvolutionKernel<N, F>, |
| |
460 B: Mapping<Loc<N, F>, Codomain = F> + Serialize + std::fmt::Debug, |
| 504 { |
461 { |
| 505 /// Basic setup |
462 /// Basic setup |
| 506 pub base : ExperimentV2<F, NoiseDistr, S, K, P, N>, |
463 pub base: ExperimentV2<F, NoiseDistr, S, K, P, N>, |
| 507 /// Weight of TV term |
464 /// Weight of TV term |
| 508 pub λ : F, |
465 pub λ: F, |
| 509 /// Bias function |
466 /// Bias function |
| 510 pub bias : B, |
467 pub bias: B, |
| 511 } |
468 } |
| 512 |
469 |
| 513 /// Trait for runnable experiments |
470 /// Trait for runnable experiments |
| 514 pub trait RunnableExperiment<F : ClapFloat> { |
471 pub trait RunnableExperiment<F: ClapFloat> { |
| 515 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. |
472 /// Run all algorithms provided, or default algorithms if none provided, on the experiment. |
| 516 fn runall(&self, cli : &CommandLineArgs, |
473 fn runall( |
| 517 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError; |
474 &self, |
| |
475 cli: &CommandLineArgs, |
| |
476 algs: Option<Vec<Named<AlgorithmConfig<F>>>>, |
| |
477 ) -> DynError; |
| 518 |
478 |
| 519 /// Return algorithm default config |
479 /// Return algorithm default config |
| 520 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F>; |
480 fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F>; |
| 521 } |
481 |
| 522 |
482 /// Experiment name |
| 523 /// Helper function to print experiment start message and save setup. |
483 fn name(&self) -> &str; |
| 524 /// Returns saving prefix. |
484 } |
| 525 fn start_experiment<E, S>( |
485 |
| 526 experiment : &Named<E>, |
486 /// Error codes for running an algorithm on an experiment. |
| 527 cli : &CommandLineArgs, |
487 #[derive(Error, Debug)] |
| 528 stats : S, |
488 pub enum RunError { |
| 529 ) -> DynResult<String> |
489 /// Algorithm not implemented for this experiment |
| |
490 #[error("Algorithm not implemented for this experiment")] |
| |
491 NotImplemented, |
| |
492 } |
| |
493 |
| |
494 use RunError::*; |
| |
495 |
| |
496 type DoRunAllIt<'a, F, const N: usize> = LoggingIteratorFactory< |
| |
497 'a, |
| |
498 Timed<IterInfo<F>>, |
| |
499 TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F>>>, |
| |
500 >; |
| |
501 |
| |
502 pub trait RunnableExperimentExtras<F: ClapFloat>: |
| |
503 RunnableExperiment<F> + Serialize + Sized |
| |
504 { |
| |
505 /// Helper function to print experiment start message and save setup. |
| |
506 /// Returns saving prefix. |
| |
507 fn start(&self, cli: &CommandLineArgs) -> DynResult<String> { |
| |
508 let experiment_name = self.name(); |
| |
509 let ser = serde_json::to_string(self); |
| |
510 |
| |
511 println!( |
| |
512 "{}\n{}", |
| |
513 format!("Performing experiment {}…", experiment_name).cyan(), |
| |
514 format!( |
| |
515 "Experiment settings: {}", |
| |
516 if let Ok(ref s) = ser { |
| |
517 s |
| |
518 } else { |
| |
519 "<serialisation failure>" |
| |
520 } |
| |
521 ) |
| |
522 .bright_black(), |
| |
523 ); |
| |
524 |
| |
525 // Set up output directory |
| |
526 let prefix = format!("{}/{}/", cli.outdir, experiment_name); |
| |
527 |
| |
528 // Save experiment configuration and statistics |
| |
529 std::fs::create_dir_all(&prefix)?; |
| |
530 write_json(format!("{prefix}experiment.json"), self)?; |
| |
531 write_json(format!("{prefix}config.json"), cli)?; |
| |
532 |
| |
533 Ok(prefix) |
| |
534 } |
| |
535 |
| |
536 /// Helper function to run all algorithms on an experiment. |
| |
537 fn do_runall<P, Z, Plot, const N: usize>( |
| |
538 &self, |
| |
539 prefix: &String, |
| |
540 cli: &CommandLineArgs, |
| |
541 algorithms: Vec<Named<AlgorithmConfig<F>>>, |
| |
542 mut make_plotter: impl FnMut(String) -> Plot, |
| |
543 mut save_extra: impl FnMut(String, Z) -> DynError, |
| |
544 init: P, |
| |
545 mut do_alg: impl FnMut( |
| |
546 (&AlgorithmConfig<F>, DoRunAllIt<F, N>, Plot, P, String), |
| |
547 ) -> DynResult<(RNDM<N, F>, Z)>, |
| |
548 ) -> DynError |
| |
549 where |
| |
550 F: for<'b> Deserialize<'b>, |
| |
551 PlotLookup: Plotting<N>, |
| |
552 P: Clone, |
| |
553 { |
| |
554 let experiment_name = self.name(); |
| |
555 |
| |
556 let mut logs = Vec::new(); |
| |
557 |
| |
558 let iterator_options = AlgIteratorOptions { |
| |
559 max_iter: cli.max_iter, |
| |
560 verbose_iter: cli |
| |
561 .verbose_iter |
| |
562 .map_or(Verbose::LogarithmicCap { base: 10, cap: 2 }, |n| { |
| |
563 Verbose::Every(n) |
| |
564 }), |
| |
565 quiet: cli.quiet, |
| |
566 }; |
| |
567 |
| |
568 // Run the algorithm(s) |
| |
569 for named @ Named { name: alg_name, data: alg } in algorithms.iter() { |
| |
570 let this_prefix = format!("{}{}/", prefix, alg_name); |
| |
571 |
| |
572 // Create Logger and IteratorFactory |
| |
573 let mut logger = Logger::new(); |
| |
574 let iterator = iterator_options.instantiate().timed().into_log(&mut logger); |
| |
575 |
| |
576 let running = if !cli.quiet { |
| |
577 format!( |
| |
578 "{}\n{}\n{}\n", |
| |
579 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), |
| |
580 format!( |
| |
581 "Iteration settings: {}", |
| |
582 serde_json::to_string(&iterator_options)? |
| |
583 ) |
| |
584 .bright_black(), |
| |
585 format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black() |
| |
586 ) |
| |
587 } else { |
| |
588 "".to_string() |
| |
589 }; |
| |
590 // |
| |
591 // The following is for postprocessing, which has been disabled anyway. |
| |
592 // |
| |
593 // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation { |
| |
594 // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), |
| |
595 // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), |
| |
596 // }; |
| |
597 //let findim_data = reg.prepare_optimise_weights(&opA, &b); |
| |
598 //let inner_config : InnerSettings<F> = Default::default(); |
| |
599 //let inner_it = inner_config.iterator_options; |
| |
600 |
| |
601 // Create plotter and directory if needed. |
| |
602 let plotter = make_plotter(this_prefix); |
| |
603 |
| |
604 let start = Instant::now(); |
| |
605 let start_cpu = ProcessTime::now(); |
| |
606 |
| |
607 let (μ, z) = match do_alg((alg, iterator, plotter, init.clone(), running)) { |
| |
608 Ok(μ) => μ, |
| |
609 Err(e) => { |
| |
610 let msg = format!( |
| |
611 "Skipping algorithm “{alg_name}” for {experiment_name} due to error: {e}" |
| |
612 ) |
| |
613 .red(); |
| |
614 eprintln!("{}", msg); |
| |
615 continue; |
| |
616 } |
| |
617 }; |
| |
618 |
| |
619 let elapsed = start.elapsed().as_secs_f64(); |
| |
620 let cpu_time = start_cpu.elapsed().as_secs_f64(); |
| |
621 |
| |
622 println!( |
| |
623 "{}", |
| |
624 format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow() |
| |
625 ); |
| |
626 |
| |
627 // Save results |
| |
628 println!("{}", "Saving results …".green()); |
| |
629 |
| |
630 let mkname = |t| format!("{prefix}{alg_name}_{t}"); |
| |
631 |
| |
632 write_json(mkname("config.json"), &named)?; |
| |
633 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; |
| |
634 μ.write_csv(mkname("reco.txt"))?; |
| |
635 save_extra(mkname(""), z)?; |
| |
636 //logger.write_csv(mkname("log.txt"))?; |
| |
637 logs.push((mkname("log.txt"), logger)); |
| |
638 } |
| |
639 |
| |
640 save_logs( |
| |
641 logs, |
| |
642 format!("{prefix}valuerange.json"), |
| |
643 cli.load_valuerange, |
| |
644 ) |
| |
645 } |
| |
646 } |
| |
647 |
| |
648 impl<F, E> RunnableExperimentExtras<F> for E |
| 530 where |
649 where |
| 531 E : Serialize + std::fmt::Debug, |
650 F: ClapFloat, |
| 532 S : Serialize, |
651 Self: RunnableExperiment<F> + Serialize, |
| 533 { |
652 { |
| 534 let Named { name : experiment_name, data } = experiment; |
653 } |
| 535 |
654 |
| 536 println!("{}\n{}", |
655 #[replace_float_literals(F::cast_from(literal))] |
| 537 format!("Performing experiment {}…", experiment_name).cyan(), |
656 impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N: usize> RunnableExperiment<F> |
| 538 format!("Experiment settings: {}", serde_json::to_string(&data)?).bright_black()); |
657 for Named<ExperimentV2<F, NoiseDistr, S, K, P, N>> |
| 539 |
|
| 540 // Set up output directory |
|
| 541 let prefix = format!("{}/{}/", cli.outdir, experiment_name); |
|
| 542 |
|
| 543 // Save experiment configuration and statistics |
|
| 544 let mkname_e = |t| format!("{prefix}{t}.json", prefix = prefix, t = t); |
|
| 545 std::fs::create_dir_all(&prefix)?; |
|
| 546 write_json(mkname_e("experiment"), experiment)?; |
|
| 547 write_json(mkname_e("config"), cli)?; |
|
| 548 write_json(mkname_e("stats"), &stats)?; |
|
| 549 |
|
| 550 Ok(prefix) |
|
| 551 } |
|
| 552 |
|
| 553 /// Error codes for running an algorithm on an experiment. |
|
| 554 enum RunError { |
|
| 555 /// Algorithm not implemented for this experiment |
|
| 556 NotImplemented, |
|
| 557 } |
|
| 558 |
|
| 559 use RunError::*; |
|
| 560 |
|
| 561 type DoRunAllIt<'a, F, const N : usize> = LoggingIteratorFactory< |
|
| 562 'a, |
|
| 563 Timed<IterInfo<F, N>>, |
|
| 564 TimingIteratorFactory<BasicAlgIteratorFactory<IterInfo<F, N>>> |
|
| 565 >; |
|
| 566 |
|
| 567 /// Helper function to run all algorithms on an experiment. |
|
| 568 fn do_runall<F : Float + for<'b> Deserialize<'b>, Z, const N : usize>( |
|
| 569 experiment_name : &String, |
|
| 570 prefix : &String, |
|
| 571 cli : &CommandLineArgs, |
|
| 572 algorithms : Vec<Named<AlgorithmConfig<F>>>, |
|
| 573 plotgrid : LinSpace<Loc<F, N>, [usize; N]>, |
|
| 574 mut save_extra : impl FnMut(String, Z) -> DynError, |
|
| 575 mut do_alg : impl FnMut( |
|
| 576 &AlgorithmConfig<F>, |
|
| 577 DoRunAllIt<F, N>, |
|
| 578 SeqPlotter<F, N>, |
|
| 579 String, |
|
| 580 ) -> Result<(RNDM<F, N>, Z), RunError>, |
|
| 581 ) -> DynError |
|
| 582 where |
658 where |
| 583 PlotLookup : Plotting<N>, |
659 F: ClapFloat |
| |
660 + nalgebra::RealField |
| |
661 + ToNalgebraRealField<MixedType = F> |
| |
662 + Default |
| |
663 + for<'b> Deserialize<'b>, |
| |
664 [usize; N]: Serialize, |
| |
665 S: Sensor<N, F> + Copy + Serialize + std::fmt::Debug, |
| |
666 P: Spread<N, F> + Copy + Serialize + std::fmt::Debug, |
| |
667 Convolution<S, P>: Spread<N, F> |
| |
668 + Bounded<F> |
| |
669 + LocalAnalysis<F, Bounds<F>, N> |
| |
670 + Copy |
| |
671 // TODO: shold not have differentiability as a requirement, but |
| |
672 // decide availability of sliding based on it. |
| |
673 //+ for<'b> Differentiable<&'b Loc<N, F>, Output = Loc<N, F>>, |
| |
674 // TODO: very weird that rust only compiles with Differentiable |
| |
675 // instead of the above one on references, which is required by |
| |
676 // poitsource_sliding_fb_reg. |
| |
677 + DifferentiableRealMapping<N, F> |
| |
678 + Lipschitz<L2, FloatType = F>, |
| |
679 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<N, F>>>::Differential<'b>: |
| |
680 Lipschitz<L2, FloatType = F>, // TODO: should not be required generally, only for sliding_fb. |
| |
681 AutoConvolution<P>: BoundedBy<F, K>, |
| |
682 K: SimpleConvolutionKernel<N, F> |
| |
683 + LocalAnalysis<F, Bounds<F>, N> |
| |
684 + Copy |
| |
685 + Serialize |
| |
686 + std::fmt::Debug, |
| |
687 Cube<N, F>: P2Minimise<Loc<N, F>, F> + SetOrd, |
| |
688 PlotLookup: Plotting<N>, |
| |
689 DefaultBT<F, N>: SensorGridBT<F, S, P, N, Depth = DynamicDepth> + BTSearch<N, F>, |
| |
690 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| |
691 RNDM<N, F>: SpikeMerging<F>, |
| |
692 NoiseDistr: Distribution<F> + Serialize + std::fmt::Debug, |
| 584 { |
693 { |
| 585 let mut logs = Vec::new(); |
694 fn name(&self) -> &str { |
| 586 |
695 self.name.as_ref() |
| 587 let iterator_options = AlgIteratorOptions{ |
696 } |
| 588 max_iter : cli.max_iter, |
697 |
| 589 verbose_iter : cli.verbose_iter |
698 fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F> { |
| 590 .map_or(Verbose::LogarithmicCap{base : 10, cap : 2}, |
|
| 591 |n| Verbose::Every(n)), |
|
| 592 quiet : cli.quiet, |
|
| 593 }; |
|
| 594 |
|
| 595 // Run the algorithm(s) |
|
| 596 for named @ Named { name : alg_name, data : alg } in algorithms.iter() { |
|
| 597 let this_prefix = format!("{}{}/", prefix, alg_name); |
|
| 598 |
|
| 599 // Create Logger and IteratorFactory |
|
| 600 let mut logger = Logger::new(); |
|
| 601 let iterator = iterator_options.instantiate() |
|
| 602 .timed() |
|
| 603 .into_log(&mut logger); |
|
| 604 |
|
| 605 let running = if !cli.quiet { |
|
| 606 format!("{}\n{}\n{}\n", |
|
| 607 format!("Running {} on experiment {}…", alg_name, experiment_name).cyan(), |
|
| 608 format!("Iteration settings: {}", serde_json::to_string(&iterator_options)?).bright_black(), |
|
| 609 format!("Algorithm settings: {}", serde_json::to_string(&alg)?).bright_black()) |
|
| 610 } else { |
|
| 611 "".to_string() |
|
| 612 }; |
|
| 613 // |
|
| 614 // The following is for postprocessing, which has been disabled anyway. |
|
| 615 // |
|
| 616 // let reg : Box<dyn WeightOptim<_, _, _, N>> = match regularisation { |
|
| 617 // Regularisation::Radon(α) => Box::new(RadonRegTerm(α)), |
|
| 618 // Regularisation::NonnegRadon(α) => Box::new(NonnegRadonRegTerm(α)), |
|
| 619 // }; |
|
| 620 //let findim_data = reg.prepare_optimise_weights(&opA, &b); |
|
| 621 //let inner_config : InnerSettings<F> = Default::default(); |
|
| 622 //let inner_it = inner_config.iterator_options; |
|
| 623 |
|
| 624 // Create plotter and directory if needed. |
|
| 625 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
|
| 626 let plotter = SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()); |
|
| 627 |
|
| 628 let start = Instant::now(); |
|
| 629 let start_cpu = ProcessTime::now(); |
|
| 630 |
|
| 631 let (μ, z) = match do_alg(alg, iterator, plotter, running) { |
|
| 632 Ok(μ) => μ, |
|
| 633 Err(RunError::NotImplemented) => { |
|
| 634 let msg = format!("Algorithm “{alg_name}” not implemented for {experiment_name}. \ |
|
| 635 Skipping.").red(); |
|
| 636 eprintln!("{}", msg); |
|
| 637 continue |
|
| 638 } |
|
| 639 }; |
|
| 640 |
|
| 641 let elapsed = start.elapsed().as_secs_f64(); |
|
| 642 let cpu_time = start_cpu.elapsed().as_secs_f64(); |
|
| 643 |
|
| 644 println!("{}", format!("Elapsed {elapsed}s (CPU time {cpu_time}s)… ").yellow()); |
|
| 645 |
|
| 646 // Save results |
|
| 647 println!("{}", "Saving results …".green()); |
|
| 648 |
|
| 649 let mkname = |t| format!("{prefix}{alg_name}_{t}"); |
|
| 650 |
|
| 651 write_json(mkname("config.json"), &named)?; |
|
| 652 write_json(mkname("stats.json"), &AlgorithmStats { cpu_time, elapsed })?; |
|
| 653 μ.write_csv(mkname("reco.txt"))?; |
|
| 654 save_extra(mkname(""), z)?; |
|
| 655 //logger.write_csv(mkname("log.txt"))?; |
|
| 656 logs.push((mkname("log.txt"), logger)); |
|
| 657 } |
|
| 658 |
|
| 659 save_logs(logs, format!("{prefix}valuerange.json"), cli.load_valuerange) |
|
| 660 } |
|
| 661 |
|
| 662 #[replace_float_literals(F::cast_from(literal))] |
|
| 663 impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N : usize> RunnableExperiment<F> for |
|
| 664 Named<ExperimentV2<F, NoiseDistr, S, K, P, N>> |
|
| 665 where |
|
| 666 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> |
|
| 667 + Default + for<'b> Deserialize<'b>, |
|
| 668 [usize; N] : Serialize, |
|
| 669 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
|
| 670 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
|
| 671 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy |
|
| 672 // TODO: shold not have differentiability as a requirement, but |
|
| 673 // decide availability of sliding based on it. |
|
| 674 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
|
| 675 // TODO: very weird that rust only compiles with Differentiable |
|
| 676 // instead of the above one on references, which is required by |
|
| 677 // poitsource_sliding_fb_reg. |
|
| 678 + DifferentiableRealMapping<F, N> |
|
| 679 + Lipschitz<L2, FloatType=F>, |
|
| 680 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb. |
|
| 681 AutoConvolution<P> : BoundedBy<F, K>, |
|
| 682 K : SimpleConvolutionKernel<F, N> |
|
| 683 + LocalAnalysis<F, Bounds<F>, N> |
|
| 684 + Copy + Serialize + std::fmt::Debug, |
|
| 685 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
|
| 686 PlotLookup : Plotting<N>, |
|
| 687 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
|
| 688 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 689 RNDM<F, N> : SpikeMerging<F>, |
|
| 690 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, |
|
| 691 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, |
|
| 692 // PreadjointCodomain : Space + Bounded<F> + DifferentiableRealMapping<F, N>, |
|
| 693 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, |
|
| 694 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, |
|
| 695 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, |
|
| 696 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, |
|
| 697 { |
|
| 698 |
|
| 699 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { |
|
| 700 AlgorithmOverrides { |
699 AlgorithmOverrides { |
| 701 merge_radius : Some(self.data.default_merge_radius), |
700 merge_radius: Some(self.data.default_merge_radius), |
| 702 .. self.data.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) |
701 ..self |
| 703 } |
702 .data |
| 704 } |
703 .algorithm_overrides |
| 705 |
704 .get(&alg) |
| 706 fn runall(&self, cli : &CommandLineArgs, |
705 .cloned() |
| 707 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
706 .unwrap_or(Default::default()) |
| |
707 } |
| |
708 } |
| |
709 |
| |
710 fn runall( |
| |
711 &self, |
| |
712 cli: &CommandLineArgs, |
| |
713 algs: Option<Vec<Named<AlgorithmConfig<F>>>>, |
| |
714 ) -> DynError { |
| 708 // Get experiment configuration |
715 // Get experiment configuration |
| 709 let &Named { |
716 let &ExperimentV2 { |
| 710 name : ref experiment_name, |
717 domain, |
| 711 data : ExperimentV2 { |
718 sensor_count, |
| 712 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
719 ref noise_distr, |
| 713 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, |
720 sensor, |
| 714 .. |
721 spread, |
| 715 } |
722 kernel, |
| 716 } = self; |
723 ref μ_hat, |
| |
724 regularisation, |
| |
725 kernel_plot_width, |
| |
726 dataterm, |
| |
727 noise_seed, |
| |
728 .. |
| |
729 } = &self.data; |
| 717 |
730 |
| 718 // Set up algorithms |
731 // Set up algorithms |
| 719 let algorithms = match (algs, dataterm) { |
732 let algorithms = match (algs, dataterm) { |
| 720 (Some(algs), _) => algs, |
733 (Some(algs), _) => algs, |
| 721 (None, DataTerm::L2Squared) => vec![DefaultAlgorithm::FB.get_named()], |
734 (None, DataTermType::L222) => vec![DefaultAlgorithm::FB.get_named()], |
| 722 (None, DataTerm::L1) => vec![DefaultAlgorithm::PDPS.get_named()], |
735 (None, DataTermType::L1) => vec![DefaultAlgorithm::PDPS.get_named()], |
| 723 }; |
736 }; |
| 724 |
737 |
| 725 // Set up operators |
738 // Set up operators |
| 726 let depth = DynamicDepth(8); |
739 let depth = DynamicDepth(8); |
| 727 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); |
740 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); |
| 736 let b = &b_hat + &noise; |
749 let b = &b_hat + &noise; |
| 737 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
750 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
| 738 // overloading log10 and conflicting with standard NumTraits one. |
751 // overloading log10 and conflicting with standard NumTraits one. |
| 739 let stats = ExperimentStats::new(&b, &noise); |
752 let stats = ExperimentStats::new(&b, &noise); |
| 740 |
753 |
| 741 let prefix = start_experiment(&self, cli, stats)?; |
754 let prefix = self.start(cli)?; |
| 742 |
755 write_json(format!("{prefix}stats.json"), &stats)?; |
| 743 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
756 |
| 744 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
757 plotall( |
| 745 |
758 cli, |
| 746 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
759 &prefix, |
| |
760 &domain, |
| |
761 &sensor, |
| |
762 &kernel, |
| |
763 &spread, |
| |
764 &μ_hat, |
| |
765 &op𝒟, |
| |
766 &opA, |
| |
767 &b_hat, |
| |
768 &b, |
| |
769 kernel_plot_width, |
| |
770 )?; |
| |
771 |
| |
772 let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]); |
| |
773 let make_plotter = |this_prefix| { |
| |
774 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
| |
775 SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()) |
| |
776 }; |
| 747 |
777 |
| 748 let save_extra = |_, ()| Ok(()); |
778 let save_extra = |_, ()| Ok(()); |
| 749 |
779 |
| 750 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, |
780 let μ0 = None; // Zero init |
| 751 |alg, iterator, plotter, running| |
781 |
| 752 { |
782 match (dataterm, regularisation) { |
| 753 let μ = match alg { |
783 (DataTermType::L1, Regularisation::Radon(α)) => { |
| 754 AlgorithmConfig::FB(ref algconfig, prox) => { |
784 let f = DataTerm::new(opA, b, L1.as_mapping()); |
| 755 match (regularisation, dataterm, prox) { |
785 let reg = RadonRegTerm(α); |
| 756 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
786 self.do_runall( |
| 757 print!("{running}"); |
787 &prefix, |
| 758 pointsource_fb_reg( |
788 cli, |
| 759 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
789 algorithms, |
| 760 iterator, plotter |
790 make_plotter, |
| 761 ) |
791 save_extra, |
| 762 }), |
792 μ0, |
| 763 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
793 |p| { |
| 764 print!("{running}"); |
794 run_pdps(&f, ®, &RadonSquared, p, |p| { |
| 765 pointsource_fb_reg( |
795 run_pdps(&f, ®, &op𝒟, p, |_| Err(NotImplemented.into())) |
| 766 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
796 }) |
| 767 iterator, plotter |
797 .map(|μ| (μ, ())) |
| 768 ) |
798 }, |
| 769 }), |
799 ) |
| 770 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
800 } |
| 771 print!("{running}"); |
801 (DataTermType::L1, Regularisation::NonnegRadon(α)) => { |
| 772 pointsource_fb_reg( |
802 let f = DataTerm::new(opA, b, L1.as_mapping()); |
| 773 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
803 let reg = NonnegRadonRegTerm(α); |
| 774 iterator, plotter |
804 self.do_runall( |
| 775 ) |
805 &prefix, |
| 776 }), |
806 cli, |
| 777 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
807 algorithms, |
| 778 print!("{running}"); |
808 make_plotter, |
| 779 pointsource_fb_reg( |
809 save_extra, |
| 780 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
810 μ0, |
| 781 iterator, plotter |
811 |p| { |
| 782 ) |
812 run_pdps(&f, ®, &RadonSquared, p, |p| { |
| 783 }), |
813 run_pdps(&f, ®, &op𝒟, p, |_| Err(NotImplemented.into())) |
| 784 _ => Err(NotImplemented) |
814 }) |
| 785 } |
815 .map(|μ| (μ, ())) |
| 786 }, |
816 }, |
| 787 AlgorithmConfig::FISTA(ref algconfig, prox) => { |
817 ) |
| 788 match (regularisation, dataterm, prox) { |
818 } |
| 789 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
819 (DataTermType::L222, Regularisation::Radon(α)) => { |
| 790 print!("{running}"); |
820 let f = DataTerm::new(opA, b, Norm222::new()); |
| 791 pointsource_fista_reg( |
821 let reg = RadonRegTerm(α); |
| 792 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
822 self.do_runall( |
| 793 iterator, plotter |
823 &prefix, |
| 794 ) |
824 cli, |
| 795 }), |
825 algorithms, |
| 796 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
826 make_plotter, |
| 797 print!("{running}"); |
827 save_extra, |
| 798 pointsource_fista_reg( |
828 μ0, |
| 799 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
829 |p| { |
| 800 iterator, plotter |
830 run_fb(&f, ®, &RadonSquared, p, |p| { |
| 801 ) |
831 run_pdps(&f, ®, &RadonSquared, p, |p| { |
| 802 }), |
832 run_fb(&f, ®, &op𝒟, p, |p| { |
| 803 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
833 run_pdps(&f, ®, &op𝒟, p, |p| { |
| 804 print!("{running}"); |
834 run_fw(&f, ®, p, |_| Err(NotImplemented.into())) |
| 805 pointsource_fista_reg( |
835 }) |
| 806 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
836 }) |
| 807 iterator, plotter |
837 }) |
| 808 ) |
838 }) |
| 809 }), |
839 .map(|μ| (μ, ())) |
| 810 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
840 }, |
| 811 print!("{running}"); |
841 ) |
| 812 pointsource_fista_reg( |
842 } |
| 813 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
843 (DataTermType::L222, Regularisation::NonnegRadon(α)) => { |
| 814 iterator, plotter |
844 let f = DataTerm::new(opA, b, Norm222::new()); |
| 815 ) |
845 let reg = NonnegRadonRegTerm(α); |
| 816 }), |
846 self.do_runall( |
| 817 _ => Err(NotImplemented), |
847 &prefix, |
| 818 } |
848 cli, |
| 819 }, |
849 algorithms, |
| 820 AlgorithmConfig::SlidingFB(ref algconfig, prox) => { |
850 make_plotter, |
| 821 match (regularisation, dataterm, prox) { |
851 save_extra, |
| 822 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
852 μ0, |
| 823 print!("{running}"); |
853 |p| { |
| 824 pointsource_sliding_fb_reg( |
854 run_fb(&f, ®, &RadonSquared, p, |p| { |
| 825 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
855 run_pdps(&f, ®, &RadonSquared, p, |p| { |
| 826 iterator, plotter |
856 run_fb(&f, ®, &op𝒟, p, |p| { |
| 827 ) |
857 run_pdps(&f, ®, &op𝒟, p, |p| { |
| 828 }), |
858 run_fw(&f, ®, p, |_| Err(NotImplemented.into())) |
| 829 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
859 }) |
| 830 print!("{running}"); |
860 }) |
| 831 pointsource_sliding_fb_reg( |
861 }) |
| 832 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
862 }) |
| 833 iterator, plotter |
863 .map(|μ| (μ, ())) |
| 834 ) |
864 }, |
| 835 }), |
865 ) |
| 836 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
866 } |
| 837 print!("{running}"); |
867 } |
| 838 pointsource_sliding_fb_reg( |
868 } |
| 839 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
869 } |
| 840 iterator, plotter |
870 |
| 841 ) |
871 /// Runs PDPS if `alg` so requests and `prox_penalty` matches. |
| 842 }), |
872 /// |
| 843 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
873 /// Due to the structure of the PDPS, the data term `f` has to have a specific form. |
| 844 print!("{running}"); |
874 /// |
| 845 pointsource_sliding_fb_reg( |
875 /// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. |
| 846 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
876 pub fn run_pdps<'a, F, A, Phi, Reg, P, I, Plot, const N: usize>( |
| 847 iterator, plotter |
877 f: &'a DataTerm<F, RNDM<N, F>, A, Phi>, |
| 848 ) |
878 reg: &Reg, |
| 849 }), |
879 prox_penalty: &P, |
| 850 _ => Err(NotImplemented), |
880 (alg, iterator, plotter, μ0, running): ( |
| 851 } |
881 &AlgorithmConfig<F>, |
| 852 }, |
882 I, |
| 853 AlgorithmConfig::PDPS(ref algconfig, prox) => { |
883 Plot, |
| 854 print!("{running}"); |
884 Option<RNDM<N, F>>, |
| 855 match (regularisation, dataterm, prox) { |
885 String, |
| 856 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
886 ), |
| 857 pointsource_pdps_reg( |
887 cont: impl FnOnce( |
| 858 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
888 (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String), |
| 859 iterator, plotter, L2Squared |
889 ) -> DynResult<RNDM<N, F>>, |
| 860 ) |
890 ) -> DynResult<RNDM<N, F>> |
| 861 }), |
891 where |
| 862 (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
892 F: Float + ToNalgebraRealField, |
| 863 pointsource_pdps_reg( |
893 A: ForwardModel<RNDM<N, F>, F>, |
| 864 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
894 Phi: Conjugable<A::Observable, F>, |
| 865 iterator, plotter, L2Squared |
895 for<'b> Phi::Conjugate<'b>: Prox<A::Observable>, |
| 866 ) |
896 for<'b> &'b A::Observable: Instance<A::Observable>, |
| 867 }), |
897 A::Observable: AXPY, |
| 868 (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ |
898 Reg: SlidingRegTerm<Loc<N, F>, F>, |
| 869 pointsource_pdps_reg( |
899 P: ProxPenalty<Loc<N, F>, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD<F, A, RNDM<N, F>>, |
| 870 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
900 RNDM<N, F>: SpikeMerging<F>, |
| 871 iterator, plotter, L1 |
901 I: AlgIteratorFactory<IterInfo<F>>, |
| 872 ) |
902 Plot: Plotter<P::ReturnMapping, A::PreadjointCodomain, RNDM<N, F>>, |
| 873 }), |
903 { |
| 874 (Regularisation::Radon(α), DataTerm::L1, ProxTerm::Wave) => Ok({ |
904 match alg { |
| 875 pointsource_pdps_reg( |
905 &AlgorithmConfig::PDPS(ref algconfig, prox_type) if prox_type == P::prox_type() => { |
| 876 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
906 print!("{running}"); |
| 877 iterator, plotter, L1 |
907 pointsource_pdps_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) |
| 878 ) |
908 } |
| 879 }), |
909 _ => cont((alg, iterator, plotter, μ0, running)), |
| 880 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
910 } |
| 881 pointsource_pdps_reg( |
911 } |
| 882 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
912 |
| 883 iterator, plotter, L2Squared |
913 /// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches. |
| 884 ) |
914 /// |
| 885 }), |
915 /// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. |
| 886 (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
916 pub fn run_fb<F, Dat, Reg, P, I, Plot, const N: usize>( |
| 887 pointsource_pdps_reg( |
917 f: &Dat, |
| 888 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
918 reg: &Reg, |
| 889 iterator, plotter, L2Squared |
919 prox_penalty: &P, |
| 890 ) |
920 (alg, iterator, plotter, μ0, running): ( |
| 891 }), |
921 &AlgorithmConfig<F>, |
| 892 (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ |
922 I, |
| 893 pointsource_pdps_reg( |
923 Plot, |
| 894 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
924 Option<RNDM<N, F>>, |
| 895 iterator, plotter, L1 |
925 String, |
| 896 ) |
926 ), |
| 897 }), |
927 cont: impl FnOnce( |
| 898 (Regularisation::Radon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({ |
928 (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String), |
| 899 pointsource_pdps_reg( |
929 ) -> DynResult<RNDM<N, F>>, |
| 900 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
930 ) -> DynResult<RNDM<N, F>> |
| 901 iterator, plotter, L1 |
931 where |
| 902 ) |
932 F: Float + ToNalgebraRealField, |
| 903 }), |
933 I: AlgIteratorFactory<IterInfo<F>>, |
| 904 // _ => Err(NotImplemented), |
934 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>, |
| 905 } |
935 Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>, |
| 906 }, |
936 RNDM<N, F>: SpikeMerging<F>, |
| 907 AlgorithmConfig::FW(ref algconfig) => { |
937 Reg: SlidingRegTerm<Loc<N, F>, F>, |
| 908 match (regularisation, dataterm) { |
938 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
| 909 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ |
939 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
| 910 print!("{running}"); |
940 { |
| 911 pointsource_fw_reg(&opA, &b, RadonRegTerm(α), |
941 let pt = P::prox_type(); |
| 912 algconfig, iterator, plotter) |
942 |
| 913 }), |
943 match alg { |
| 914 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ |
944 &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => { |
| 915 print!("{running}"); |
945 print!("{running}"); |
| 916 pointsource_fw_reg(&opA, &b, NonnegRadonRegTerm(α), |
946 pointsource_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) |
| 917 algconfig, iterator, plotter) |
947 } |
| 918 }), |
948 &AlgorithmConfig::FISTA(ref algconfig, prox_type) if prox_type == pt => { |
| 919 _ => Err(NotImplemented), |
949 print!("{running}"); |
| 920 } |
950 pointsource_fista_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) |
| 921 }, |
951 } |
| 922 _ => Err(NotImplemented), |
952 &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => { |
| 923 }?; |
953 print!("{running}"); |
| 924 Ok((μ, ())) |
954 pointsource_sliding_fb_reg(f, reg, prox_penalty, algconfig, iterator, plotter, μ0) |
| 925 }) |
955 } |
| 926 } |
956 _ => cont((alg, iterator, plotter, μ0, running)), |
| 927 } |
957 } |
| 928 |
958 } |
| |
959 |
| |
960 /// Runs FB-style algorithms if `alg` so requests and `prox_penalty` matches. |
| |
961 /// |
| |
962 /// For the moment, due to restrictions of the Frank–Wolfe implementation, only the |
| |
963 /// $L^2$-squared data term is enabled through the type signatures. |
| |
964 /// |
| |
965 /// `cont` gives a continuation attempt to find the algorithm matching the description `alg`. |
| |
966 pub fn run_fw<'a, F, A, Reg, I, Plot, const N: usize>( |
| |
967 f: &'a DataTerm<F, RNDM<N, F>, A, Norm222<F>>, |
| |
968 reg: &Reg, |
| |
969 (alg, iterator, plotter, μ0, running): ( |
| |
970 &AlgorithmConfig<F>, |
| |
971 I, |
| |
972 Plot, |
| |
973 Option<RNDM<N, F>>, |
| |
974 String, |
| |
975 ), |
| |
976 cont: impl FnOnce( |
| |
977 (&AlgorithmConfig<F>, I, Plot, Option<RNDM<N, F>>, String), |
| |
978 ) -> DynResult<RNDM<N, F>>, |
| |
979 ) -> DynResult<RNDM<N, F>> |
| |
980 where |
| |
981 F: Float + ToNalgebraRealField, |
| |
982 I: AlgIteratorFactory<IterInfo<F>>, |
| |
983 A: ForwardModel<RNDM<N, F>, F>, |
| |
984 A::PreadjointCodomain: MinMaxMapping<Loc<N, F>, F>, |
| |
985 for<'b> &'b A::PreadjointCodomain: Instance<A::PreadjointCodomain>, |
| |
986 Cube<N, F>: P2Minimise<Loc<N, F>, F>, |
| |
987 RNDM<N, F>: SpikeMerging<F>, |
| |
988 Reg: RegTermFW<F, A, ValueIteratorFactory<F, AlgIteratorOptions>, N>, |
| |
989 Plot: Plotter<A::PreadjointCodomain, A::PreadjointCodomain, RNDM<N, F>>, |
| |
990 { |
| |
991 match alg { |
| |
992 &AlgorithmConfig::FW(ref algconfig) => { |
| |
993 print!("{running}"); |
| |
994 pointsource_fw_reg(f, reg, algconfig, iterator, plotter, μ0) |
| |
995 } |
| |
996 _ => cont((alg, iterator, plotter, μ0, running)), |
| |
997 } |
| |
998 } |
| 929 |
999 |
| 930 #[replace_float_literals(F::cast_from(literal))] |
1000 #[replace_float_literals(F::cast_from(literal))] |
| 931 impl<F, NoiseDistr, S, K, P, B, /*PreadjointCodomain,*/ const N : usize> RunnableExperiment<F> for |
1001 impl<F, NoiseDistr, S, K, P, B, const N: usize> RunnableExperiment<F> |
| 932 Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> |
1002 for Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> |
| 933 where |
1003 where |
| 934 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F> |
1004 F: ClapFloat |
| 935 + Default + for<'b> Deserialize<'b>, |
1005 + nalgebra::RealField |
| 936 [usize; N] : Serialize, |
1006 + ToNalgebraRealField<MixedType = F> |
| 937 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
1007 + Default |
| 938 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
1008 + for<'b> Deserialize<'b>, |
| 939 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy |
1009 [usize; N]: Serialize, |
| 940 // TODO: shold not have differentiability as a requirement, but |
1010 S: Sensor<N, F> + Copy + Serialize + std::fmt::Debug, |
| 941 // decide availability of sliding based on it. |
1011 P: Spread<N, F> + Copy + Serialize + std::fmt::Debug, |
| 942 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
1012 Convolution<S, P>: Spread<N, F> |
| 943 // TODO: very weird that rust only compiles with Differentiable |
1013 + Bounded<F> |
| 944 // instead of the above one on references, which is required by |
|
| 945 // poitsource_sliding_fb_reg. |
|
| 946 + DifferentiableRealMapping<F, N> |
|
| 947 + Lipschitz<L2, FloatType=F>, |
|
| 948 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<F,N>>>::Differential<'b> : Lipschitz<L2, FloatType=F>, // TODO: should not be required generally, only for sliding_fb. |
|
| 949 AutoConvolution<P> : BoundedBy<F, K>, |
|
| 950 K : SimpleConvolutionKernel<F, N> |
|
| 951 + LocalAnalysis<F, Bounds<F>, N> |
1014 + LocalAnalysis<F, Bounds<F>, N> |
| 952 + Copy + Serialize + std::fmt::Debug, |
1015 + Copy |
| 953 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
1016 // TODO: shold not have differentiability as a requirement, but |
| 954 PlotLookup : Plotting<N>, |
1017 // decide availability of sliding based on it. |
| 955 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
1018 //+ for<'b> Differentiable<&'b Loc<N, F>, Output = Loc<N, F>>, |
| |
1019 // TODO: very weird that rust only compiles with Differentiable |
| |
1020 // instead of the above one on references, which is required by |
| |
1021 // poitsource_sliding_fb_reg. |
| |
1022 + DifferentiableRealMapping<N, F> |
| |
1023 + Lipschitz<L2, FloatType = F>, |
| |
1024 for<'b> <Convolution<S, P> as DifferentiableMapping<Loc<N, F>>>::Differential<'b>: |
| |
1025 Lipschitz<L2, FloatType = F>, // TODO: should not be required generally, only for sliding_fb. |
| |
1026 AutoConvolution<P>: BoundedBy<F, K>, |
| |
1027 K: SimpleConvolutionKernel<N, F> |
| |
1028 + LocalAnalysis<F, Bounds<F>, N> |
| |
1029 + Copy |
| |
1030 + Serialize |
| |
1031 + std::fmt::Debug, |
| |
1032 Cube<N, F>: P2Minimise<Loc<N, F>, F> + SetOrd, |
| |
1033 PlotLookup: Plotting<N>, |
| |
1034 DefaultBT<F, N>: SensorGridBT<F, S, P, N, Depth = DynamicDepth> + BTSearch<N, F>, |
| 956 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
1035 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| 957 RNDM<F, N> : SpikeMerging<F>, |
1036 RNDM<N, F>: SpikeMerging<F>, |
| 958 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, |
1037 NoiseDistr: Distribution<F> + Serialize + std::fmt::Debug, |
| 959 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, |
1038 B: Mapping<Loc<N, F>, Codomain = F> + Serialize + std::fmt::Debug, |
| 960 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>, |
1039 nalgebra::DVector<F>: ClosedMul<F>, |
| 961 // PreadjointCodomain : Bounded<F> + DifferentiableRealMapping<F, N>, |
1040 // This is mainly required for the final Mul requirement to be defined |
| |
1041 // DefaultSG<F, S, P, N>: ForwardModel< |
| |
1042 // RNDM<N, F>, |
| |
1043 // F, |
| |
1044 // PreadjointCodomain = PreadjointCodomain, |
| |
1045 // Observable = DVector<F::MixedType>, |
| |
1046 // >, |
| |
1047 // PreadjointCodomain: Bounded<F> + DifferentiableRealMapping<N, F> + std::ops::Mul<F>, |
| |
1048 // Pair<PreadjointCodomain, DVector<F>>: std::ops::Mul<F>, |
| 962 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, |
1049 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, |
| 963 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, |
1050 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, |
| 964 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, |
1051 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>, |
| 965 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, |
1052 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>, |
| 966 { |
1053 { |
| 967 |
1054 fn name(&self) -> &str { |
| 968 fn algorithm_overrides(&self, alg : DefaultAlgorithm) -> AlgorithmOverrides<F> { |
1055 self.name.as_ref() |
| |
1056 } |
| |
1057 |
| |
1058 fn algorithm_overrides(&self, alg: DefaultAlgorithm) -> AlgorithmOverrides<F> { |
| 969 AlgorithmOverrides { |
1059 AlgorithmOverrides { |
| 970 merge_radius : Some(self.data.base.default_merge_radius), |
1060 merge_radius: Some(self.data.base.default_merge_radius), |
| 971 .. self.data.base.algorithm_overrides.get(&alg).cloned().unwrap_or(Default::default()) |
1061 ..self |
| 972 } |
1062 .data |
| 973 } |
1063 .base |
| 974 |
1064 .algorithm_overrides |
| 975 fn runall(&self, cli : &CommandLineArgs, |
1065 .get(&alg) |
| 976 algs : Option<Vec<Named<AlgorithmConfig<F>>>>) -> DynError { |
1066 .cloned() |
| |
1067 .unwrap_or(Default::default()) |
| |
1068 } |
| |
1069 } |
| |
1070 |
| |
1071 fn runall( |
| |
1072 &self, |
| |
1073 cli: &CommandLineArgs, |
| |
1074 algs: Option<Vec<Named<AlgorithmConfig<F>>>>, |
| |
1075 ) -> DynError { |
| 977 // Get experiment configuration |
1076 // Get experiment configuration |
| 978 let &Named { |
1077 let &ExperimentBiased { |
| 979 name : ref experiment_name, |
1078 λ, |
| 980 data : ExperimentBiased { |
1079 ref bias, |
| 981 λ, |
1080 base: |
| 982 ref bias, |
1081 ExperimentV2 { |
| 983 base : ExperimentV2 { |
1082 domain, |
| 984 domain, sensor_count, ref noise_distr, sensor, spread, kernel, |
1083 sensor_count, |
| 985 ref μ_hat, regularisation, kernel_plot_width, dataterm, noise_seed, |
1084 ref noise_distr, |
| |
1085 sensor, |
| |
1086 spread, |
| |
1087 kernel, |
| |
1088 ref μ_hat, |
| |
1089 regularisation, |
| |
1090 kernel_plot_width, |
| |
1091 dataterm, |
| |
1092 noise_seed, |
| 986 .. |
1093 .. |
| 987 } |
1094 }, |
| 988 } |
1095 } = &self.data; |
| 989 } = self; |
|
| 990 |
1096 |
| 991 // Set up algorithms |
1097 // Set up algorithms |
| 992 let algorithms = match (algs, dataterm) { |
1098 let algorithms = match (algs, dataterm) { |
| 993 (Some(algs), _) => algs, |
1099 (Some(algs), _) => algs, |
| 994 _ => vec![DefaultAlgorithm::SlidingPDPS.get_named()], |
1100 _ => vec![DefaultAlgorithm::SlidingPDPS.get_named()], |
| 998 let depth = DynamicDepth(8); |
1104 let depth = DynamicDepth(8); |
| 999 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); |
1105 let opA = DefaultSG::new(domain, sensor_count, sensor, spread, depth); |
| 1000 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); |
1106 let op𝒟 = DefaultSeminormOp::new(depth, domain, kernel); |
| 1001 let opAext = RowOp(opA.clone(), IdOp::new()); |
1107 let opAext = RowOp(opA.clone(), IdOp::new()); |
| 1002 let fnR = Zero::new(); |
1108 let fnR = Zero::new(); |
| 1003 let h = map3(domain.span_start(), domain.span_end(), sensor_count, |
1109 let h = map3( |
| 1004 |a, b, n| (b-a)/F::cast_from(n)) |
1110 domain.span_start(), |
| 1005 .into_iter() |
1111 domain.span_end(), |
| 1006 .reduce(NumTraitsFloat::max) |
1112 sensor_count, |
| 1007 .unwrap(); |
1113 |a, b, n| (b - a) / F::cast_from(n), |
| |
1114 ) |
| |
1115 .into_iter() |
| |
1116 .reduce(NumTraitsFloat::max) |
| |
1117 .unwrap(); |
| 1008 let z = DVector::zeros(sensor_count.iter().product()); |
1118 let z = DVector::zeros(sensor_count.iter().product()); |
| 1009 let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap(); |
1119 let opKz = Grad::new_for(&z, h, sensor_count, ForwardNeumann).unwrap(); |
| 1010 let y = opKz.apply(&z); |
1120 let y = opKz.apply(&z); |
| 1011 let fnH = Weighted{ base_fn : L1.as_mapping(), weight : λ}; // TODO: L_{2,1} |
1121 let fnH = Weighted { base_fn: L1.as_mapping(), weight: λ }; // TODO: L_{2,1} |
| 1012 // let zero_y = y.clone(); |
1122 // let zero_y = y.clone(); |
| 1013 // let zeroBTFN = opA.preadjoint().apply(&zero_y); |
1123 // let zeroBTFN = opA.preadjoint().apply(&zero_y); |
| 1014 // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); |
1124 // let opKμ = ZeroOp::new(&zero_y, zeroBTFN); |
| 1015 |
1125 |
| 1016 // Set up random number generator. |
1126 // Set up random number generator. |
| 1017 let mut rng = StdRng::seed_from_u64(noise_seed); |
1127 let mut rng = StdRng::seed_from_u64(noise_seed); |
| 1018 |
1128 |
| 1019 // Generate the data and calculate SSNR statistic |
1129 // Generate the data and calculate SSNR statistic |
| 1020 let bias_vec = DVector::from_vec(opA.grid() |
1130 let bias_vec = DVector::from_vec( |
| 1021 .into_iter() |
1131 opA.grid() |
| 1022 .map(|v| bias.apply(v)) |
1132 .into_iter() |
| 1023 .collect::<Vec<F>>()); |
1133 .map(|v| bias.apply(v)) |
| 1024 let b_hat : DVector<_> = opA.apply(μ_hat) + &bias_vec; |
1134 .collect::<Vec<F>>(), |
| |
1135 ); |
| |
1136 let b_hat: DVector<_> = opA.apply(μ_hat) + &bias_vec; |
| 1025 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); |
1137 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); |
| 1026 let b = &b_hat + &noise; |
1138 let b = &b_hat + &noise; |
| 1027 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
1139 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField |
| 1028 // overloading log10 and conflicting with standard NumTraits one. |
1140 // overloading log10 and conflicting with standard NumTraits one. |
| 1029 let stats = ExperimentStats::new(&b, &noise); |
1141 let stats = ExperimentStats::new(&b, &noise); |
| 1030 |
1142 |
| 1031 let prefix = start_experiment(&self, cli, stats)?; |
1143 let prefix = self.start(cli)?; |
| 1032 |
1144 write_json(format!("{prefix}stats.json"), &stats)?; |
| 1033 plotall(cli, &prefix, &domain, &sensor, &kernel, &spread, |
1145 |
| 1034 &μ_hat, &op𝒟, &opA, &b_hat, &b, kernel_plot_width)?; |
1146 plotall( |
| |
1147 cli, |
| |
1148 &prefix, |
| |
1149 &domain, |
| |
1150 &sensor, |
| |
1151 &kernel, |
| |
1152 &spread, |
| |
1153 &μ_hat, |
| |
1154 &op𝒟, |
| |
1155 &opA, |
| |
1156 &b_hat, |
| |
1157 &b, |
| |
1158 kernel_plot_width, |
| |
1159 )?; |
| 1035 |
1160 |
| 1036 opA.write_observable(&bias_vec, format!("{prefix}bias"))?; |
1161 opA.write_observable(&bias_vec, format!("{prefix}bias"))?; |
| 1037 |
1162 |
| 1038 let plotgrid = lingrid(&domain, &[if N==1 { 1000 } else { 100 }; N]); |
1163 let plotgrid = lingrid(&domain, &[if N == 1 { 1000 } else { 100 }; N]); |
| |
1164 let make_plotter = |this_prefix| { |
| |
1165 let plot_count = if cli.plot >= PlotLevel::Iter { 2000 } else { 0 }; |
| |
1166 SeqPlotter::new(this_prefix, plot_count, plotgrid.clone()) |
| |
1167 }; |
| 1039 |
1168 |
| 1040 let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z")); |
1169 let save_extra = |prefix, z| opA.write_observable(&z, format!("{prefix}z")); |
| 1041 |
1170 |
| 1042 // Run the algorithms |
1171 let μ0 = None; // Zero init |
| 1043 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, |
1172 |
| 1044 |alg, iterator, plotter, running| |
1173 match (dataterm, regularisation) { |
| 1045 { |
1174 (DataTermType::L222, Regularisation::Radon(α)) => { |
| 1046 let Pair(μ, z) = match alg { |
1175 let f = DataTerm::new(opAext, b, Norm222::new()); |
| 1047 AlgorithmConfig::ForwardPDPS(ref algconfig, prox) => { |
1176 let reg = RadonRegTerm(α); |
| 1048 match (regularisation, dataterm, prox) { |
1177 self.do_runall( |
| 1049 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
1178 &prefix, |
| 1050 print!("{running}"); |
1179 cli, |
| 1051 pointsource_forward_pdps_pair( |
1180 algorithms, |
| 1052 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
1181 make_plotter, |
| 1053 iterator, plotter, |
1182 save_extra, |
| 1054 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
1183 (μ0, z, y), |
| 1055 ) |
1184 |p| { |
| 1056 }), |
1185 run_pdps_pair(&f, ®, &RadonSquared, &opKz, &fnR, &fnH, p, |q| { |
| 1057 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
1186 run_pdps_pair(&f, ®, &op𝒟, &opKz, &fnR, &fnH, q, |_| { |
| 1058 print!("{running}"); |
1187 Err(NotImplemented.into()) |
| 1059 pointsource_forward_pdps_pair( |
1188 }) |
| 1060 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, |
1189 }) |
| 1061 iterator, plotter, |
1190 .map(|Pair(μ, z)| (μ, z)) |
| 1062 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
1191 }, |
| 1063 ) |
1192 ) |
| 1064 }), |
1193 } |
| 1065 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
1194 (DataTermType::L222, Regularisation::NonnegRadon(α)) => { |
| 1066 print!("{running}"); |
1195 let f = DataTerm::new(opAext, b, Norm222::new()); |
| 1067 pointsource_forward_pdps_pair( |
1196 let reg = NonnegRadonRegTerm(α); |
| 1068 &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
1197 self.do_runall( |
| 1069 iterator, plotter, |
1198 &prefix, |
| 1070 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
1199 cli, |
| 1071 ) |
1200 algorithms, |
| 1072 }), |
1201 make_plotter, |
| 1073 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
1202 save_extra, |
| 1074 print!("{running}"); |
1203 (μ0, z, y), |
| 1075 pointsource_forward_pdps_pair( |
1204 |p| { |
| 1076 &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
1205 run_pdps_pair(&f, ®, &RadonSquared, &opKz, &fnR, &fnH, p, |q| { |
| 1077 iterator, plotter, |
1206 run_pdps_pair(&f, ®, &op𝒟, &opKz, &fnR, &fnH, q, |_| { |
| 1078 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
1207 Err(NotImplemented.into()) |
| 1079 ) |
1208 }) |
| 1080 }), |
1209 }) |
| 1081 _ => Err(NotImplemented) |
1210 .map(|Pair(μ, z)| (μ, z)) |
| 1082 } |
1211 }, |
| 1083 }, |
1212 ) |
| 1084 AlgorithmConfig::SlidingPDPS(ref algconfig, prox) => { |
1213 } |
| 1085 match (regularisation, dataterm, prox) { |
1214 _ => Err(NotImplemented.into()), |
| 1086 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
1215 } |
| 1087 print!("{running}"); |
1216 } |
| 1088 pointsource_sliding_pdps_pair( |
1217 } |
| 1089 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
1218 |
| 1090 iterator, plotter, |
1219 type MeasureZ<F, Z, const N: usize> = Pair<RNDM<N, F>, Z>; |
| 1091 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
1220 |
| 1092 ) |
1221 pub fn run_pdps_pair<F, S, Dat, Reg, Z, R, Y, KOpZ, H, P, I, Plot, const N: usize>( |
| 1093 }), |
1222 f: &Dat, |
| 1094 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({ |
1223 reg: &Reg, |
| 1095 print!("{running}"); |
1224 prox_penalty: &P, |
| 1096 pointsource_sliding_pdps_pair( |
1225 opKz: &KOpZ, |
| 1097 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, |
1226 fnR: &R, |
| 1098 iterator, plotter, |
1227 fnH: &H, |
| 1099 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
1228 (alg, iterator, plotter, μ0zy, running): ( |
| 1100 ) |
1229 &AlgorithmConfig<F>, |
| 1101 }), |
1230 I, |
| 1102 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
1231 Plot, |
| 1103 print!("{running}"); |
1232 (Option<RNDM<N, F>>, Z, Y), |
| 1104 pointsource_sliding_pdps_pair( |
1233 String, |
| 1105 &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig, |
1234 ), |
| 1106 iterator, plotter, |
1235 cont: impl FnOnce( |
| 1107 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
1236 ( |
| 1108 ) |
1237 &AlgorithmConfig<F>, |
| 1109 }), |
1238 I, |
| 1110 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({ |
1239 Plot, |
| 1111 print!("{running}"); |
1240 (Option<RNDM<N, F>>, Z, Y), |
| 1112 pointsource_sliding_pdps_pair( |
1241 String, |
| 1113 &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig, |
1242 ), |
| 1114 iterator, plotter, |
1243 ) -> DynResult<Pair<RNDM<N, F>, Z>>, |
| 1115 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), |
1244 ) -> DynResult<Pair<RNDM<N, F>, Z>> |
| 1116 ) |
1245 where |
| 1117 }), |
1246 F: Float + ToNalgebraRealField, |
| 1118 _ => Err(NotImplemented) |
1247 I: AlgIteratorFactory<IterInfo<F>>, |
| 1119 } |
1248 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>> |
| 1120 }, |
1249 + BoundedCurvature<F>, |
| 1121 _ => Err(NotImplemented) |
1250 S: DifferentiableRealMapping<N, F> + ClosedMul<F>, |
| 1122 }?; |
1251 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>, |
| 1123 Ok((μ, z)) |
1252 //Pair<S, Z>: ClosedMul<F>, |
| 1124 }) |
1253 RNDM<N, F>: SpikeMerging<F>, |
| |
1254 Reg: SlidingRegTerm<Loc<N, F>, F>, |
| |
1255 P: ProxPenalty<Loc<N, F>, S, Reg, F>, |
| |
1256 KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y> |
| |
1257 + GEMV<F, Z> |
| |
1258 + SimplyAdjointable<Z, Y, AdjointCodomain = Z>, |
| |
1259 KOpZ::SimpleAdjoint: GEMV<F, Y>, |
| |
1260 Y: ClosedEuclidean<F> + Clone, |
| |
1261 for<'b> &'b Y: Instance<Y>, |
| |
1262 Z: ClosedEuclidean<F> + Clone + ClosedMul<F>, |
| |
1263 for<'b> &'b Z: Instance<Z>, |
| |
1264 R: Prox<Z, Codomain = F>, |
| |
1265 H: Conjugable<Y, F, Codomain = F>, |
| |
1266 for<'b> H::Conjugate<'b>: Prox<Y>, |
| |
1267 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>, |
| |
1268 { |
| |
1269 let pt = P::prox_type(); |
| |
1270 |
| |
1271 match alg { |
| |
1272 &AlgorithmConfig::ForwardPDPS(ref algconfig, prox_type) if prox_type == pt => { |
| |
1273 print!("{running}"); |
| |
1274 pointsource_forward_pdps_pair( |
| |
1275 f, |
| |
1276 reg, |
| |
1277 prox_penalty, |
| |
1278 algconfig, |
| |
1279 iterator, |
| |
1280 plotter, |
| |
1281 μ0zy, |
| |
1282 opKz, |
| |
1283 fnR, |
| |
1284 fnH, |
| |
1285 ) |
| |
1286 } |
| |
1287 &AlgorithmConfig::SlidingPDPS(ref algconfig, prox_type) if prox_type == pt => { |
| |
1288 print!("{running}"); |
| |
1289 pointsource_sliding_pdps_pair( |
| |
1290 f, |
| |
1291 reg, |
| |
1292 prox_penalty, |
| |
1293 algconfig, |
| |
1294 iterator, |
| |
1295 plotter, |
| |
1296 μ0zy, |
| |
1297 opKz, |
| |
1298 fnR, |
| |
1299 fnH, |
| |
1300 ) |
| |
1301 } |
| |
1302 _ => cont((alg, iterator, plotter, μ0zy, running)), |
| |
1303 } |
| |
1304 } |
| |
1305 |
| |
1306 pub fn run_fb_pair<F, S, Dat, Reg, Z, R, P, I, Plot, const N: usize>( |
| |
1307 f: &Dat, |
| |
1308 reg: &Reg, |
| |
1309 prox_penalty: &P, |
| |
1310 fnR: &R, |
| |
1311 (alg, iterator, plotter, μ0z, running): ( |
| |
1312 &AlgorithmConfig<F>, |
| |
1313 I, |
| |
1314 Plot, |
| |
1315 (Option<RNDM<N, F>>, Z), |
| |
1316 String, |
| |
1317 ), |
| |
1318 cont: impl FnOnce( |
| |
1319 ( |
| |
1320 &AlgorithmConfig<F>, |
| |
1321 I, |
| |
1322 Plot, |
| |
1323 (Option<RNDM<N, F>>, Z), |
| |
1324 String, |
| |
1325 ), |
| |
1326 ) -> DynResult<Pair<RNDM<N, F>, Z>>, |
| |
1327 ) -> DynResult<Pair<RNDM<N, F>, Z>> |
| |
1328 where |
| |
1329 F: Float + ToNalgebraRealField, |
| |
1330 I: AlgIteratorFactory<IterInfo<F>>, |
| |
1331 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>> |
| |
1332 + BoundedCurvature<F>, |
| |
1333 S: DifferentiableRealMapping<N, F> + ClosedMul<F>, |
| |
1334 RNDM<N, F>: SpikeMerging<F>, |
| |
1335 Reg: SlidingRegTerm<Loc<N, F>, F>, |
| |
1336 P: ProxPenalty<Loc<N, F>, S, Reg, F>, |
| |
1337 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>, |
| |
1338 Z: ClosedEuclidean<F> + AXPY + Clone, |
| |
1339 for<'b> &'b Z: Instance<Z>, |
| |
1340 R: Prox<Z, Codomain = F>, |
| |
1341 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>, |
| |
1342 // We should not need to explicitly require this: |
| |
1343 for<'b> &'b Loc<0, F>: Instance<Loc<0, F>>, |
| |
1344 { |
| |
1345 let pt = P::prox_type(); |
| |
1346 |
| |
1347 match alg { |
| |
1348 &AlgorithmConfig::FB(ref algconfig, prox_type) if prox_type == pt => { |
| |
1349 print!("{running}"); |
| |
1350 pointsource_fb_pair(f, reg, prox_penalty, algconfig, iterator, plotter, μ0z, fnR) |
| |
1351 } |
| |
1352 &AlgorithmConfig::SlidingFB(ref algconfig, prox_type) if prox_type == pt => { |
| |
1353 print!("{running}"); |
| |
1354 pointsource_sliding_fb_pair( |
| |
1355 f, |
| |
1356 reg, |
| |
1357 prox_penalty, |
| |
1358 algconfig, |
| |
1359 iterator, |
| |
1360 plotter, |
| |
1361 μ0z, |
| |
1362 fnR, |
| |
1363 ) |
| |
1364 } |
| |
1365 _ => cont((alg, iterator, plotter, μ0z, running)), |
| 1125 } |
1366 } |
| 1126 } |
1367 } |
| 1127 |
1368 |
| 1128 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
1369 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
| 1129 struct ValueRange<F : Float> { |
1370 struct ValueRange<F: Float> { |
| 1130 ini : F, |
1371 ini: F, |
| 1131 min : F, |
1372 min: F, |
| 1132 } |
1373 } |
| 1133 |
1374 |
| 1134 impl<F : Float> ValueRange<F> { |
1375 impl<F: Float> ValueRange<F> { |
| 1135 fn expand_with(self, other : Self) -> Self { |
1376 fn expand_with(self, other: Self) -> Self { |
| 1136 ValueRange { |
1377 ValueRange { ini: self.ini.max(other.ini), min: self.min.min(other.min) } |
| 1137 ini : self.ini.max(other.ini), |
|
| 1138 min : self.min.min(other.min), |
|
| 1139 } |
|
| 1140 } |
1378 } |
| 1141 } |
1379 } |
| 1142 |
1380 |
| 1143 /// Calculative minimum and maximum values of all the `logs`, and save them into |
1381 /// Calculative minimum and maximum values of all the `logs`, and save them into |
| 1144 /// corresponding file names given as the first elements of the tuples in the vectors. |
1382 /// corresponding file names given as the first elements of the tuples in the vectors. |
| 1145 fn save_logs<F : Float + for<'b> Deserialize<'b>, const N : usize>( |
1383 fn save_logs<F: Float + for<'b> Deserialize<'b>>( |
| 1146 logs : Vec<(String, Logger<Timed<IterInfo<F, N>>>)>, |
1384 logs: Vec<(String, Logger<Timed<IterInfo<F>>>)>, |
| 1147 valuerange_file : String, |
1385 valuerange_file: String, |
| 1148 load_valuerange : bool, |
1386 load_valuerange: bool, |
| 1149 ) -> DynError { |
1387 ) -> DynError { |
| 1150 // Process logs for relative values |
1388 // Process logs for relative values |
| 1151 println!("{}", "Processing logs…"); |
1389 println!("{}", "Processing logs…"); |
| 1152 |
1390 |
| 1153 // Find minimum value and initial value within a single log |
1391 // Find minimum value and initial value within a single log |
| 1154 let proc_single_log = |log : &Logger<Timed<IterInfo<F, N>>>| { |
1392 let proc_single_log = |log: &Logger<Timed<IterInfo<F>>>| { |
| 1155 let d = log.data(); |
1393 let d = log.data(); |
| 1156 let mi = d.iter() |
1394 let mi = d.iter().map(|i| i.data.value).reduce(NumTraitsFloat::min); |
| 1157 .map(|i| i.data.value) |
|
| 1158 .reduce(NumTraitsFloat::min); |
|
| 1159 d.first() |
1395 d.first() |
| 1160 .map(|i| i.data.value) |
1396 .map(|i| i.data.value) |
| 1161 .zip(mi) |
1397 .zip(mi) |
| 1162 .map(|(ini, min)| ValueRange{ ini, min }) |
1398 .map(|(ini, min)| ValueRange { ini, min }) |
| 1163 }; |
1399 }; |
| 1164 |
1400 |
| 1165 // Find minimum and maximum value over all logs |
1401 // Find minimum and maximum value over all logs |
| 1166 let mut v = logs.iter() |
1402 let mut v = logs |
| 1167 .filter_map(|&(_, ref log)| proc_single_log(log)) |
1403 .iter() |
| 1168 .reduce(|v1, v2| v1.expand_with(v2)) |
1404 .filter_map(|&(_, ref log)| proc_single_log(log)) |
| 1169 .ok_or(anyhow!("No algorithms found"))?; |
1405 .reduce(|v1, v2| v1.expand_with(v2)) |
| |
1406 .ok_or(anyhow!("No algorithms found"))?; |
| 1170 |
1407 |
| 1171 // Load existing range |
1408 // Load existing range |
| 1172 if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() { |
1409 if load_valuerange && std::fs::metadata(&valuerange_file).is_ok() { |
| 1173 let data = std::fs::read_to_string(&valuerange_file)?; |
1410 let data = std::fs::read_to_string(&valuerange_file)?; |
| 1174 v = v.expand_with(serde_json::from_str(&data)?); |
1411 v = v.expand_with(serde_json::from_str(&data)?); |
| 1222 } |
1461 } |
| 1223 |
1462 |
| 1224 Ok(()) |
1463 Ok(()) |
| 1225 } |
1464 } |
| 1226 |
1465 |
| 1227 |
|
| 1228 /// Plot experiment setup |
1466 /// Plot experiment setup |
| 1229 #[replace_float_literals(F::cast_from(literal))] |
1467 #[replace_float_literals(F::cast_from(literal))] |
| 1230 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N : usize>( |
1468 fn plotall<F, Sensor, Kernel, Spread, 𝒟, A, const N: usize>( |
| 1231 cli : &CommandLineArgs, |
1469 cli: &CommandLineArgs, |
| 1232 prefix : &String, |
1470 prefix: &String, |
| 1233 domain : &Cube<F, N>, |
1471 domain: &Cube<N, F>, |
| 1234 sensor : &Sensor, |
1472 sensor: &Sensor, |
| 1235 kernel : &Kernel, |
1473 kernel: &Kernel, |
| 1236 spread : &Spread, |
1474 spread: &Spread, |
| 1237 μ_hat : &RNDM<F, N>, |
1475 μ_hat: &RNDM<N, F>, |
| 1238 op𝒟 : &𝒟, |
1476 op𝒟: &𝒟, |
| 1239 opA : &A, |
1477 opA: &A, |
| 1240 b_hat : &A::Observable, |
1478 b_hat: &A::Observable, |
| 1241 b : &A::Observable, |
1479 b: &A::Observable, |
| 1242 kernel_plot_width : F, |
1480 kernel_plot_width: F, |
| 1243 ) -> DynError |
1481 ) -> DynError |
| 1244 where F : Float + ToNalgebraRealField, |
1482 where |
| 1245 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
1483 F: Float + ToNalgebraRealField, |
| 1246 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
1484 Sensor: RealMapping<N, F> + Support<N, F> + Clone, |
| 1247 Kernel : RealMapping<F, N> + Support<F, N>, |
1485 Spread: RealMapping<N, F> + Support<N, F> + Clone, |
| 1248 Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>, |
1486 Kernel: RealMapping<N, F> + Support<N, F>, |
| 1249 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
1487 Convolution<Sensor, Spread>: DifferentiableRealMapping<N, F> + Support<N, F>, |
| 1250 𝒟::Codomain : RealMapping<F, N>, |
1488 𝒟: DiscreteMeasureOp<Loc<N, F>, F>, |
| 1251 A : ForwardModel<RNDM<F, N>, F>, |
1489 𝒟::Codomain: RealMapping<N, F>, |
| 1252 for<'a> &'a A::Observable : Instance<A::Observable>, |
1490 A: ForwardModel<RNDM<N, F>, F>, |
| 1253 A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>, |
1491 for<'a> &'a A::Observable: Instance<A::Observable>, |
| 1254 PlotLookup : Plotting<N>, |
1492 A::PreadjointCodomain: DifferentiableRealMapping<N, F> + Bounded<F>, |
| 1255 Cube<F, N> : SetOrd { |
1493 PlotLookup: Plotting<N>, |
| 1256 |
1494 Cube<N, F>: SetOrd, |
| |
1495 { |
| 1257 if cli.plot < PlotLevel::Data { |
1496 if cli.plot < PlotLevel::Data { |
| 1258 return Ok(()) |
1497 return Ok(()); |
| 1259 } |
1498 } |
| 1260 |
1499 |
| 1261 let base = Convolution(sensor.clone(), spread.clone()); |
1500 let base = Convolution(sensor.clone(), spread.clone()); |
| 1262 |
1501 |
| 1263 let resolution = if N==1 { 100 } else { 40 }; |
1502 let resolution = if N == 1 { 100 } else { 40 }; |
| 1264 let pfx = |n| format!("{prefix}{n}"); |
1503 let pfx = |n| format!("{prefix}{n}"); |
| 1265 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
1504 let plotgrid = lingrid( |
| |
1505 &[[-kernel_plot_width, kernel_plot_width]; N].into(), |
| |
1506 &[resolution; N], |
| |
1507 ); |
| 1266 |
1508 |
| 1267 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor")); |
1509 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor")); |
| 1268 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel")); |
1510 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel")); |
| 1269 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread")); |
1511 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread")); |
| 1270 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor")); |
1512 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor")); |
| 1271 |
1513 |
| 1272 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
1514 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
| 1273 |
1515 |
| 1274 let ω_hat = op𝒟.apply(μ_hat); |
1516 let ω_hat = op𝒟.apply(μ_hat); |
| 1275 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
1517 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
| 1276 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat")); |
1518 PlotLookup::plot_into_file(&ω_hat, plotgrid2, pfx("omega_hat")); |
| 1277 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise")); |
1519 PlotLookup::plot_into_file(&noise, plotgrid2, pfx("omega_noise")); |
| 1278 |
1520 |
| 1279 let preadj_b = opA.preadjoint().apply(b); |
1521 let preadj_b = opA.preadjoint().apply(b); |
| 1280 let preadj_b_hat = opA.preadjoint().apply(b_hat); |
1522 let preadj_b_hat = opA.preadjoint().apply(b_hat); |
| 1281 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); |
1523 //let bounds = preadj_b.bounds().common(&preadj_b_hat.bounds()); |
| 1282 PlotLookup::plot_into_file_spikes( |
1524 PlotLookup::plot_into_file_spikes( |
| 1283 Some(&preadj_b), |
1525 Some(&preadj_b), |
| 1284 Some(&preadj_b_hat), |
1526 Some(&preadj_b_hat), |
| 1285 plotgrid2, |
1527 plotgrid2, |
| 1286 &μ_hat, |
1528 &μ_hat, |
| 1287 pfx("omega_b") |
1529 pfx("omega_b"), |
| 1288 ); |
1530 ); |
| 1289 PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b")); |
1531 PlotLookup::plot_into_file(&preadj_b, plotgrid2, pfx("preadj_b")); |
| 1290 PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat")); |
1532 PlotLookup::plot_into_file(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat")); |
| 1291 |
1533 |
| 1292 // Save true solution and observables |
1534 // Save true solution and observables |