diff -r aec67cdd6b14 -r efa60bc4f743 src/run.rs --- a/src/run.rs Tue Aug 01 10:32:12 2023 +0300 +++ b/src/run.rs Thu Aug 29 00:00:00 2024 -0500 @@ -31,7 +31,10 @@ use alg_tools::error::DynError; use alg_tools::tabledump::TableDump; use alg_tools::sets::Cube; -use alg_tools::mapping::{RealMapping, Differentiable}; +use alg_tools::mapping::{ + RealMapping, + DifferentiableRealMapping +}; use alg_tools::nalgebra_support::ToNalgebraRealField; use alg_tools::euclidean::Euclidean; use alg_tools::lingrid::lingrid; @@ -48,6 +51,11 @@ pointsource_fb_reg, pointsource_fista_reg, }; +use crate::radon_fb::{ + RadonFBConfig, + pointsource_radon_fb_reg, + pointsource_radon_fista_reg, +}; use crate::sliding_fb::{ SlidingFBConfig, pointsource_sliding_fb_reg @@ -85,6 +93,8 @@ FISTA(FBConfig), FW(FWConfig), PDPS(PDPSConfig), + RadonFB(RadonFBConfig), + RadonFISTA(RadonFBConfig), SlidingFB(SlidingFBConfig), } @@ -114,19 +124,19 @@ match self { FB(fb) => FB(FBConfig { τ0 : cli.tau0.unwrap_or(fb.τ0), - insertion : override_fb_generic(fb.insertion), + generic : override_fb_generic(fb.generic), .. fb }), FISTA(fb) => FISTA(FBConfig { τ0 : cli.tau0.unwrap_or(fb.τ0), - insertion : override_fb_generic(fb.insertion), + generic : override_fb_generic(fb.generic), .. fb }), PDPS(pdps) => PDPS(PDPSConfig { τ0 : cli.tau0.unwrap_or(pdps.τ0), σ0 : cli.sigma0.unwrap_or(pdps.σ0), acceleration : cli.acceleration.unwrap_or(pdps.acceleration), - insertion : override_fb_generic(pdps.insertion), + generic : override_fb_generic(pdps.generic), .. pdps }), FW(fw) => FW(FWConfig { @@ -134,8 +144,21 @@ tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), .. fw }), + RadonFB(fb) => RadonFB(RadonFBConfig { + τ0 : cli.tau0.unwrap_or(fb.τ0), + insertion : override_fb_generic(fb.insertion), + .. fb + }), + RadonFISTA(fb) => RadonFISTA(RadonFBConfig { + τ0 : cli.tau0.unwrap_or(fb.τ0), + insertion : override_fb_generic(fb.insertion), + .. fb + }), SlidingFB(sfb) => SlidingFB(SlidingFBConfig { τ0 : cli.tau0.unwrap_or(sfb.τ0), + θ0 : cli.theta0.unwrap_or(sfb.θ0), + transport_tolerance_ω: cli.transport_tolerance_omega.unwrap_or(sfb.transport_tolerance_ω), + transport_tolerance_dv: cli.transport_tolerance_dv.unwrap_or(sfb.transport_tolerance_dv), insertion : override_fb_generic(sfb.insertion), .. sfb }), @@ -169,6 +192,12 @@ /// The μPDPS primal-dual proximal splitting method #[clap(name = "pdps")] PDPS, + /// The RadonFB forward-backward method + #[clap(name = "radon_fb")] + RadonFB, + /// The RadonFISTA inertial forward-backward method + #[clap(name = "radon_fista")] + RadonFISTA, /// The Sliding FB method #[clap(name = "sliding_fb", alias = "sfb")] SlidingFB, @@ -187,6 +216,8 @@ .. Default::default() }), PDPS => AlgorithmConfig::PDPS(Default::default()), + RadonFB => AlgorithmConfig::RadonFB(Default::default()), + RadonFISTA => AlgorithmConfig::RadonFISTA(Default::default()), SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), } } @@ -364,13 +395,13 @@ // TODO: very weird that rust only compiles with Differentiable // instead of the above one on references, which is required by // poitsource_sliding_fb_reg. - + Differentiable, Output = Loc> - + Lipschitz, + + DifferentiableRealMapping + + Lipschitz, // as ForwardModel, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc, Output = Loc>, AutoConvolution

: BoundedBy, - K : SimpleConvolutionKernel + LocalAnalysis, N> + K : SimpleConvolutionKernel + + LocalAnalysis, N> + Copy + Serialize + std::fmt::Debug, - //+ Differentiable, Output = Loc>, // TODO: shouldn't need to assume differentiability Cube: P2Minimise, F> + SetOrd, PlotLookup : Plotting, DefaultBT : SensorGridBT + BTSearch, @@ -569,6 +600,50 @@ } } }, + AlgorithmConfig::RadonFB(ref algconfig) => { + match (regularisation, dataterm) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { + running(); + pointsource_radon_fb_reg( + &opA, &b, NonnegRadonRegTerm(α), algconfig, + iterator, plotter + ) + }, + (Regularisation::Radon(α), DataTerm::L2Squared) => { + running(); + pointsource_radon_fb_reg( + &opA, &b, RadonRegTerm(α), algconfig, + iterator, plotter + ) + }, + _ => { + not_implemented(); + continue + } + } + }, + AlgorithmConfig::RadonFISTA(ref algconfig) => { + match (regularisation, dataterm) { + (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { + running(); + pointsource_radon_fista_reg( + &opA, &b, NonnegRadonRegTerm(α), algconfig, + iterator, plotter + ) + }, + (Regularisation::Radon(α), DataTerm::L2Squared) => { + running(); + pointsource_radon_fista_reg( + &opA, &b, RadonRegTerm(α), algconfig, + iterator, plotter + ) + }, + _ => { + not_implemented(); + continue + } + } + }, AlgorithmConfig::SlidingFB(ref algconfig) => { match (regularisation, dataterm) { (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { @@ -685,11 +760,12 @@ Sensor : RealMapping + Support + Clone, Spread : RealMapping + Support + Clone, Kernel : RealMapping + Support, - Convolution : RealMapping + Support, + Convolution : DifferentiableRealMapping + Support, + //Differential, Convolution> : RealVectorField, 𝒟 : DiscreteMeasureOp, F>, 𝒟::Codomain : RealMapping, A : ForwardModel, F>, - A::PreadjointCodomain : RealMapping + Bounded, + A::PreadjointCodomain : DifferentiableRealMapping + Bounded, PlotLookup : Plotting, Cube : SetOrd { @@ -706,7 +782,7 @@ PlotLookup::plot_into_file(sensor, plotgrid, pfx("sensor"), "sensor".to_string()); PlotLookup::plot_into_file(kernel, plotgrid, pfx("kernel"), "kernel".to_string()); PlotLookup::plot_into_file(spread, plotgrid, pfx("spread"), "spread".to_string()); - PlotLookup::plot_into_file(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); + PlotLookup::plot_into_file_diff(&base, plotgrid, pfx("base_sensor"), "base_sensor".to_string()); let plotgrid2 = lingrid(&domain, &[resolution; N]); @@ -725,6 +801,10 @@ plotgrid2, None, &μ_hat, pfx("omega_b") ); + PlotLookup::plot_into_file_diff(&preadj_b, plotgrid2, pfx("preadj_b"), + "preadj_b".to_string()); + PlotLookup::plot_into_file_diff(&preadj_b_hat, plotgrid2, pfx("preadj_b_hat"), + "preadj_b_hat".to_string()); // Save true solution and observables let pfx = |n| format!("{}{}", prefix, n);