src/run.rs

branch
dev
changeset 37
c5d8bd1a7728
parent 35
b087e3eab191
child 38
0f59c0d02e13
child 39
6316d68b58af
equal deleted inserted replaced
36:fb911f72e698 37:c5d8bd1a7728
70 FBConfig, 70 FBConfig,
71 FBGenericConfig, 71 FBGenericConfig,
72 pointsource_fb_reg, 72 pointsource_fb_reg,
73 pointsource_fista_reg, 73 pointsource_fista_reg,
74 }; 74 };
75 use crate::radon_fb::{
76 RadonFBConfig,
77 pointsource_radon_fb_reg,
78 pointsource_radon_fista_reg,
79 };
80 use crate::sliding_fb::{ 75 use crate::sliding_fb::{
81 SlidingFBConfig, 76 SlidingFBConfig,
82 TransportConfig, 77 TransportConfig,
83 pointsource_sliding_fb_reg 78 pointsource_sliding_fb_reg
84 }; 79 };
112 }; 107 };
113 use crate::dataterm::{ 108 use crate::dataterm::{
114 L1, 109 L1,
115 L2Squared, 110 L2Squared,
116 }; 111 };
112 use crate::prox_penalty::{
113 RadonSquared,
114 //ProxPenalty,
115 };
117 use alg_tools::norms::{L2, NormExponent}; 116 use alg_tools::norms::{L2, NormExponent};
118 use alg_tools::operator_arithmetic::Weighted; 117 use alg_tools::operator_arithmetic::Weighted;
119 use anyhow::anyhow; 118 use anyhow::anyhow;
120 119
120 /// Available proximal terms
121 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
122 pub enum ProxTerm {
123 /// Partial-to-wave operator 𝒟.
124 Wave,
125 /// Radon-norm squared
126 RadonSquared
127 }
128
121 /// Available algorithms and their configurations 129 /// Available algorithms and their configurations
122 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 130 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
123 pub enum AlgorithmConfig<F : Float> { 131 pub enum AlgorithmConfig<F : Float> {
124 FB(FBConfig<F>), 132 FB(FBConfig<F>, ProxTerm),
125 FISTA(FBConfig<F>), 133 FISTA(FBConfig<F>, ProxTerm),
126 FW(FWConfig<F>), 134 FW(FWConfig<F>),
127 PDPS(PDPSConfig<F>), 135 PDPS(PDPSConfig<F>, ProxTerm),
128 RadonFB(RadonFBConfig<F>), 136 SlidingFB(SlidingFBConfig<F>, ProxTerm),
129 RadonFISTA(RadonFBConfig<F>), 137 ForwardPDPS(ForwardPDPSConfig<F>, ProxTerm),
130 SlidingFB(SlidingFBConfig<F>), 138 SlidingPDPS(SlidingPDPSConfig<F>, ProxTerm),
131 ForwardPDPS(ForwardPDPSConfig<F>),
132 SlidingPDPS(SlidingPDPSConfig<F>),
133 } 139 }
134 140
135 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { 141 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> {
136 assert!(v.len() == 3); 142 assert!(v.len() == 3);
137 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } 143 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] }
163 } 169 }
164 }; 170 };
165 171
166 use AlgorithmConfig::*; 172 use AlgorithmConfig::*;
167 match self { 173 match self {
168 FB(fb) => FB(FBConfig { 174 FB(fb, prox) => FB(FBConfig {
169 τ0 : cli.tau0.unwrap_or(fb.τ0), 175 τ0 : cli.tau0.unwrap_or(fb.τ0),
170 generic : override_fb_generic(fb.generic), 176 generic : override_fb_generic(fb.generic),
171 .. fb 177 .. fb
172 }), 178 }, prox),
173 FISTA(fb) => FISTA(FBConfig { 179 FISTA(fb, prox) => FISTA(FBConfig {
174 τ0 : cli.tau0.unwrap_or(fb.τ0), 180 τ0 : cli.tau0.unwrap_or(fb.τ0),
175 generic : override_fb_generic(fb.generic), 181 generic : override_fb_generic(fb.generic),
176 .. fb 182 .. fb
177 }), 183 }, prox),
178 PDPS(pdps) => PDPS(PDPSConfig { 184 PDPS(pdps, prox) => PDPS(PDPSConfig {
179 τ0 : cli.tau0.unwrap_or(pdps.τ0), 185 τ0 : cli.tau0.unwrap_or(pdps.τ0),
180 σ0 : cli.sigma0.unwrap_or(pdps.σ0), 186 σ0 : cli.sigma0.unwrap_or(pdps.σ0),
181 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), 187 acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
182 generic : override_fb_generic(pdps.generic), 188 generic : override_fb_generic(pdps.generic),
183 .. pdps 189 .. pdps
184 }), 190 }, prox),
185 FW(fw) => FW(FWConfig { 191 FW(fw) => FW(FWConfig {
186 merging : cli.merging.clone().unwrap_or(fw.merging), 192 merging : cli.merging.clone().unwrap_or(fw.merging),
187 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), 193 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance),
188 .. fw 194 .. fw
189 }), 195 }),
190 RadonFB(fb) => RadonFB(RadonFBConfig { 196 SlidingFB(sfb, prox) => SlidingFB(SlidingFBConfig {
191 τ0 : cli.tau0.unwrap_or(fb.τ0),
192 insertion : override_fb_generic(fb.insertion),
193 .. fb
194 }),
195 RadonFISTA(fb) => RadonFISTA(RadonFBConfig {
196 τ0 : cli.tau0.unwrap_or(fb.τ0),
197 insertion : override_fb_generic(fb.insertion),
198 .. fb
199 }),
200 SlidingFB(sfb) => SlidingFB(SlidingFBConfig {
201 τ0 : cli.tau0.unwrap_or(sfb.τ0), 197 τ0 : cli.tau0.unwrap_or(sfb.τ0),
202 transport : override_transport(sfb.transport), 198 transport : override_transport(sfb.transport),
203 insertion : override_fb_generic(sfb.insertion), 199 insertion : override_fb_generic(sfb.insertion),
204 .. sfb 200 .. sfb
205 }), 201 }, prox),
206 SlidingPDPS(spdps) => SlidingPDPS(SlidingPDPSConfig { 202 SlidingPDPS(spdps, prox) => SlidingPDPS(SlidingPDPSConfig {
207 τ0 : cli.tau0.unwrap_or(spdps.τ0), 203 τ0 : cli.tau0.unwrap_or(spdps.τ0),
208 σp0 : cli.sigmap0.unwrap_or(spdps.σp0), 204 σp0 : cli.sigmap0.unwrap_or(spdps.σp0),
209 σd0 : cli.sigma0.unwrap_or(spdps.σd0), 205 σd0 : cli.sigma0.unwrap_or(spdps.σd0),
210 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), 206 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
211 transport : override_transport(spdps.transport), 207 transport : override_transport(spdps.transport),
212 insertion : override_fb_generic(spdps.insertion), 208 insertion : override_fb_generic(spdps.insertion),
213 .. spdps 209 .. spdps
214 }), 210 }, prox),
215 ForwardPDPS(fpdps) => ForwardPDPS(ForwardPDPSConfig { 211 ForwardPDPS(fpdps, prox) => ForwardPDPS(ForwardPDPSConfig {
216 τ0 : cli.tau0.unwrap_or(fpdps.τ0), 212 τ0 : cli.tau0.unwrap_or(fpdps.τ0),
217 σp0 : cli.sigmap0.unwrap_or(fpdps.σp0), 213 σp0 : cli.sigmap0.unwrap_or(fpdps.σp0),
218 σd0 : cli.sigma0.unwrap_or(fpdps.σd0), 214 σd0 : cli.sigma0.unwrap_or(fpdps.σd0),
219 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration), 215 //acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
220 insertion : override_fb_generic(fpdps.insertion), 216 insertion : override_fb_generic(fpdps.insertion),
221 .. fpdps 217 .. fpdps
222 }), 218 }, prox),
223 } 219 }
224 } 220 }
225 } 221 }
226 222
227 /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name. 223 /// Helper struct for tagging and [`AlgorithmConfig`] or [`ExperimentV2`] with a name.
248 #[clap(name = "fwrelax")] 244 #[clap(name = "fwrelax")]
249 FWRelax, 245 FWRelax,
250 /// The μPDPS primal-dual proximal splitting method 246 /// The μPDPS primal-dual proximal splitting method
251 #[clap(name = "pdps")] 247 #[clap(name = "pdps")]
252 PDPS, 248 PDPS,
253 /// The RadonFB forward-backward method
254 #[clap(name = "radon_fb")]
255 RadonFB,
256 /// The RadonFISTA inertial forward-backward method
257 #[clap(name = "radon_fista")]
258 RadonFISTA,
259 /// The sliding FB method 249 /// The sliding FB method
260 #[clap(name = "sliding_fb", alias = "sfb")] 250 #[clap(name = "sliding_fb", alias = "sfb")]
261 SlidingFB, 251 SlidingFB,
262 /// The sliding PDPS method 252 /// The sliding PDPS method
263 #[clap(name = "sliding_pdps", alias = "spdps")] 253 #[clap(name = "sliding_pdps", alias = "spdps")]
264 SlidingPDPS, 254 SlidingPDPS,
265 /// The PDPS method with a forward step for the smooth function 255 /// The PDPS method with a forward step for the smooth function
266 #[clap(name = "forward_pdps", alias = "fpdps")] 256 #[clap(name = "forward_pdps", alias = "fpdps")]
267 ForwardPDPS, 257 ForwardPDPS,
258
259 // Radon variants
260
261 /// The μFB forward-backward method with radon-norm squared proximal term
262 #[clap(name = "radon_fb")]
263 RadonFB,
264 /// The μFISTA inertial forward-backward method with radon-norm squared proximal term
265 #[clap(name = "radon_fista")]
266 RadonFISTA,
267 /// The μPDPS primal-dual proximal splitting method with radon-norm squared proximal term
268 #[clap(name = "radon_pdps")]
269 RadonPDPS,
270 /// The sliding FB method with radon-norm squared proximal term
271 #[clap(name = "radon_sliding_fb", alias = "radon_sfb")]
272 RadonSlidingFB,
273 /// The sliding PDPS method with radon-norm squared proximal term
274 #[clap(name = "radon_sliding_pdps", alias = "radon_spdps")]
275 RadonSlidingPDPS,
276 /// The PDPS method with a forward step for the smooth function with radon-norm squared proximal term
277 #[clap(name = "radon_forward_pdps", alias = "radon_fpdps")]
278 RadonForwardPDPS,
268 } 279 }
269 280
270 impl DefaultAlgorithm { 281 impl DefaultAlgorithm {
271 /// Returns the algorithm configuration corresponding to the algorithm shorthand 282 /// Returns the algorithm configuration corresponding to the algorithm shorthand
272 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { 283 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> {
273 use DefaultAlgorithm::*; 284 use DefaultAlgorithm::*;
274 match *self { 285 match *self {
275 FB => AlgorithmConfig::FB(Default::default()), 286 FB => AlgorithmConfig::FB(Default::default(), ProxTerm::Wave),
276 FISTA => AlgorithmConfig::FISTA(Default::default()), 287 FISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::Wave),
277 FW => AlgorithmConfig::FW(Default::default()), 288 FW => AlgorithmConfig::FW(Default::default()),
278 FWRelax => AlgorithmConfig::FW(FWConfig{ 289 FWRelax => AlgorithmConfig::FW(FWConfig{
279 variant : FWVariant::Relaxed, 290 variant : FWVariant::Relaxed,
280 .. Default::default() 291 .. Default::default()
281 }), 292 }),
282 PDPS => AlgorithmConfig::PDPS(Default::default()), 293 PDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::Wave),
283 RadonFB => AlgorithmConfig::RadonFB(Default::default()), 294 SlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::Wave),
284 RadonFISTA => AlgorithmConfig::RadonFISTA(Default::default()), 295 SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::Wave),
285 SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), 296 ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::Wave),
286 SlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default()), 297
287 ForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default()), 298 // Radon variants
299
300 RadonFB => AlgorithmConfig::FB(Default::default(), ProxTerm::RadonSquared),
301 RadonFISTA => AlgorithmConfig::FISTA(Default::default(), ProxTerm::RadonSquared),
302 RadonPDPS => AlgorithmConfig::PDPS(Default::default(), ProxTerm::RadonSquared),
303 RadonSlidingFB => AlgorithmConfig::SlidingFB(Default::default(), ProxTerm::RadonSquared),
304 RadonSlidingPDPS => AlgorithmConfig::SlidingPDPS(Default::default(), ProxTerm::RadonSquared),
305 RadonForwardPDPS => AlgorithmConfig::ForwardPDPS(Default::default(), ProxTerm::RadonSquared),
288 } 306 }
289 } 307 }
290 308
291 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand 309 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand
292 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { 310 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> {
600 618
601 save_logs(logs) 619 save_logs(logs)
602 } 620 }
603 621
604 #[replace_float_literals(F::cast_from(literal))] 622 #[replace_float_literals(F::cast_from(literal))]
605 impl<F, NoiseDistr, S, K, P, const N : usize> RunnableExperiment<F> for 623 impl<F, NoiseDistr, S, K, P, /*PreadjointCodomain, */ const N : usize> RunnableExperiment<F> for
606 Named<ExperimentV2<F, NoiseDistr, S, K, P, N>> 624 Named<ExperimentV2<F, NoiseDistr, S, K, P, N>>
607 where 625 where
608 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, 626 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
609 [usize; N] : Serialize, 627 [usize; N] : Serialize,
610 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, 628 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
626 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, 644 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
627 PlotLookup : Plotting<N>, 645 PlotLookup : Plotting<N>,
628 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, 646 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
629 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 647 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
630 RNDM<F, N> : SpikeMerging<F>, 648 RNDM<F, N> : SpikeMerging<F>,
631 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug 649 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug,
650 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>,
651 // PreadjointCodomain : Space + Bounded<F> + DifferentiableRealMapping<F, N>,
652 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
653 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
654 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
655 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
632 { 656 {
633 657
634 fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> { 658 fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> {
635 self.data.algorithm_defaults.get(&alg).cloned() 659 self.data.algorithm_defaults.get(&alg).cloned()
636 } 660 }
661 685
662 // Set up random number generator. 686 // Set up random number generator.
663 let mut rng = StdRng::seed_from_u64(noise_seed); 687 let mut rng = StdRng::seed_from_u64(noise_seed);
664 688
665 // Generate the data and calculate SSNR statistic 689 // Generate the data and calculate SSNR statistic
666 let b_hat : DVector<_> = opA.apply(μ_hat); 690 let b_hat = opA.apply(μ_hat);
667 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng); 691 let noise = DVector::from_distribution(b_hat.len(), &noise_distr, &mut rng);
668 let b = &b_hat + &noise; 692 let b = &b_hat + &noise;
669 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField 693 // Need to wrap calc_ssnr into a function to hide ultra-lame nalgebra::RealField
670 // overloading log10 and conflicting with standard NumTraits one. 694 // overloading log10 and conflicting with standard NumTraits one.
671 let stats = ExperimentStats::new(&b, &noise); 695 let stats = ExperimentStats::new(&b, &noise);
681 705
682 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, 706 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra,
683 |alg, iterator, plotter, running| 707 |alg, iterator, plotter, running|
684 { 708 {
685 let μ = match alg { 709 let μ = match alg {
686 AlgorithmConfig::FB(ref algconfig) => { 710 AlgorithmConfig::FB(ref algconfig, prox) => {
687 match (regularisation, dataterm) { 711 match (regularisation, dataterm, prox) {
688 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ 712 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({
689 print!("{running}"); 713 print!("{running}");
690 pointsource_fb_reg( 714 pointsource_fb_reg(
691 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 715 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
692 iterator, plotter 716 iterator, plotter
693 ) 717 )
694 }), 718 }),
695 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ 719 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({
696 print!("{running}"); 720 print!("{running}");
697 pointsource_fb_reg( 721 pointsource_fb_reg(
698 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 722 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
699 iterator, plotter 723 iterator, plotter
700 ) 724 )
701 }), 725 }),
726 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
727 print!("{running}");
728 pointsource_fb_reg(
729 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig,
730 iterator, plotter
731 )
732 }),
733 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
734 print!("{running}");
735 pointsource_fb_reg(
736 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig,
737 iterator, plotter
738 )
739 }),
702 _ => Err(NotImplemented) 740 _ => Err(NotImplemented)
703 } 741 }
704 }, 742 },
705 AlgorithmConfig::FISTA(ref algconfig) => { 743 AlgorithmConfig::FISTA(ref algconfig, prox) => {
706 match (regularisation, dataterm) { 744 match (regularisation, dataterm, prox) {
707 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ 745 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({
708 print!("{running}"); 746 print!("{running}");
709 pointsource_fista_reg( 747 pointsource_fista_reg(
710 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 748 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
711 iterator, plotter 749 iterator, plotter
712 ) 750 )
713 }), 751 }),
714 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ 752 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({
715 print!("{running}"); 753 print!("{running}");
716 pointsource_fista_reg( 754 pointsource_fista_reg(
717 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 755 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
718 iterator, plotter 756 iterator, plotter
719 ) 757 )
720 }), 758 }),
759 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
760 print!("{running}");
761 pointsource_fista_reg(
762 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig,
763 iterator, plotter
764 )
765 }),
766 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
767 print!("{running}");
768 pointsource_fista_reg(
769 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig,
770 iterator, plotter
771 )
772 }),
721 _ => Err(NotImplemented), 773 _ => Err(NotImplemented),
722 } 774 }
723 }, 775 },
724 AlgorithmConfig::RadonFB(ref algconfig) => { 776 AlgorithmConfig::SlidingFB(ref algconfig, prox) => {
725 match (regularisation, dataterm) { 777 match (regularisation, dataterm, prox) {
726 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ 778 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({
727 print!("{running}"); 779 print!("{running}");
728 pointsource_radon_fb_reg( 780 pointsource_sliding_fb_reg(
729 &opA, &b, NonnegRadonRegTerm(α), algconfig, 781 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
730 iterator, plotter 782 iterator, plotter
731 ) 783 )
732 }), 784 }),
733 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ 785 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({
734 print!("{running}"); 786 print!("{running}");
735 pointsource_radon_fb_reg( 787 pointsource_sliding_fb_reg(
736 &opA, &b, RadonRegTerm(α), algconfig, 788 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
789 iterator, plotter
790 )
791 }),
792 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
793 print!("{running}");
794 pointsource_sliding_fb_reg(
795 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig,
796 iterator, plotter
797 )
798 }),
799 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
800 print!("{running}");
801 pointsource_sliding_fb_reg(
802 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig,
737 iterator, plotter 803 iterator, plotter
738 ) 804 )
739 }), 805 }),
740 _ => Err(NotImplemented), 806 _ => Err(NotImplemented),
741 } 807 }
742 }, 808 },
743 AlgorithmConfig::RadonFISTA(ref algconfig) => { 809 AlgorithmConfig::PDPS(ref algconfig, prox) => {
744 match (regularisation, dataterm) {
745 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({
746 print!("{running}");
747 pointsource_radon_fista_reg(
748 &opA, &b, NonnegRadonRegTerm(α), algconfig,
749 iterator, plotter
750 )
751 }),
752 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({
753 print!("{running}");
754 pointsource_radon_fista_reg(
755 &opA, &b, RadonRegTerm(α), algconfig,
756 iterator, plotter
757 )
758 }),
759 _ => Err(NotImplemented),
760 }
761 },
762 AlgorithmConfig::SlidingFB(ref algconfig) => {
763 match (regularisation, dataterm) {
764 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({
765 print!("{running}");
766 pointsource_sliding_fb_reg(
767 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
768 iterator, plotter
769 )
770 }),
771 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({
772 print!("{running}");
773 pointsource_sliding_fb_reg(
774 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
775 iterator, plotter
776 )
777 }),
778 _ => Err(NotImplemented),
779 }
780 },
781 AlgorithmConfig::PDPS(ref algconfig) => {
782 print!("{running}"); 810 print!("{running}");
783 match (regularisation, dataterm) { 811 match (regularisation, dataterm, prox) {
784 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ 812 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({
785 pointsource_pdps_reg( 813 pointsource_pdps_reg(
786 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 814 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
787 iterator, plotter, L2Squared 815 iterator, plotter, L2Squared
788 ) 816 )
789 }), 817 }),
790 (Regularisation::Radon(α),DataTerm::L2Squared) => Ok({ 818 (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::Wave) => Ok({
791 pointsource_pdps_reg( 819 pointsource_pdps_reg(
792 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 820 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
793 iterator, plotter, L2Squared 821 iterator, plotter, L2Squared
794 ) 822 )
795 }), 823 }),
796 (Regularisation::NonnegRadon(α), DataTerm::L1) => Ok({ 824 (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::Wave) => Ok({
797 pointsource_pdps_reg( 825 pointsource_pdps_reg(
798 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 826 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
799 iterator, plotter, L1 827 iterator, plotter, L1
800 ) 828 )
801 }), 829 }),
802 (Regularisation::Radon(α), DataTerm::L1) => Ok({ 830 (Regularisation::Radon(α), DataTerm::L1, ProxTerm::Wave) => Ok({
803 pointsource_pdps_reg( 831 pointsource_pdps_reg(
804 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, 832 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
805 iterator, plotter, L1 833 iterator, plotter, L1
806 ) 834 )
807 }), 835 }),
836 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
837 pointsource_pdps_reg(
838 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig,
839 iterator, plotter, L2Squared
840 )
841 }),
842 (Regularisation::Radon(α),DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
843 pointsource_pdps_reg(
844 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig,
845 iterator, plotter, L2Squared
846 )
847 }),
848 (Regularisation::NonnegRadon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({
849 pointsource_pdps_reg(
850 &opA, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig,
851 iterator, plotter, L1
852 )
853 }),
854 (Regularisation::Radon(α), DataTerm::L1, ProxTerm::RadonSquared) => Ok({
855 pointsource_pdps_reg(
856 &opA, &b, RadonRegTerm(α), &RadonSquared, algconfig,
857 iterator, plotter, L1
858 )
859 }),
860 // _ => Err(NotImplemented),
808 } 861 }
809 }, 862 },
810 AlgorithmConfig::FW(ref algconfig) => { 863 AlgorithmConfig::FW(ref algconfig) => {
811 match (regularisation, dataterm) { 864 match (regularisation, dataterm) {
812 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ 865 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({
829 } 882 }
830 } 883 }
831 884
832 885
833 #[replace_float_literals(F::cast_from(literal))] 886 #[replace_float_literals(F::cast_from(literal))]
834 impl<F, NoiseDistr, S, K, P, B, const N : usize> RunnableExperiment<F> for 887 impl<F, NoiseDistr, S, K, P, B, /*PreadjointCodomain,*/ const N : usize> RunnableExperiment<F> for
835 Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>> 888 Named<ExperimentBiased<F, NoiseDistr, S, K, P, B, N>>
836 where 889 where
837 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, 890 F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
838 [usize; N] : Serialize, 891 [usize; N] : Serialize,
839 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, 892 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
857 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, 910 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
858 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 911 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
859 RNDM<F, N> : SpikeMerging<F>, 912 RNDM<F, N> : SpikeMerging<F>,
860 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug, 913 NoiseDistr : Distribution<F> + Serialize + std::fmt::Debug,
861 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug, 914 B : Mapping<Loc<F, N>, Codomain = F> + Serialize + std::fmt::Debug,
915 // DefaultSG<F, S, P, N> : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = PreadjointCodomain, Observable=DVector<F::MixedType>>,
916 // PreadjointCodomain : Bounded<F> + DifferentiableRealMapping<F, N>,
917 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
918 // DefaultSeminormOp<F, K, N> : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
919 // RadonSquared : ProxPenalty<F, PreadjointCodomain, RadonRegTerm<F>, N>,
920 // RadonSquared : ProxPenalty<F, PreadjointCodomain, NonnegRadonRegTerm<F>, N>,
862 { 921 {
863 922
864 fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> { 923 fn algorithm_defaults(&self, alg : DefaultAlgorithm) -> Option<AlgorithmConfig<F>> {
865 self.data.base.algorithm_defaults.get(&alg).cloned() 924 self.data.base.algorithm_defaults.get(&alg).cloned()
866 } 925 }
935 // Run the algorithms 994 // Run the algorithms
936 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra, 995 do_runall(experiment_name, &prefix, cli, algorithms, plotgrid, save_extra,
937 |alg, iterator, plotter, running| 996 |alg, iterator, plotter, running|
938 { 997 {
939 let Pair(μ, z) = match alg { 998 let Pair(μ, z) = match alg {
940 AlgorithmConfig::ForwardPDPS(ref algconfig) => { 999 AlgorithmConfig::ForwardPDPS(ref algconfig, prox) => {
941 match (regularisation, dataterm) { 1000 match (regularisation, dataterm, prox) {
942 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ 1001 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({
943 print!("{running}"); 1002 print!("{running}");
944 pointsource_forward_pdps_pair( 1003 pointsource_forward_pdps_pair(
945 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 1004 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
946 iterator, plotter, 1005 iterator, plotter,
947 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1006 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
948 ) 1007 )
949 }), 1008 }),
950 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ 1009 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({
951 print!("{running}"); 1010 print!("{running}");
952 pointsource_forward_pdps_pair( 1011 pointsource_forward_pdps_pair(
953 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, 1012 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig,
954 iterator, plotter, 1013 iterator, plotter,
955 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1014 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
956 ) 1015 )
957 }), 1016 }),
1017 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
1018 print!("{running}");
1019 pointsource_forward_pdps_pair(
1020 &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig,
1021 iterator, plotter,
1022 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
1023 )
1024 }),
1025 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
1026 print!("{running}");
1027 pointsource_forward_pdps_pair(
1028 &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig,
1029 iterator, plotter,
1030 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
1031 )
1032 }),
958 _ => Err(NotImplemented) 1033 _ => Err(NotImplemented)
959 } 1034 }
960 }, 1035 },
961 AlgorithmConfig::SlidingPDPS(ref algconfig) => { 1036 AlgorithmConfig::SlidingPDPS(ref algconfig, prox) => {
962 match (regularisation, dataterm) { 1037 match (regularisation, dataterm, prox) {
963 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => Ok({ 1038 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({
964 print!("{running}"); 1039 print!("{running}");
965 pointsource_sliding_pdps_pair( 1040 pointsource_sliding_pdps_pair(
966 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, 1041 &opAext, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
967 iterator, plotter, 1042 iterator, plotter,
968 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1043 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
969 ) 1044 )
970 }), 1045 }),
971 (Regularisation::Radon(α), DataTerm::L2Squared) => Ok({ 1046 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::Wave) => Ok({
972 print!("{running}"); 1047 print!("{running}");
973 pointsource_sliding_pdps_pair( 1048 pointsource_sliding_pdps_pair(
974 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig, 1049 &opAext, &b, RadonRegTerm(α), &op𝒟, algconfig,
1050 iterator, plotter,
1051 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
1052 )
1053 }),
1054 (Regularisation::NonnegRadon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
1055 print!("{running}");
1056 pointsource_sliding_pdps_pair(
1057 &opAext, &b, NonnegRadonRegTerm(α), &RadonSquared, algconfig,
1058 iterator, plotter,
1059 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
1060 )
1061 }),
1062 (Regularisation::Radon(α), DataTerm::L2Squared, ProxTerm::RadonSquared) => Ok({
1063 print!("{running}");
1064 pointsource_sliding_pdps_pair(
1065 &opAext, &b, RadonRegTerm(α), &RadonSquared, algconfig,
975 iterator, plotter, 1066 iterator, plotter,
976 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(), 1067 /* opKμ, */ &opKz, &fnR, &fnH, z.clone(), y.clone(),
977 ) 1068 )
978 }), 1069 }),
979 _ => Err(NotImplemented) 1070 _ => Err(NotImplemented)

mercurial