src/run.rs

branch
dev
changeset 32
56c8adc32b09
parent 29
87649ccfa6a8
child 34
efa60bc4f743
--- a/src/run.rs	Fri Apr 28 13:15:19 2023 +0300
+++ b/src/run.rs	Tue Dec 31 09:34:24 2024 -0500
@@ -31,10 +31,9 @@
 use alg_tools::error::DynError;
 use alg_tools::tabledump::TableDump;
 use alg_tools::sets::Cube;
-use alg_tools::mapping::RealMapping;
+use alg_tools::mapping::{RealMapping, Differentiable};
 use alg_tools::nalgebra_support::ToNalgebraRealField;
 use alg_tools::euclidean::Euclidean;
-use alg_tools::norms::L1;
 use alg_tools::lingrid::lingrid;
 use alg_tools::sets::SetOrd;
 
@@ -45,13 +44,16 @@
 use crate::forward_model::*;
 use crate::fb::{
     FBConfig,
+    FBGenericConfig,
     pointsource_fb_reg,
-    FBMetaAlgorithm,
-    FBGenericConfig,
+    pointsource_fista_reg,
+};
+use crate::sliding_fb::{
+    SlidingFBConfig,
+    pointsource_sliding_fb_reg
 };
 use crate::pdps::{
     PDPSConfig,
-    L2Squared,
     pointsource_pdps_reg,
 };
 use crate::frank_wolfe::{
@@ -65,14 +67,25 @@
 use crate::plot::*;
 use crate::{AlgorithmOverrides, CommandLineArgs};
 use crate::tolerance::Tolerance;
-use crate::regularisation::{Regularisation, RadonRegTerm, NonnegRadonRegTerm};
+use crate::regularisation::{
+    Regularisation,
+    RadonRegTerm,
+    NonnegRadonRegTerm
+};
+use crate::dataterm::{
+    L1,
+    L2Squared
+};
+use alg_tools::norms::L2;
 
 /// Available algorithms and their configurations
 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
 pub enum AlgorithmConfig<F : Float> {
     FB(FBConfig<F>),
+    FISTA(FBConfig<F>),
     FW(FWConfig<F>),
     PDPS(PDPSConfig<F>),
+    SlidingFB(SlidingFBConfig<F>),
 }
 
 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> {
@@ -104,6 +117,11 @@
                 insertion : override_fb_generic(fb.insertion),
                 .. fb
             }),
+            FISTA(fb) => FISTA(FBConfig {
+                τ0 : cli.tau0.unwrap_or(fb.τ0),
+                insertion : override_fb_generic(fb.insertion),
+                .. fb
+            }),
             PDPS(pdps) => PDPS(PDPSConfig {
                 τ0 : cli.tau0.unwrap_or(pdps.τ0),
                 σ0 : cli.sigma0.unwrap_or(pdps.σ0),
@@ -115,7 +133,12 @@
                 merging : cli.merging.clone().unwrap_or(fw.merging),
                 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance),
                 .. fw
-            })
+            }),
+            SlidingFB(sfb) => SlidingFB(SlidingFBConfig {
+                τ0 : cli.tau0.unwrap_or(sfb.τ0),
+                insertion : override_fb_generic(sfb.insertion),
+                .. sfb
+            }),
         }
     }
 }
@@ -146,6 +169,9 @@
     /// The μPDPS primal-dual proximal splitting method
     #[clap(name = "pdps")]
     PDPS,
+    /// The Sliding FB method
+    #[clap(name = "sliding_fb", alias = "sfb")]
+    SlidingFB,
 }
 
 impl DefaultAlgorithm {
@@ -154,16 +180,14 @@
         use DefaultAlgorithm::*;
         match *self {
             FB => AlgorithmConfig::FB(Default::default()),
-            FISTA => AlgorithmConfig::FB(FBConfig{
-                meta : FBMetaAlgorithm::InertiaFISTA,
-                .. Default::default()
-            }),
+            FISTA => AlgorithmConfig::FISTA(Default::default()),
             FW => AlgorithmConfig::FW(Default::default()),
             FWRelax => AlgorithmConfig::FW(FWConfig{
                 variant : FWVariant::Relaxed,
                 .. Default::default()
             }),
             PDPS => AlgorithmConfig::PDPS(Default::default()),
+            SlidingFB => AlgorithmConfig::SlidingFB(Default::default()),
         }
     }
 
@@ -333,10 +357,20 @@
       [usize; N] : Serialize,
       S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
       P : Spread<F, N> + Copy + Serialize + std::fmt::Debug,
-      Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy,
+      Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy
+                         // TODO: shold not have differentiability as a requirement, but
+                         // decide availability of sliding based on it.
+                         //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>,
+                         // TODO: very weird that rust only compiles with Differentiable
+                         // instead of the above one on references, which is required by
+                         // poitsource_sliding_fb_reg.
+                         + Differentiable<Loc<F, N>, Output = Loc<F, N>>
+                         + Lipschitz<L2>,
+      // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>,
       AutoConvolution<P> : BoundedBy<F, K>,
-      K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> 
+      K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N>
           + Copy + Serialize + std::fmt::Debug,
+          //+ Differentiable<Loc<F, N>, Output = Loc<F, N>>, // TODO: shouldn't need to assume differentiability
       Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
       PlotLookup : Plotting<N>,
       DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
@@ -513,6 +547,50 @@
                         }
                     }
                 },
+                AlgorithmConfig::FISTA(ref algconfig) => {
+                    match (regularisation, dataterm) {
+                        (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
+                            running();
+                            pointsource_fista_reg(
+                                &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
+                                iterator, plotter
+                            )
+                        },
+                        (Regularisation::Radon(α), DataTerm::L2Squared) => {
+                            running();
+                            pointsource_fista_reg(
+                                &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
+                                iterator, plotter
+                            )
+                        },
+                        _ => {
+                            not_implemented();
+                            continue
+                        }
+                    }
+                },
+                AlgorithmConfig::SlidingFB(ref algconfig) => {
+                    match (regularisation, dataterm) {
+                        (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
+                            running();
+                            pointsource_sliding_fb_reg(
+                                &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
+                                iterator, plotter
+                            )
+                        },
+                        (Regularisation::Radon(α), DataTerm::L2Squared) => {
+                            running();
+                            pointsource_sliding_fb_reg(
+                                &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
+                                iterator, plotter
+                            )
+                        },
+                        _ => {
+                            not_implemented();
+                            continue
+                        }
+                    }
+                },
                 AlgorithmConfig::PDPS(ref algconfig) => {
                     running();
                     match (regularisation, dataterm) {

mercurial