29 }; |
29 }; |
30 use alg_tools::logger::Logger; |
30 use alg_tools::logger::Logger; |
31 use alg_tools::error::DynError; |
31 use alg_tools::error::DynError; |
32 use alg_tools::tabledump::TableDump; |
32 use alg_tools::tabledump::TableDump; |
33 use alg_tools::sets::Cube; |
33 use alg_tools::sets::Cube; |
34 use alg_tools::mapping::{RealMapping, Differentiable}; |
34 use alg_tools::mapping::{ |
|
35 RealMapping, |
|
36 DifferentiableRealMapping |
|
37 }; |
35 use alg_tools::nalgebra_support::ToNalgebraRealField; |
38 use alg_tools::nalgebra_support::ToNalgebraRealField; |
36 use alg_tools::euclidean::Euclidean; |
39 use alg_tools::euclidean::Euclidean; |
37 use alg_tools::lingrid::lingrid; |
40 use alg_tools::lingrid::lingrid; |
38 use alg_tools::sets::SetOrd; |
41 use alg_tools::sets::SetOrd; |
39 |
42 |
112 |
122 |
113 use AlgorithmConfig::*; |
123 use AlgorithmConfig::*; |
114 match self { |
124 match self { |
115 FB(fb) => FB(FBConfig { |
125 FB(fb) => FB(FBConfig { |
116 τ0 : cli.tau0.unwrap_or(fb.τ0), |
126 τ0 : cli.tau0.unwrap_or(fb.τ0), |
117 insertion : override_fb_generic(fb.insertion), |
127 generic : override_fb_generic(fb.generic), |
118 .. fb |
128 .. fb |
119 }), |
129 }), |
120 FISTA(fb) => FISTA(FBConfig { |
130 FISTA(fb) => FISTA(FBConfig { |
121 τ0 : cli.tau0.unwrap_or(fb.τ0), |
131 τ0 : cli.tau0.unwrap_or(fb.τ0), |
122 insertion : override_fb_generic(fb.insertion), |
132 generic : override_fb_generic(fb.generic), |
123 .. fb |
133 .. fb |
124 }), |
134 }), |
125 PDPS(pdps) => PDPS(PDPSConfig { |
135 PDPS(pdps) => PDPS(PDPSConfig { |
126 τ0 : cli.tau0.unwrap_or(pdps.τ0), |
136 τ0 : cli.tau0.unwrap_or(pdps.τ0), |
127 σ0 : cli.sigma0.unwrap_or(pdps.σ0), |
137 σ0 : cli.sigma0.unwrap_or(pdps.σ0), |
128 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
138 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
129 insertion : override_fb_generic(pdps.insertion), |
139 generic : override_fb_generic(pdps.generic), |
130 .. pdps |
140 .. pdps |
131 }), |
141 }), |
132 FW(fw) => FW(FWConfig { |
142 FW(fw) => FW(FWConfig { |
133 merging : cli.merging.clone().unwrap_or(fw.merging), |
143 merging : cli.merging.clone().unwrap_or(fw.merging), |
134 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), |
144 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), |
135 .. fw |
145 .. fw |
136 }), |
146 }), |
|
147 RadonFB(fb) => RadonFB(RadonFBConfig { |
|
148 τ0 : cli.tau0.unwrap_or(fb.τ0), |
|
149 insertion : override_fb_generic(fb.insertion), |
|
150 .. fb |
|
151 }), |
|
152 RadonFISTA(fb) => RadonFISTA(RadonFBConfig { |
|
153 τ0 : cli.tau0.unwrap_or(fb.τ0), |
|
154 insertion : override_fb_generic(fb.insertion), |
|
155 .. fb |
|
156 }), |
137 SlidingFB(sfb) => SlidingFB(SlidingFBConfig { |
157 SlidingFB(sfb) => SlidingFB(SlidingFBConfig { |
138 τ0 : cli.tau0.unwrap_or(sfb.τ0), |
158 τ0 : cli.tau0.unwrap_or(sfb.τ0), |
|
159 θ0 : cli.theta0.unwrap_or(sfb.θ0), |
|
160 transport_tolerance_ω: cli.transport_tolerance_omega.unwrap_or(sfb.transport_tolerance_ω), |
|
161 transport_tolerance_dv: cli.transport_tolerance_dv.unwrap_or(sfb.transport_tolerance_dv), |
139 insertion : override_fb_generic(sfb.insertion), |
162 insertion : override_fb_generic(sfb.insertion), |
140 .. sfb |
163 .. sfb |
141 }), |
164 }), |
142 } |
165 } |
143 } |
166 } |
185 FWRelax => AlgorithmConfig::FW(FWConfig{ |
214 FWRelax => AlgorithmConfig::FW(FWConfig{ |
186 variant : FWVariant::Relaxed, |
215 variant : FWVariant::Relaxed, |
187 .. Default::default() |
216 .. Default::default() |
188 }), |
217 }), |
189 PDPS => AlgorithmConfig::PDPS(Default::default()), |
218 PDPS => AlgorithmConfig::PDPS(Default::default()), |
|
219 RadonFB => AlgorithmConfig::RadonFB(Default::default()), |
|
220 RadonFISTA => AlgorithmConfig::RadonFISTA(Default::default()), |
190 SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), |
221 SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), |
191 } |
222 } |
192 } |
223 } |
193 |
224 |
194 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand |
225 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand |
362 // decide availability of sliding based on it. |
393 // decide availability of sliding based on it. |
363 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
394 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
364 // TODO: very weird that rust only compiles with Differentiable |
395 // TODO: very weird that rust only compiles with Differentiable |
365 // instead of the above one on references, which is required by |
396 // instead of the above one on references, which is required by |
366 // poitsource_sliding_fb_reg. |
397 // poitsource_sliding_fb_reg. |
367 + Differentiable<Loc<F, N>, Output = Loc<F, N>> |
398 + DifferentiableRealMapping<F, N> |
368 + Lipschitz<L2>, |
399 + Lipschitz<L2, FloatType=F>, |
369 // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
400 // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
370 AutoConvolution<P> : BoundedBy<F, K>, |
401 AutoConvolution<P> : BoundedBy<F, K>, |
371 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> |
402 K : SimpleConvolutionKernel<F, N> |
|
403 + LocalAnalysis<F, Bounds<F>, N> |
372 + Copy + Serialize + std::fmt::Debug, |
404 + Copy + Serialize + std::fmt::Debug, |
373 //+ Differentiable<Loc<F, N>, Output = Loc<F, N>>, // TODO: shouldn't need to assume differentiability |
|
374 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
405 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
375 PlotLookup : Plotting<N>, |
406 PlotLookup : Plotting<N>, |
376 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
407 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
377 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
408 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
378 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
409 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
567 not_implemented(); |
598 not_implemented(); |
568 continue |
599 continue |
569 } |
600 } |
570 } |
601 } |
571 }, |
602 }, |
|
603 AlgorithmConfig::RadonFB(ref algconfig) => { |
|
604 match (regularisation, dataterm) { |
|
605 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
|
606 running(); |
|
607 pointsource_radon_fb_reg( |
|
608 &opA, &b, NonnegRadonRegTerm(α), algconfig, |
|
609 iterator, plotter |
|
610 ) |
|
611 }, |
|
612 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
|
613 running(); |
|
614 pointsource_radon_fb_reg( |
|
615 &opA, &b, RadonRegTerm(α), algconfig, |
|
616 iterator, plotter |
|
617 ) |
|
618 }, |
|
619 _ => { |
|
620 not_implemented(); |
|
621 continue |
|
622 } |
|
623 } |
|
624 }, |
|
625 AlgorithmConfig::RadonFISTA(ref algconfig) => { |
|
626 match (regularisation, dataterm) { |
|
627 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
|
628 running(); |
|
629 pointsource_radon_fista_reg( |
|
630 &opA, &b, NonnegRadonRegTerm(α), algconfig, |
|
631 iterator, plotter |
|
632 ) |
|
633 }, |
|
634 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
|
635 running(); |
|
636 pointsource_radon_fista_reg( |
|
637 &opA, &b, RadonRegTerm(α), algconfig, |
|
638 iterator, plotter |
|
639 ) |
|
640 }, |
|
641 _ => { |
|
642 not_implemented(); |
|
643 continue |
|
644 } |
|
645 } |
|
646 }, |
572 AlgorithmConfig::SlidingFB(ref algconfig) => { |
647 AlgorithmConfig::SlidingFB(ref algconfig) => { |
573 match (regularisation, dataterm) { |
648 match (regularisation, dataterm) { |
574 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
649 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
575 running(); |
650 running(); |
576 pointsource_sliding_fb_reg( |
651 pointsource_sliding_fb_reg( |
683 ) -> DynError |
758 ) -> DynError |
684 where F : Float + ToNalgebraRealField, |
759 where F : Float + ToNalgebraRealField, |
685 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
760 Sensor : RealMapping<F, N> + Support<F, N> + Clone, |
686 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
761 Spread : RealMapping<F, N> + Support<F, N> + Clone, |
687 Kernel : RealMapping<F, N> + Support<F, N>, |
762 Kernel : RealMapping<F, N> + Support<F, N>, |
688 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, |
763 Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>, |
|
764 //Differential<Loc<F, N>, Convolution<Sensor, Spread>> : RealVectorField<F, N, N>, |
689 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
765 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, |
690 𝒟::Codomain : RealMapping<F, N>, |
766 𝒟::Codomain : RealMapping<F, N>, |
691 A : ForwardModel<Loc<F, N>, F>, |
767 A : ForwardModel<Loc<F, N>, F>, |
692 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, |
768 A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>, |
693 PlotLookup : Plotting<N>, |
769 PlotLookup : Plotting<N>, |
694 Cube<F, N> : SetOrd { |
770 Cube<F, N> : SetOrd { |
695 |
771 |
696 if cli.plot < PlotLevel::Data { |
772 if cli.plot < PlotLevel::Data { |
697 return Ok(()) |
773 return Ok(()) |
704 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
780 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); |
705 |
781 |
706 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); |
782 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); |
707 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); |
783 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); |
708 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); |
784 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); |
709 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); |
785 PlotLookup::plot_into_file_diff(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); |
710 |
786 |
711 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
787 let plotgrid2 = lingrid(&domain, &[resolution; N]); |
712 |
788 |
713 let ω_hat = op𝒟.apply(μ_hat); |
789 let ω_hat = op𝒟.apply(μ_hat); |
714 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
790 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); |
723 "Aᵀb".to_string(), &preadj_b, |
799 "Aᵀb".to_string(), &preadj_b, |
724 "Aᵀb̂".to_string(), Some(&preadj_b_hat), |
800 "Aᵀb̂".to_string(), Some(&preadj_b_hat), |
725 plotgrid2, None, &μ_hat, |
801 plotgrid2, None, &μ_hat, |
726 pfx("omega_b") |
802 pfx("omega_b") |
727 ); |
803 ); |
|
804 PlotLookup::plot_into_file_diff(&preadj_b, plotgrid2, pfx("preadj_b"), |
|
805 "preadj_b".to_string()); |
|
806 PlotLookup::plot_into_file_diff(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"), |
|
807 "preadj_b_hat".to_string()); |
728 |
808 |
729 // Save true solution and observables |
809 // Save true solution and observables |
730 let pfx = |n| format!("{}{}", prefix, n); |
810 let pfx = |n| format!("{}{}", prefix, n); |
731 μ_hat.write_csv(pfx("orig.txt"))?; |
811 μ_hat.write_csv(pfx("orig.txt"))?; |
732 opA.write_observable(&b_hat, pfx("b_hat"))?; |
812 opA.write_observable(&b_hat, pfx("b_hat"))?; |