--- a/src/run.rs Tue Aug 01 10:32:12 2023 +0300 +++ b/src/run.rs Thu Aug 29 00:00:00 2024 -0500 @@ -31,7 +31,10 @@ use alg_tools::error::DynError; use alg_tools::tabledump::TableDump; use alg_tools::sets::Cube; -use alg_tools::mapping::{RealMapping, Differentiable}; +use alg_tools::mapping::{ + RealMapping, + DifferentiableRealMapping +}; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::euclidean::Euclidean; use alg_tools::lingrid::lingrid; @@ -48,6 +51,11 @@ pointsource_fb_reg, pointsource_fista_reg, }; +use crate::radon_fb::{ + RadonFBConfig, + pointsource_radon_fb_reg, + pointsource_radon_fista_reg, +}; use crate::sliding_fb::{ SlidingFBConfig, pointsource_sliding_fb_reg @@ -85,6 +93,8 @@ FISTA(FBConfig<F>), FW(FWConfig<F>), PDPS(PDPSConfig<F>), + RadonFB(RadonFBConfig<F>), + RadonFISTA(RadonFBConfig<F>), SlidingFB(SlidingFBConfig<F>), } @@ -114,19 +124,19 @@ match self { FB(fb) => FB(FBConfig { τ0 : cli.tau0.unwrap_or(fb.τ0), - insertion : override_fb_generic(fb.insertion), + generic : override_fb_generic(fb.generic), .. fb }), FISTA(fb) => FISTA(FBConfig { τ0 : cli.tau0.unwrap_or(fb.τ0), - insertion : override_fb_generic(fb.insertion), + generic : override_fb_generic(fb.generic), .. fb }), PDPS(pdps) => PDPS(PDPSConfig { τ0 : cli.tau0.unwrap_or(pdps.τ0), σ0 : cli.sigma0.unwrap_or(pdps.σ0), acceleration : cli.acceleration.unwrap_or(pdps.acceleration), - insertion : override_fb_generic(pdps.insertion), + generic : override_fb_generic(pdps.generic), .. pdps }), FW(fw) => FW(FWConfig { @@ -134,8 +144,21 @@ tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), .. fw }), + RadonFB(fb) => RadonFB(RadonFBConfig { + τ0 : cli.tau0.unwrap_or(fb.τ0), + insertion : override_fb_generic(fb.insertion), + .. fb + }), + RadonFISTA(fb) => RadonFISTA(RadonFBConfig { + τ0 : cli.tau0.unwrap_or(fb.τ0), + insertion : override_fb_generic(fb.insertion), + .. fb + }), SlidingFB(sfb) => SlidingFB(SlidingFBConfig { τ0 : cli.tau0.unwrap_or(sfb.τ0), + θ0 : cli.theta0.unwrap_or(sfb.θ0), + transport_tolerance_ω: cli.transport_tolerance_omega.unwrap_or(sfb.transport_tolerance_ω), + transport_tolerance_dv: cli.transport_tolerance_dv.unwrap_or(sfb.transport_tolerance_dv), insertion : override_fb_generic(sfb.insertion), .. sfb }), @@ -169,6 +192,12 @@ /// The μPDPS primal-dual proximal splitting method #[clap(name = "pdps")] PDPS, + /// The RadonFB forward-backward method + #[clap(name = "radon_fb")] + RadonFB, + /// The RadonFISTA inertial forward-backward method + #[clap(name = "radon_fista")] + RadonFISTA, /// The Sliding FB method #[clap(name = "sliding_fb", alias = "sfb")] SlidingFB, @@ -187,6 +216,8 @@ .. Default::default() }), PDPS => AlgorithmConfig::PDPS(Default::default()), + RadonFB => AlgorithmConfig::RadonFB(Default::default()), + RadonFISTA => AlgorithmConfig::RadonFISTA(Default::default()), SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), } } @@ -364,13 +395,13 @@ // TODO: very weird that rust only compiles with Differentiable // instead of the above one on references, which is required by // poitsource_sliding_fb_reg. - + Differentiable<Loc<F, N>, Output = Loc<F, N>> - + Lipschitz<L2>, + + DifferentiableRealMapping<F, N> + + Lipschitz<L2, FloatType=F>, // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, AutoConvolution<P> : BoundedBy<F, K>, - K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> + K : SimpleConvolutionKernel<F, N> + + LocalAnalysis<F, Bounds<F>, N> + Copy + Serialize + std::fmt::Debug, - //+ Differentiable<Loc<F, N>, Output = Loc<F, N>>, // TODO: shouldn't need to assume differentiability Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, PlotLookup : Plotting<N>, DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, @@ -569,6 +600,50 @@ } } }, + AlgorithmConfig::RadonFB(ref algconfig) => { + match (regularisation, dataterm) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { + running(); + pointsource_radon_fb_reg( + &opA, &b, NonnegRadonRegTerm(α), algconfig, + iterator, plotter + ) + }, + (Regularisation::Radon(α), DataTerm::L2Squared) => { + running(); + pointsource_radon_fb_reg( + &opA, &b, RadonRegTerm(α), algconfig, + iterator, plotter + ) + }, + _ => { + not_implemented(); + continue + } + } + }, + AlgorithmConfig::RadonFISTA(ref algconfig) => { + match (regularisation, dataterm) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { + running(); + pointsource_radon_fista_reg( + &opA, &b, NonnegRadonRegTerm(α), algconfig, + iterator, plotter + ) + }, + (Regularisation::Radon(α), DataTerm::L2Squared) => { + running(); + pointsource_radon_fista_reg( + &opA, &b, RadonRegTerm(α), algconfig, + iterator, plotter + ) + }, + _ => { + not_implemented(); + continue + } + } + }, AlgorithmConfig::SlidingFB(ref algconfig) => { match (regularisation, dataterm) { (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { @@ -685,11 +760,12 @@ Sensor : RealMapping<F, N> + Support<F, N> + Clone, Spread : RealMapping<F, N> + Support<F, N> + Clone, Kernel : RealMapping<F, N> + Support<F, N>, - Convolution<Sensor, Spread> : RealMapping<F, N> + Support<F, N>, + Convolution<Sensor, Spread> : DifferentiableRealMapping<F, N> + Support<F, N>, + //Differential<Loc<F, N>, Convolution<Sensor, Spread>> : RealVectorField<F, N, N>, 𝒟 : DiscreteMeasureOp<Loc<F, N>, F>, 𝒟::Codomain : RealMapping<F, N>, A : ForwardModel<Loc<F, N>, F>, - A::PreadjointCodomain : RealMapping<F, N> + Bounded<F>, + A::PreadjointCodomain : DifferentiableRealMapping<F, N> + Bounded<F>, PlotLookup : Plotting<N>, Cube<F, N> : SetOrd { @@ -706,7 +782,7 @@ PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); - PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); + PlotLookup::plot_into_file_diff(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); let plotgrid2 = lingrid(&domain, &[resolution; N]); @@ -725,6 +801,10 @@ plotgrid2, None, &μ_hat, pfx("omega_b") ); + PlotLookup::plot_into_file_diff(&preadj_b, plotgrid2, pfx("preadj_b"), + "preadj_b".to_string()); + PlotLookup::plot_into_file_diff(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"), + "preadj_b_hat".to_string()); // Save true solution and observables let pfx = |n| format!("{}{}", prefix, n);