src/run.rs

branch
dev
changeset 32
56c8adc32b09
parent 29
87649ccfa6a8
child 34
efa60bc4f743
equal deleted inserted replaced
30:bd13c2ae3450 32:56c8adc32b09
29 }; 29 };
30 use alg_tools::logger::Logger; 30 use alg_tools::logger::Logger;
31 use alg_tools::error::DynError; 31 use alg_tools::error::DynError;
32 use alg_tools::tabledump::TableDump; 32 use alg_tools::tabledump::TableDump;
33 use alg_tools::sets::Cube; 33 use alg_tools::sets::Cube;
34 use alg_tools::mapping::RealMapping; 34 use alg_tools::mapping::{RealMapping, Differentiable};
35 use alg_tools::nalgebra_support::ToNalgebraRealField; 35 use alg_tools::nalgebra_support::ToNalgebraRealField;
36 use alg_tools::euclidean::Euclidean; 36 use alg_tools::euclidean::Euclidean;
37 use alg_tools::norms::L1;
38 use alg_tools::lingrid::lingrid; 37 use alg_tools::lingrid::lingrid;
39 use alg_tools::sets::SetOrd; 38 use alg_tools::sets::SetOrd;
40 39
41 use crate::kernels::*; 40 use crate::kernels::*;
42 use crate::types::*; 41 use crate::types::*;
43 use crate::measures::*; 42 use crate::measures::*;
44 use crate::measures::merging::SpikeMerging; 43 use crate::measures::merging::SpikeMerging;
45 use crate::forward_model::*; 44 use crate::forward_model::*;
46 use crate::fb::{ 45 use crate::fb::{
47 FBConfig, 46 FBConfig,
47 FBGenericConfig,
48 pointsource_fb_reg, 48 pointsource_fb_reg,
49 FBMetaAlgorithm, 49 pointsource_fista_reg,
50 FBGenericConfig, 50 };
51 use crate::sliding_fb::{
52 SlidingFBConfig,
53 pointsource_sliding_fb_reg
51 }; 54 };
52 use crate::pdps::{ 55 use crate::pdps::{
53 PDPSConfig, 56 PDPSConfig,
54 L2Squared,
55 pointsource_pdps_reg, 57 pointsource_pdps_reg,
56 }; 58 };
57 use crate::frank_wolfe::{ 59 use crate::frank_wolfe::{
58 FWConfig, 60 FWConfig,
59 FWVariant, 61 FWVariant,
63 use crate::subproblem::InnerSettings; 65 use crate::subproblem::InnerSettings;
64 use crate::seminorms::*; 66 use crate::seminorms::*;
65 use crate::plot::*; 67 use crate::plot::*;
66 use crate::{AlgorithmOverrides, CommandLineArgs}; 68 use crate::{AlgorithmOverrides, CommandLineArgs};
67 use crate::tolerance::Tolerance; 69 use crate::tolerance::Tolerance;
68 use crate::regularisation::{Regularisation, RadonRegTerm, NonnegRadonRegTerm}; 70 use crate::regularisation::{
71 Regularisation,
72 RadonRegTerm,
73 NonnegRadonRegTerm
74 };
75 use crate::dataterm::{
76 L1,
77 L2Squared
78 };
79 use alg_tools::norms::L2;
69 80
70 /// Available algorithms and their configurations 81 /// Available algorithms and their configurations
71 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] 82 #[derive(Copy, Clone, Debug, Serialize, Deserialize)]
72 pub enum AlgorithmConfig<F : Float> { 83 pub enum AlgorithmConfig<F : Float> {
73 FB(FBConfig<F>), 84 FB(FBConfig<F>),
85 FISTA(FBConfig<F>),
74 FW(FWConfig<F>), 86 FW(FWConfig<F>),
75 PDPS(PDPSConfig<F>), 87 PDPS(PDPSConfig<F>),
88 SlidingFB(SlidingFBConfig<F>),
76 } 89 }
77 90
78 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { 91 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> {
79 assert!(v.len() == 3); 92 assert!(v.len() == 3);
80 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } 93 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] }
102 FB(fb) => FB(FBConfig { 115 FB(fb) => FB(FBConfig {
103 τ0 : cli.tau0.unwrap_or(fb.τ0), 116 τ0 : cli.tau0.unwrap_or(fb.τ0),
104 insertion : override_fb_generic(fb.insertion), 117 insertion : override_fb_generic(fb.insertion),
105 .. fb 118 .. fb
106 }), 119 }),
120 FISTA(fb) => FISTA(FBConfig {
121 τ0 : cli.tau0.unwrap_or(fb.τ0),
122 insertion : override_fb_generic(fb.insertion),
123 .. fb
124 }),
107 PDPS(pdps) => PDPS(PDPSConfig { 125 PDPS(pdps) => PDPS(PDPSConfig {
108 τ0 : cli.tau0.unwrap_or(pdps.τ0), 126 τ0 : cli.tau0.unwrap_or(pdps.τ0),
109 σ0 : cli.sigma0.unwrap_or(pdps.σ0), 127 σ0 : cli.sigma0.unwrap_or(pdps.σ0),
110 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), 128 acceleration : cli.acceleration.unwrap_or(pdps.acceleration),
111 insertion : override_fb_generic(pdps.insertion), 129 insertion : override_fb_generic(pdps.insertion),
113 }), 131 }),
114 FW(fw) => FW(FWConfig { 132 FW(fw) => FW(FWConfig {
115 merging : cli.merging.clone().unwrap_or(fw.merging), 133 merging : cli.merging.clone().unwrap_or(fw.merging),
116 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance), 134 tolerance : cli.tolerance.as_ref().map(unpack_tolerance).unwrap_or(fw.tolerance),
117 .. fw 135 .. fw
118 }) 136 }),
137 SlidingFB(sfb) => SlidingFB(SlidingFBConfig {
138 τ0 : cli.tau0.unwrap_or(sfb.τ0),
139 insertion : override_fb_generic(sfb.insertion),
140 .. sfb
141 }),
119 } 142 }
120 } 143 }
121 } 144 }
122 145
123 /// Helper struct for tagging and [`AlgorithmConfig`] or [`Experiment`] with a name. 146 /// Helper struct for tagging and [`AlgorithmConfig`] or [`Experiment`] with a name.
144 #[clap(name = "fwrelax")] 167 #[clap(name = "fwrelax")]
145 FWRelax, 168 FWRelax,
146 /// The μPDPS primal-dual proximal splitting method 169 /// The μPDPS primal-dual proximal splitting method
147 #[clap(name = "pdps")] 170 #[clap(name = "pdps")]
148 PDPS, 171 PDPS,
172 /// The Sliding FB method
173 #[clap(name = "sliding_fb", alias = "sfb")]
174 SlidingFB,
149 } 175 }
150 176
151 impl DefaultAlgorithm { 177 impl DefaultAlgorithm {
152 /// Returns the algorithm configuration corresponding to the algorithm shorthand 178 /// Returns the algorithm configuration corresponding to the algorithm shorthand
153 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { 179 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> {
154 use DefaultAlgorithm::*; 180 use DefaultAlgorithm::*;
155 match *self { 181 match *self {
156 FB => AlgorithmConfig::FB(Default::default()), 182 FB => AlgorithmConfig::FB(Default::default()),
157 FISTA => AlgorithmConfig::FB(FBConfig{ 183 FISTA => AlgorithmConfig::FISTA(Default::default()),
158 meta : FBMetaAlgorithm::InertiaFISTA,
159 .. Default::default()
160 }),
161 FW => AlgorithmConfig::FW(Default::default()), 184 FW => AlgorithmConfig::FW(Default::default()),
162 FWRelax => AlgorithmConfig::FW(FWConfig{ 185 FWRelax => AlgorithmConfig::FW(FWConfig{
163 variant : FWVariant::Relaxed, 186 variant : FWVariant::Relaxed,
164 .. Default::default() 187 .. Default::default()
165 }), 188 }),
166 PDPS => AlgorithmConfig::PDPS(Default::default()), 189 PDPS => AlgorithmConfig::PDPS(Default::default()),
190 SlidingFB => AlgorithmConfig::SlidingFB(Default::default()),
167 } 191 }
168 } 192 }
169 193
170 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand 194 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand
171 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { 195 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> {
331 Named<$type<F, NoiseDistr, S, K, P, N>> 355 Named<$type<F, NoiseDistr, S, K, P, N>>
332 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, 356 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>,
333 [usize; N] : Serialize, 357 [usize; N] : Serialize,
334 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, 358 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug,
335 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, 359 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug,
336 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, 360 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy
361 // TODO: shold not have differentiability as a requirement, but
362 // decide availability of sliding based on it.
363 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>,
364 // TODO: very weird that rust only compiles with Differentiable
365 // instead of the above one on references, which is required by
366 // poitsource_sliding_fb_reg.
367 + Differentiable<Loc<F, N>, Output = Loc<F, N>>
368 + Lipschitz<L2>,
369 // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>,
337 AutoConvolution<P> : BoundedBy<F, K>, 370 AutoConvolution<P> : BoundedBy<F, K>,
338 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> 371 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N>
339 + Copy + Serialize + std::fmt::Debug, 372 + Copy + Serialize + std::fmt::Debug,
373 //+ Differentiable<Loc<F, N>, Output = Loc<F, N>>, // TODO: shouldn't need to assume differentiability
340 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, 374 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd,
341 PlotLookup : Plotting<N>, 375 PlotLookup : Plotting<N>,
342 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, 376 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>,
343 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 377 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
344 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 378 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
511 not_implemented(); 545 not_implemented();
512 continue 546 continue
513 } 547 }
514 } 548 }
515 }, 549 },
550 AlgorithmConfig::FISTA(ref algconfig) => {
551 match (regularisation, dataterm) {
552 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
553 running();
554 pointsource_fista_reg(
555 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
556 iterator, plotter
557 )
558 },
559 (Regularisation::Radon(α), DataTerm::L2Squared) => {
560 running();
561 pointsource_fista_reg(
562 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
563 iterator, plotter
564 )
565 },
566 _ => {
567 not_implemented();
568 continue
569 }
570 }
571 },
572 AlgorithmConfig::SlidingFB(ref algconfig) => {
573 match (regularisation, dataterm) {
574 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
575 running();
576 pointsource_sliding_fb_reg(
577 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig,
578 iterator, plotter
579 )
580 },
581 (Regularisation::Radon(α), DataTerm::L2Squared) => {
582 running();
583 pointsource_sliding_fb_reg(
584 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig,
585 iterator, plotter
586 )
587 },
588 _ => {
589 not_implemented();
590 continue
591 }
592 }
593 },
516 AlgorithmConfig::PDPS(ref algconfig) => { 594 AlgorithmConfig::PDPS(ref algconfig) => {
517 running(); 595 running();
518 match (regularisation, dataterm) { 596 match (regularisation, dataterm) {
519 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { 597 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => {
520 pointsource_pdps_reg( 598 pointsource_pdps_reg(

mercurial