| 29 }; |
29 }; |
| 30 use alg_tools::logger::Logger; |
30 use alg_tools::logger::Logger; |
| 31 use alg_tools::error::DynError; |
31 use alg_tools::error::DynError; |
| 32 use alg_tools::tabledump::TableDump; |
32 use alg_tools::tabledump::TableDump; |
| 33 use alg_tools::sets::Cube; |
33 use alg_tools::sets::Cube; |
| 34 use alg_tools::mapping::{RealMapping, Differentiable}; |
34 use alg_tools::mapping::{ |
| |
35 RealMapping, |
| |
36 DifferentiableRealMapping |
| |
37 }; |
| 35 use alg_tools::nalgebra_support::ToNalgebraRealField; |
38 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 36 use alg_tools::euclidean::Euclidean; |
39 use alg_tools::euclidean::Euclidean; |
| 37 use alg_tools::lingrid::lingrid; |
40 use alg_tools::lingrid::lingrid; |
| 38 use alg_tools::sets::SetOrd; |
41 use alg_tools::sets::SetOrd; |
| 39 |
42 |
| 112 |
122 |
| 113 use AlgorithmConfig::*; |
123 use AlgorithmConfig::*; |
| 114 match self { |
124 match self { |
| 115 FB(fb) => FB(FBConfig { |
125 FB(fb) => FB(FBConfig { |
| 116 τ0 : cli.tau0.unwrap_or(fb.τ0), |
126 τ0 : cli.tau0.unwrap_or(fb.τ0), |
| 117 insertion : override_fb_generic(fb.insertion), |
127 generic : override_fb_generic(fb.generic), |
| 118 .. fb |
128 .. fb |
| 119 }), |
129 }), |
| 120 FISTA(fb) => FISTA(FBConfig { |
130 FISTA(fb) => FISTA(FBConfig { |
| 121 τ0 : cli.tau0.unwrap_or(fb.τ0), |
131 τ0 : cli.tau0.unwrap_or(fb.τ0), |
| 122 insertion : override_fb_generic(fb.insertion), |
132 generic : override_fb_generic(fb.generic), |
| 123 .. fb |
133 .. fb |
| 124 }), |
134 }), |
| 125 PDPS(pdps) => PDPS(PDPSConfig { |
135 PDPS(pdps) => PDPS(PDPSConfig { |
| 126 τ0 : cli.tau0.unwrap_or(pdps.τ0), |
136 τ0 : cli.tau0.unwrap_or(pdps.τ0), |
| 127 σ0 : cli.sigma0.unwrap_or(pdps.σ0), |
137 σ0 : cli.sigma0.unwrap_or(pdps.σ0), |
| 128 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
138 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
| 129 insertion : override_fb_generic(pdps.insertion), |
139 generic : override_fb_generic(pdps.generic), |
| 130 .. pdps |
140 .. pdps |
| 131 }), |
141 }), |
| 132 FW(fw) => FW(FWConfig { |
142 FW(fw) => FW(FWConfig { |
| 133 merging : cli.merging.clone().unwrap_or(fw.merging), |
143 merging : cli.merging.clone().unwrap_or(fw.merging), |
| 134 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), |
144 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), |
| 135 .. fw |
145 .. fw |
| 136 }), |
146 }), |
| |
147 RadonFB(fb) => RadonFB(RadonFBConfig { |
| |
148 τ0 : cli.tau0.unwrap_or(fb.τ0), |
| |
149 insertion : override_fb_generic(fb.insertion), |
| |
150 .. fb |
| |
151 }), |
| |
152 RadonFISTA(fb) => RadonFISTA(RadonFBConfig { |
| |
153 τ0 : cli.tau0.unwrap_or(fb.τ0), |
| |
154 insertion : override_fb_generic(fb.insertion), |
| |
155 .. fb |
| |
156 }), |
| 137 SlidingFB(sfb) => SlidingFB(SlidingFBConfig { |
157 SlidingFB(sfb) => SlidingFB(SlidingFBConfig { |
| 138 τ0 : cli.tau0.unwrap_or(sfb.τ0), |
158 τ0 : cli.tau0.unwrap_or(sfb.τ0), |
| |
159 θ0 : cli.theta0.unwrap_or(sfb.θ0), |
| |
160 transport_tolerance_ω: cli.transport_tolerance_omega.unwrap_or(sfb.transport_tolerance_ω), |
| |
161 transport_tolerance_dv: cli.transport_tolerance_dv.unwrap_or(sfb.transport_tolerance_dv), |
| 139 insertion : override_fb_generic(sfb.insertion), |
162 insertion : override_fb_generic(sfb.insertion), |
| 140 .. sfb |
163 .. sfb |
| 141 }), |
164 }), |
| 142 } |
165 } |
| 143 } |
166 } |
| 185 FWRelax => AlgorithmConfig::FW(FWConfig{ |
214 FWRelax => AlgorithmConfig::FW(FWConfig{ |
| 186 variant : FWVariant::Relaxed, |
215 variant : FWVariant::Relaxed, |
| 187 .. Default::default() |
216 .. Default::default() |
| 188 }), |
217 }), |
| 189 PDPS => AlgorithmConfig::PDPS(Default::default()), |
218 PDPS => AlgorithmConfig::PDPS(Default::default()), |
| |
219 RadonFB => AlgorithmConfig::RadonFB(Default::default()), |
| |
220 RadonFISTA => AlgorithmConfig::RadonFISTA(Default::default()), |
| 190 SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), |
221 SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), |
| 191 } |
222 } |
| 192 } |
223 } |
| 193 |
224 |
| 194 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand |
225 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand |
| 362 // decide availability of sliding based on it. |
393 // decide availability of sliding based on it. |
| 363 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
394 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
| 364 // TODO: very weird that rust only compiles with Differentiable |
395 // TODO: very weird that rust only compiles with Differentiable |
| 365 // instead of the above one on references, which is required by |
396 // instead of the above one on references, which is required by |
| 366 // poitsource_sliding_fb_reg. |
397 // poitsource_sliding_fb_reg. |
| 367 + Differentiable<Loc<F, N>, Output = Loc<F, N>> |
398 + DifferentiableRealMapping<F, N> |
| 368 + Lipschitz<L2>, |
399 + Lipschitz<L2, FloatType=F>, |
| 369 // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
400 // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
| 370 AutoConvolution<P> : BoundedBy<F, K>, |
401 AutoConvolution<P> : BoundedBy<F, K>, |
| 371 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> |
402 K : SimpleConvolutionKernel<F, N> |
| |
403 + LocalAnalysis<F, Bounds<F>, N> |
| 372 + Copy + Serialize + std::fmt::Debug, |
404 + Copy + Serialize + std::fmt::Debug, |
| 373 //+ Differentiable<Loc<F, N>, Output = Loc<F, N>>, // TODO: shouldn't need to assume differentiability |
|
| 374 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
405 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
| 375 PlotLookup : Plotting<N>, |
406 PlotLookup : Plotting<N>, |
| 376 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
407 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
| 377 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
408 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| 378 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
409 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
| 567 not_implemented(); |
598 not_implemented(); |
| 568 continue |
599 continue |
| 569 } |
600 } |
| 570 } |
601 } |
| 571 }, |
602 }, |
| |
603 AlgorithmConfig::RadonFB(ref algconfig) => { |
| |
604 match (regularisation, dataterm) { |
| |
605 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
| |
606 running(); |
| |
607 pointsource_radon_fb_reg( |
| |
608 &opA, &b, NonnegRadonRegTerm(α), algconfig, |
| |
609 iterator, plotter |
| |
610 ) |
| |
611 }, |
| |
612 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
| |
613 running(); |
| |
614 pointsource_radon_fb_reg( |
| |
615 &opA, &b, RadonRegTerm(α), algconfig, |
| |
616 iterator, plotter |
| |
617 ) |
| |
618 }, |
| |
619 _ => { |
| |
620 not_implemented(); |
| |
621 continue |
| |
622 } |
| |
623 } |
| |
624 }, |
| |
625 AlgorithmConfig::RadonFISTA(ref algconfig) => { |
| |
626 match (regularisation, dataterm) { |
| |
627 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
| |
628 running(); |
| |
629 pointsource_radon_fista_reg( |
| |
630 &opA, &b, NonnegRadonRegTerm(α), algconfig, |
| |
631 iterator, plotter |
| |
632 ) |
| |
633 }, |
| |
634 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
| |
635 running(); |
| |
636 pointsource_radon_fista_reg( |
| |
637 &opA, &b, RadonRegTerm(α), algconfig, |
| |
638 iterator, plotter |
| |
639 ) |
| |
640 }, |
| |
641 _ => { |
| |
642 not_implemented(); |
| |
643 continue |
| |
644 } |
| |
645 } |
| |
646 }, |
| 572 AlgorithmConfig::SlidingFB(ref algconfig) => { |
647 AlgorithmConfig::SlidingFB(ref algconfig) => { |
| 573 match (regularisation, dataterm) { |
648 match (regularisation, dataterm) { |
| 574 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
649 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
| 575 running(); |
650 running(); |
| 576 pointsource_sliding_fb_reg( |
651 pointsource_sliding_fb_reg( |
| 683 ) -> DynError |
758 ) -> DynError |
| 684 where F : Float + ToNalgebraRealField, |
759 where F : Float + ToNalgebraRealField, |
| 685 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
760 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
| 686 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
761 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
| 687 Kernel : RealMapping<F, N> + Support<F, N>, |
762 Kernel : RealMapping<F, N> + Support<F, N>, |
| 688 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, |
763 Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>, |
| |
764 //Differential<Loc<F, N>, Convolution<Sensor, Spread>> : RealVectorField<F, N, N>, |
| 689 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
765 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
| 690 𝒟::Codomain : RealMapping<F, N>, |
766 𝒟::Codomain : RealMapping<F, N>, |
| 691 A : ForwardModel<Loc<F, N>, F>, |
767 A : ForwardModel<Loc<F, N>, F>, |
| 692 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, |
768 A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>, |
| 693 PlotLookup : Plotting<N>, |
769 PlotLookup : Plotting<N>, |
| 694 Cube<F, N> : SetOrd { |
770 Cube<F, N> : SetOrd { |
| 695 |
771 |
| 696 if cli.plot < PlotLevel::Data { |
772 if cli.plot < PlotLevel::Data { |
| 697 return Ok(()) |
773 return Ok(()) |
| 704 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
780 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
| 705 |
781 |
| 706 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); |
782 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); |
| 707 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); |
783 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); |
| 708 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); |
784 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); |
| 709 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); |
785 PlotLookup::plot_into_file_diff(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); |
| 710 |
786 |
| 711 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
787 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
| 712 |
788 |
| 713 let ω_hat = op𝒟.apply(μ_hat); |
789 let ω_hat = op𝒟.apply(μ_hat); |
| 714 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
790 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
| 723 "Aᵀb".to_string(), &preadj_b, |
799 "Aᵀb".to_string(), &preadj_b, |
| 724 "Aᵀb̂".to_string(), Some(&preadj_b_hat), |
800 "Aᵀb̂".to_string(), Some(&preadj_b_hat), |
| 725 plotgrid2, None, &μ_hat, |
801 plotgrid2, None, &μ_hat, |
| 726 pfx("omega_b") |
802 pfx("omega_b") |
| 727 ); |
803 ); |
| |
804 PlotLookup::plot_into_file_diff(&preadj_b, plotgrid2, pfx("preadj_b"), |
| |
805 "preadj_b".to_string()); |
| |
806 PlotLookup::plot_into_file_diff(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"), |
| |
807 "preadj_b_hat".to_string()); |
| 728 |
808 |
| 729 // Save true solution and observables |
809 // Save true solution and observables |
| 730 let pfx = |n| format!("{}{}", prefix, n); |
810 let pfx = |n| format!("{}{}", prefix, n); |
| 731 μ_hat.write_csv(pfx("orig.txt"))?; |
811 μ_hat.write_csv(pfx("orig.txt"))?; |
| 732 opA.write_observable(&b_hat, pfx("b_hat"))?; |
812 opA.write_observable(&b_hat, pfx("b_hat"))?; |