src/run.rs

branch
dev
changeset 34
efa60bc4f743
parent 32
56c8adc32b09
child 35
b087e3eab191
equal deleted inserted replaced
33:aec67cdd6b14 34:efa60bc4f743
29 }; 29 };
30 use alg_tools::logger::Logger; 30 use alg_tools::logger::Logger;
31 use alg_tools::error::DynError; 31 use alg_tools::error::DynError;
32 use alg_tools::tabledump::TableDump; 32 use alg_tools::tabledump::TableDump;
33 use alg_tools::sets::Cube; 33 use alg_tools::sets::Cube;
34 use alg_tools::mapping::{RealMapping, Differentiable}; 34 use alg_tools::mapping::{
35 RealMapping,
36 DifferentiableRealMapping
37 };
35 use alg_tools::nalgebra_support::ToNalgebraRealField; 38 use alg_tools::nalgebra_support::ToNalgebraRealField;
36 use alg_tools::euclidean::Euclidean; 39 use alg_tools::euclidean::Euclidean;
37 use alg_tools::lingrid::lingrid; 40 use alg_tools::lingrid::lingrid;
38 use alg_tools::sets::SetOrd; 41 use alg_tools::sets::SetOrd;
39 42
46 FBConfig, 49 FBConfig,
47 FBGenericConfig, 50 FBGenericConfig,
48 pointsource_fb_reg, 51 pointsource_fb_reg,
49 pointsource_fista_reg, 52 pointsource_fista_reg,
50 }; 53 };
54 use crate::radon_fb::{
55 RadonFBConfig,
56 pointsource_radon_fb_reg,
57 pointsource_radon_fista_reg,
58 };
51 use crate::sliding_fb::{ 59 use crate::sliding_fb::{
52 SlidingFBConfig, 60 SlidingFBConfig,
53 pointsource_sliding_fb_reg 61 pointsource_sliding_fb_reg
54 }; 62 };
55 use crate::pdps::{ 63 use crate::pdps::{
83 pub enum AlgorithmConfig<F : Float> { 91 pub enum AlgorithmConfig<F : Float> {
84 FB(FBConfig<F>), 92 FB(FBConfig<F>),
85 FISTA(FBConfig<F>), 93 FISTA(FBConfig<F>),
86 FW(FWConfig<F>), 94 FW(FWConfig<F>),
87 PDPS(PDPSConfig<F>), 95 PDPS(PDPSConfig<F>),
96 RadonFB(RadonFBConfig<F>),
97 RadonFISTA(RadonFBConfig<F>),
88 SlidingFB(SlidingFBConfig<F>), 98 SlidingFB(SlidingFBConfig<F>),
89 } 99 }
90 100
91 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { 101 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> {
92 assert!(v.len() == 3); 102 assert!(v.len() == 3);
112 122
113 use AlgorithmConfig::*; 123 use AlgorithmConfig::*;
114 match self { 124 match self {
115 FB(fb) => FB(FBConfig { 125 FB(fb) => FB(FBConfig {
116 τ0 : cli.tau0.unwrap_or(fb.τ0), 126 τ0 : cli.tau0.unwrap_or(fb.τ0),
117 insertion : override_fb_generic(fb.insertion), 127 generic : override_fb_generic(fb.generic),
118 .. fb 128 .. fb
119 }), 129 }),
120 FISTA(fb) => FISTA(FBConfig { 130 FISTA(fb) => FISTA(FBConfig {
121 τ0 : cli.tau0.unwrap_or(fb.τ0), 131 τ0 : cli.tau0.unwrap_or(fb.τ0),
122 insertion : override_fb_generic(fb.insertion), 132 generic : override_fb_generic(fb.generic),
123 .. fb 133 .. fb
124 }), 134 }),
125 PDPS(pdps) => PDPS(PDPSConfig { 135 PDPS(pdps) => PDPS(PDPSConfig {
126 τ0 : cli.tau0.unwrap_or(pdps.τ0), 136 τ0 : cli.tau0.unwrap_or(pdps.τ0),
127 σ0 : cli.sigma0.unwrap_or(pdps.σ0), 137 σ0 : cli.sigma0.unwrap_or(pdps.σ0),
128 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), 138 acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
129 insertion : override_fb_generic(pdps.insertion), 139 generic : override_fb_generic(pdps.generic),
130 .. pdps 140 .. pdps
131 }), 141 }),
132 FW(fw) => FW(FWConfig { 142 FW(fw) => FW(FWConfig {
133 merging : cli.merging.clone().unwrap_or(fw.merging), 143 merging : cli.merging.clone().unwrap_or(fw.merging),
134 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), 144 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance),
135 .. fw 145 .. fw
136 }), 146 }),
147 RadonFB(fb) => RadonFB(RadonFBConfig {
148 τ0 : cli.tau0.unwrap_or(fb.τ0),
149 insertion : override_fb_generic(fb.insertion),
150 .. fb
151 }),
152 RadonFISTA(fb) => RadonFISTA(RadonFBConfig {
153 τ0 : cli.tau0.unwrap_or(fb.τ0),
154 insertion : override_fb_generic(fb.insertion),
155 .. fb
156 }),
137 SlidingFB(sfb) => SlidingFB(SlidingFBConfig { 157 SlidingFB(sfb) => SlidingFB(SlidingFBConfig {
138 τ0 : cli.tau0.unwrap_or(sfb.τ0), 158 τ0 : cli.tau0.unwrap_or(sfb.τ0),
159 θ0 : cli.theta0.unwrap_or(sfb.θ0),
160 transport_tolerance_ω: cli.transport_tolerance_omega.unwrap_or(sfb.transport_tolerance_ω),
161 transport_tolerance_dv: cli.transport_tolerance_dv.unwrap_or(sfb.transport_tolerance_dv),
139 insertion : override_fb_generic(sfb.insertion), 162 insertion : override_fb_generic(sfb.insertion),
140 .. sfb 163 .. sfb
141 }), 164 }),
142 } 165 }
143 } 166 }
167 #[clap(name = "fwrelax")] 190 #[clap(name = "fwrelax")]
168 FWRelax, 191 FWRelax,
169 /// The μPDPS primal-dual proximal splitting method 192 /// The μPDPS primal-dual proximal splitting method
170 #[clap(name = "pdps")] 193 #[clap(name = "pdps")]
171 PDPS, 194 PDPS,
195 /// The RadonFB forward-backward method
196 #[clap(name = "radon_fb")]
197 RadonFB,
198 /// The RadonFISTA inertial forward-backward method
199 #[clap(name = "radon_fista")]
200 RadonFISTA,
172 /// The Sliding FB method 201 /// The Sliding FB method
173 #[clap(name = "sliding_fb", alias = "sfb")] 202 #[clap(name = "sliding_fb", alias = "sfb")]
174 SlidingFB, 203 SlidingFB,
175 } 204 }
176 205
185 FWRelax => AlgorithmConfig::FW(FWConfig{ 214 FWRelax => AlgorithmConfig::FW(FWConfig{
186 variant : FWVariant::Relaxed, 215 variant : FWVariant::Relaxed,
187 .. Default::default() 216 .. Default::default()
188 }), 217 }),
189 PDPS => AlgorithmConfig::PDPS(Default::default()), 218 PDPS => AlgorithmConfig::PDPS(Default::default()),
219 RadonFB => AlgorithmConfig::RadonFB(Default::default()),
220 RadonFISTA => AlgorithmConfig::RadonFISTA(Default::default()),
190 SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), 221 SlidingFB => AlgorithmConfig::SlidingFB(Default::default()),
191 } 222 }
192 } 223 }
193 224
194 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand 225 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand
362 // decide availability of sliding based on it. 393 // decide availability of sliding based on it.
363 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, 394 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>,
364 // TODO: very weird that rust only compiles with Differentiable 395 // TODO: very weird that rust only compiles with Differentiable
365 // instead of the above one on references, which is required by 396 // instead of the above one on references, which is required by
366 // poitsource_sliding_fb_reg. 397 // poitsource_sliding_fb_reg.
367 + Differentiable<Loc<F, N>, Output = Loc<F, N>> 398 + DifferentiableRealMapping<F, N>
368 + Lipschitz<L2>, 399 + Lipschitz<L2, FloatType=F>,
369 // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, 400 // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>,
370 AutoConvolution<P> : BoundedBy<F, K>, 401 AutoConvolution<P> : BoundedBy<F, K>,
371 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> 402 K : SimpleConvolutionKernel<F, N>
403 + LocalAnalysis<F, Bounds<F>, N>
372 + Copy + Serialize + std::fmt::Debug, 404 + Copy + Serialize + std::fmt::Debug,
373 //+ Differentiable<Loc<F, N>, Output = Loc<F, N>>, // TODO: shouldn't need to assume differentiability
374 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, 405 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
375 PlotLookup : Plotting<N>, 406 PlotLookup : Plotting<N>,
376 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, 407 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
377 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 408 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
378 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 409 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
567 not_implemented(); 598 not_implemented();
568 continue 599 continue
569 } 600 }
570 } 601 }
571 }, 602 },
603 AlgorithmConfig::RadonFB(ref algconfig) => {
604 match (regularisation, dataterm) {
605 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
606 running();
607 pointsource_radon_fb_reg(
608 &opA, &b, NonnegRadonRegTerm(α), algconfig,
609 iterator, plotter
610 )
611 },
612 (Regularisation::Radon(α), DataTerm::L2Squared) => {
613 running();
614 pointsource_radon_fb_reg(
615 &opA, &b, RadonRegTerm(α), algconfig,
616 iterator, plotter
617 )
618 },
619 _ => {
620 not_implemented();
621 continue
622 }
623 }
624 },
625 AlgorithmConfig::RadonFISTA(ref algconfig) => {
626 match (regularisation, dataterm) {
627 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
628 running();
629 pointsource_radon_fista_reg(
630 &opA, &b, NonnegRadonRegTerm(α), algconfig,
631 iterator, plotter
632 )
633 },
634 (Regularisation::Radon(α), DataTerm::L2Squared) => {
635 running();
636 pointsource_radon_fista_reg(
637 &opA, &b, RadonRegTerm(α), algconfig,
638 iterator, plotter
639 )
640 },
641 _ => {
642 not_implemented();
643 continue
644 }
645 }
646 },
572 AlgorithmConfig::SlidingFB(ref algconfig) => { 647 AlgorithmConfig::SlidingFB(ref algconfig) => {
573 match (regularisation, dataterm) { 648 match (regularisation, dataterm) {
574 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { 649 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
575 running(); 650 running();
576 pointsource_sliding_fb_reg( 651 pointsource_sliding_fb_reg(
683 ) -> DynError 758 ) -> DynError
684 where F : Float + ToNalgebraRealField, 759 where F : Float + ToNalgebraRealField,
685 Sensor : RealMapping<F, N> + Support<F, N> + Clone, 760 Sensor : RealMapping<F, N> + Support<F, N> + Clone,
686 Spread : RealMapping<F, N> + Support<F, N> + Clone, 761 Spread : RealMapping<F, N> + Support<F, N> + Clone,
687 Kernel : RealMapping<F, N> + Support<F, N>, 762 Kernel : RealMapping<F, N> + Support<F, N>,
688 Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, 763 Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>,
764 //Differential<Loc<F, N>, Convolution<Sensor, Spread>> : RealVectorField<F, N, N>,
689 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, 765 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>,
690 𝒟::Codomain : RealMapping<F, N>, 766 𝒟::Codomain : RealMapping<F, N>,
691 A : ForwardModel<Loc<F, N>, F>, 767 A : ForwardModel<Loc<F, N>, F>,
692 A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, 768 A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>,
693 PlotLookup : Plotting<N>, 769 PlotLookup : Plotting<N>,
694 Cube<F, N> : SetOrd { 770 Cube<F, N> : SetOrd {
695 771
696 if cli.plot < PlotLevel::Data { 772 if cli.plot < PlotLevel::Data {
697 return Ok(()) 773 return Ok(())
704 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]); 780 let plotgrid = lingrid(&[[-kernel_plot_width, kernel_plot_width]; N].into(), &[resolution; N]);
705 781
706 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); 782 PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string());
707 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); 783 PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string());
708 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); 784 PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string());
709 PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); 785 PlotLookup::plot_into_file_diff(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string());
710 786
711 let plotgrid2 = lingrid(&domain, &[resolution; N]); 787 let plotgrid2 = lingrid(&domain, &[resolution; N]);
712 788
713 let ω_hat = op𝒟.apply(μ_hat); 789 let ω_hat = op𝒟.apply(μ_hat);
714 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b); 790 let noise = opA.preadjoint().apply(opA.apply(μ_hat) - b);
723 "Aᵀb".to_string(), &preadj_b, 799 "Aᵀb".to_string(), &preadj_b,
724 "Aᵀb̂".to_string(), Some(&preadj_b_hat), 800 "Aᵀb̂".to_string(), Some(&preadj_b_hat),
725 plotgrid2, None, &μ_hat, 801 plotgrid2, None, &μ_hat,
726 pfx("omega_b") 802 pfx("omega_b")
727 ); 803 );
804 PlotLookup::plot_into_file_diff(&preadj_b, plotgrid2, pfx("preadj_b"),
805 "preadj_b".to_string());
806 PlotLookup::plot_into_file_diff(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"),
807 "preadj_b_hat".to_string());
728 808
729 // Save true solution and observables 809 // Save true solution and observables
730 let pfx = |n| format!("{}{}", prefix, n); 810 let pfx = |n| format!("{}{}", prefix, n);
731 μ_hat.write_csv(pfx("orig.txt"))?; 811 μ_hat.write_csv(pfx("orig.txt"))?;
732 opA.write_observable(&b_hat, pfx("b_hat"))?; 812 opA.write_observable(&b_hat, pfx("b_hat"))?;

mercurial