src/run.rs

branch
dev
changeset 34
efa60bc4f743
parent 32
56c8adc32b09
child 35
b087e3eab191
--- a/src/run.rs	Tue Aug 01 10:32:12 2023 +0300
+++ b/src/run.rs	Thu Aug 29 00:00:00 2024 -0500
@@ -31,7 +31,10 @@
 use alg_tools::error::DynError;
 use alg_tools::tabledump::TableDump;
 use alg_tools::sets::Cube;
-use alg_tools::mapping::{RealMapping, Differentiable};
+use alg_tools::mapping::{
+    RealMapping,
+    DifferentiableRealMapping
+};
 use alg_tools::nalgebra_support::ToNalgebraRealField;
 use alg_tools::euclidean::Euclidean;
 use alg_tools::lingrid::lingrid;
@@ -48,6 +51,11 @@
     pointsource_fb_reg,
     pointsource_fista_reg,
 };
+use crate::radon_fb::{
+    RadonFBConfig,
+    pointsource_radon_fb_reg,
+    pointsource_radon_fista_reg,
+};
 use crate::sliding_fb::{
     SlidingFBConfig,
     pointsource_sliding_fb_reg
@@ -85,6 +93,8 @@
     FISTA(FBConfig<F>),
     FW(FWConfig<F>),
     PDPS(PDPSConfig<F>),
+    RadonFB(RadonFBConfig<F>),
+    RadonFISTA(RadonFBConfig<F>),
     SlidingFB(SlidingFBConfig<F>),
 }
 
@@ -114,19 +124,19 @@
         match self {
             FB(fb) => FB(FBConfig {
                 τ0 : cli.tau0.unwrap_or(fb.τ0),
-                insertion : override_fb_generic(fb.insertion),
+                generic : override_fb_generic(fb.generic),
                 .. fb
             }),
             FISTA(fb) => FISTA(FBConfig {
                 τ0 : cli.tau0.unwrap_or(fb.τ0),
-                insertion : override_fb_generic(fb.insertion),
+                generic : override_fb_generic(fb.generic),
                 .. fb
             }),
             PDPS(pdps) => PDPS(PDPSConfig {
                 τ0 : cli.tau0.unwrap_or(pdps.τ0),
                 σ0 : cli.sigma0.unwrap_or(pdps.σ0),
                 acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
-                insertion : override_fb_generic(pdps.insertion),
+                generic : override_fb_generic(pdps.generic),
                 .. pdps
             }),
             FW(fw) => FW(FWConfig {
@@ -134,8 +144,21 @@
                 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance),
                 .. fw
             }),
+            RadonFB(fb) => RadonFB(RadonFBConfig {
+                τ0 : cli.tau0.unwrap_or(fb.τ0),
+                insertion : override_fb_generic(fb.insertion),
+                .. fb
+            }),
+            RadonFISTA(fb) => RadonFISTA(RadonFBConfig {
+                τ0 : cli.tau0.unwrap_or(fb.τ0),
+                insertion : override_fb_generic(fb.insertion),
+                .. fb
+            }),
             SlidingFB(sfb) => SlidingFB(SlidingFBConfig {
                 τ0 : cli.tau0.unwrap_or(sfb.τ0),
+                θ0 : cli.theta0.unwrap_or(sfb.θ0),
+                transport_tolerance_ω: cli.transport_tolerance_omega.unwrap_or(sfb.transport_tolerance_ω),
+                transport_tolerance_dv: cli.transport_tolerance_dv.unwrap_or(sfb.transport_tolerance_dv),
                 insertion : override_fb_generic(sfb.insertion),
                 .. sfb
             }),
@@ -169,6 +192,12 @@
     /// The μPDPS primal-dual proximal splitting method
     #[clap(name = "pdps")]
     PDPS,
+    /// The RadonFB forward-backward method
+    #[clap(name = "radon_fb")]
+    RadonFB,
+    /// The RadonFISTA inertial forward-backward method
+    #[clap(name = "radon_fista")]
+    RadonFISTA,
     /// The Sliding FB method
     #[clap(name = "sliding_fb", alias = "sfb")]
     SlidingFB,
@@ -187,6 +216,8 @@
                 .. Default::default()
             }),
             PDPS => AlgorithmConfig::PDPS(Default::default()),
+            RadonFB => AlgorithmConfig::RadonFB(Default::default()),
+            RadonFISTA => AlgorithmConfig::RadonFISTA(Default::default()),
             SlidingFB => AlgorithmConfig::SlidingFB(Default::default()),
         }
     }
@@ -364,13 +395,13 @@
                          // TODO: very weird that rust only compiles with Differentiable
                          // instead of the above one on references, which is required by
                          // poitsource_sliding_fb_reg.
-                         + Differentiable<Loc<F, N>, Output = Loc<F, N>>
-                         + Lipschitz<L2>,
+                         + DifferentiableRealMapping<F, N>
+                         + Lipschitz<L2, FloatType=F>,
       // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>,
       AutoConvolution<P> : BoundedBy<F, K>,
-      K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N>
+      K : SimpleConvolutionKernel<F, N>
+          + LocalAnalysis<F, Bounds<F>, N>
           + Copy + Serialize + std::fmt::Debug,
-          //+ Differentiable<Loc<F, N>, Output = Loc<F, N>>, // TODO: shouldn't need to assume differentiability
       Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
       PlotLookup : Plotting<N>,
       DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
@@ -569,6 +600,50 @@
                         }
                     }
                 },
+                AlgorithmConfig::RadonFB(ref algconfig) => {
+                    match (regularisation, dataterm) {
+                        (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
+                            running();
+                            pointsource_radon_fb_reg(
+                                &opA, &b, NonnegRadonRegTerm(α), algconfig,
+                                iterator, plotter
+                            )
+                        },
+                        (Regularisation::Radon(α), DataTerm::L2Squared) => {
+                            running();
+                            pointsource_radon_fb_reg(
+                                &opA, &b, RadonRegTerm(α), algconfig,
+                                iterator, plotter
+                            )
+                        },
+                        _ => {
+                            not_implemented();
+                            continue
+                        }
+                    }
+                },
+                AlgorithmConfig::RadonFISTA(ref algconfig) => {
+                    match (regularisation, dataterm) {
+                        (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
+                            running();
+                            pointsource_radon_fista_reg(
+                                &opA, &b, NonnegRadonRegTerm(α), algconfig,
+                                iterator, plotter
+                            )
+                        },
+                        (Regularisation::Radon(α), DataTerm::L2Squared) => {
+                            running();
+                            pointsource_radon_fista_reg(
+                                &opA, &b, RadonRegTerm(α), algconfig,
+                                iterator, plotter
+                            )
+                        },
+                        _ => {
+                            not_implemented();
+                            continue
+                        }
+                    }
+                },
                 AlgorithmConfig::SlidingFB(ref algconfig) => {
                     match (regularisation, dataterm) {
                         (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
@@ -685,11 +760,12 @@
       Sensor : RealMapping<F, N> + Support<F, N> + Clone,
       Spread : RealMapping<F, N> + Support<F, N> + Clone,
       Kernel : RealMapping<F, N> + Support<F, N>,
-      Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>,
+      Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>,
+      //Differential<Loc<F, N>, Convolution<Sensor, Spread>> : RealVectorField<F, N, N>,
       𝒟 : DiscreteMeasureOp<Loc<F, N>, F>,
       𝒟::Codomain : RealMapping<F, N>,
       A : ForwardModel<Loc<F, N>, F>,
-      A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>,
+      A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>,
       PlotLookup : Plotting<N>,
       Cube<F, N> : SetOrd {
 
@@ -706,7 +782,7 @@
     PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string());
     PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string());
     PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string());
-    PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string());
+    PlotLookup::plot_into_file_diff(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string());
 
     let plotgrid2 = lingrid(&domain, &[resolution; N]);
 
@@ -725,6 +801,10 @@
         plotgrid2, None, &μ_hat,
         pfx("omega_b")
     );
+    PlotLookup::plot_into_file_diff(&preadj_b, plotgrid2, pfx("preadj_b"),
+                                    "preadj_b".to_string());
+    PlotLookup::plot_into_file_diff(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"),
+                                    "preadj_b_hat".to_string());
 
     // Save true solution and observables
     let pfx = |n| format!("{}{}", prefix, n);

mercurial