--- a/src/run.rs Fri Apr 28 13:15:19 2023 +0300 +++ b/src/run.rs Tue Dec 31 09:34:24 2024 -0500 @@ -31,10 +31,9 @@ use alg_tools::error::DynError; use alg_tools::tabledump::TableDump; use alg_tools::sets::Cube; -use alg_tools::mapping::RealMapping; +use alg_tools::mapping::{RealMapping, Differentiable}; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::euclidean::Euclidean; -use alg_tools::norms::L1; use alg_tools::lingrid::lingrid; use alg_tools::sets::SetOrd; @@ -45,13 +44,16 @@ use crate::forward_model::*; use crate::fb::{ FBConfig, + FBGenericConfig, pointsource_fb_reg, - FBMetaAlgorithm, - FBGenericConfig, + pointsource_fista_reg, +}; +use crate::sliding_fb::{ + SlidingFBConfig, + pointsource_sliding_fb_reg }; use crate::pdps::{ PDPSConfig, - L2Squared, pointsource_pdps_reg, }; use crate::frank_wolfe::{ @@ -65,14 +67,25 @@ use crate::plot::*; use crate::{AlgorithmOverrides, CommandLineArgs}; use crate::tolerance::Tolerance; -use crate::regularisation::{Regularisation, RadonRegTerm, NonnegRadonRegTerm}; +use crate::regularisation::{ + Regularisation, + RadonRegTerm, + NonnegRadonRegTerm +}; +use crate::dataterm::{ + L1, + L2Squared +}; +use alg_tools::norms::L2; /// Available algorithms and their configurations #[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub enum AlgorithmConfig<F : Float> { FB(FBConfig<F>), + FISTA(FBConfig<F>), FW(FWConfig<F>), PDPS(PDPSConfig<F>), + SlidingFB(SlidingFBConfig<F>), } fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { @@ -104,6 +117,11 @@ insertion : override_fb_generic(fb.insertion), .. fb }), + FISTA(fb) => FISTA(FBConfig { + τ0 : cli.tau0.unwrap_or(fb.τ0), + insertion : override_fb_generic(fb.insertion), + .. fb + }), PDPS(pdps) => PDPS(PDPSConfig { τ0 : cli.tau0.unwrap_or(pdps.τ0), σ0 : cli.sigma0.unwrap_or(pdps.σ0), @@ -115,7 +133,12 @@ merging : cli.merging.clone().unwrap_or(fw.merging), tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), .. fw - }) + }), + SlidingFB(sfb) => SlidingFB(SlidingFBConfig { + τ0 : cli.tau0.unwrap_or(sfb.τ0), + insertion : override_fb_generic(sfb.insertion), + .. sfb + }), } } } @@ -146,6 +169,9 @@ /// The μPDPS primal-dual proximal splitting method #[clap(name = "pdps")] PDPS, + /// The Sliding FB method + #[clap(name = "sliding_fb", alias = "sfb")] + SlidingFB, } impl DefaultAlgorithm { @@ -154,16 +180,14 @@ use DefaultAlgorithm::*; match *self { FB => AlgorithmConfig::FB(Default::default()), - FISTA => AlgorithmConfig::FB(FBConfig{ - meta : FBMetaAlgorithm::InertiaFISTA, - .. Default::default() - }), + FISTA => AlgorithmConfig::FISTA(Default::default()), FW => AlgorithmConfig::FW(Default::default()), FWRelax => AlgorithmConfig::FW(FWConfig{ variant : FWVariant::Relaxed, .. Default::default() }), PDPS => AlgorithmConfig::PDPS(Default::default()), + SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), } } @@ -333,10 +357,20 @@ [usize; N] : Serialize, S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, - Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, + Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy + // TODO: shold not have differentiability as a requirement, but + // decide availability of sliding based on it. + //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, + // TODO: very weird that rust only compiles with Differentiable + // instead of the above one on references, which is required by + // poitsource_sliding_fb_reg. + + Differentiable<Loc<F, N>, Output = Loc<F, N>> + + Lipschitz<L2>, + // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, AutoConvolution<P> : BoundedBy<F, K>, - K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + Copy + Serialize + std::fmt::Debug, + //+ Differentiable<Loc<F, N>, Output = Loc<F, N>>, // TODO: shouldn't need to assume differentiability Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, PlotLookup : Plotting<N>, DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, @@ -513,6 +547,50 @@ } } }, + AlgorithmConfig::FISTA(ref algconfig) => { + match (regularisation, dataterm) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { + running(); + pointsource_fista_reg( + &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, + iterator, plotter + ) + }, + (Regularisation::Radon(α), DataTerm::L2Squared) => { + running(); + pointsource_fista_reg( + &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, + iterator, plotter + ) + }, + _ => { + not_implemented(); + continue + } + } + }, + AlgorithmConfig::SlidingFB(ref algconfig) => { + match (regularisation, dataterm) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { + running(); + pointsource_sliding_fb_reg( + &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, + iterator, plotter + ) + }, + (Regularisation::Radon(α), DataTerm::L2Squared) => { + running(); + pointsource_sliding_fb_reg( + &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, + iterator, plotter + ) + }, + _ => { + not_implemented(); + continue + } + } + }, AlgorithmConfig::PDPS(ref algconfig) => { running(); match (regularisation, dataterm) {