29 }; |
29 }; |
30 use alg_tools::logger::Logger; |
30 use alg_tools::logger::Logger; |
31 use alg_tools::error::DynError; |
31 use alg_tools::error::DynError; |
32 use alg_tools::tabledump::TableDump; |
32 use alg_tools::tabledump::TableDump; |
33 use alg_tools::sets::Cube; |
33 use alg_tools::sets::Cube; |
34 use alg_tools::mapping::RealMapping; |
34 use alg_tools::mapping::{RealMapping, Differentiable}; |
35 use alg_tools::nalgebra_support::ToNalgebraRealField; |
35 use alg_tools::nalgebra_support::ToNalgebraRealField; |
36 use alg_tools::euclidean::Euclidean; |
36 use alg_tools::euclidean::Euclidean; |
37 use alg_tools::norms::L1; |
|
38 use alg_tools::lingrid::lingrid; |
37 use alg_tools::lingrid::lingrid; |
39 use alg_tools::sets::SetOrd; |
38 use alg_tools::sets::SetOrd; |
40 |
39 |
41 use crate::kernels::*; |
40 use crate::kernels::*; |
42 use crate::types::*; |
41 use crate::types::*; |
43 use crate::measures::*; |
42 use crate::measures::*; |
44 use crate::measures::merging::SpikeMerging; |
43 use crate::measures::merging::SpikeMerging; |
45 use crate::forward_model::*; |
44 use crate::forward_model::*; |
46 use crate::fb::{ |
45 use crate::fb::{ |
47 FBConfig, |
46 FBConfig, |
|
47 FBGenericConfig, |
48 pointsource_fb_reg, |
48 pointsource_fb_reg, |
49 FBMetaAlgorithm, |
49 pointsource_fista_reg, |
50 FBGenericConfig, |
50 }; |
|
51 use crate::sliding_fb::{ |
|
52 SlidingFBConfig, |
|
53 pointsource_sliding_fb_reg |
51 }; |
54 }; |
52 use crate::pdps::{ |
55 use crate::pdps::{ |
53 PDPSConfig, |
56 PDPSConfig, |
54 L2Squared, |
|
55 pointsource_pdps_reg, |
57 pointsource_pdps_reg, |
56 }; |
58 }; |
57 use crate::frank_wolfe::{ |
59 use crate::frank_wolfe::{ |
58 FWConfig, |
60 FWConfig, |
59 FWVariant, |
61 FWVariant, |
63 use crate::subproblem::InnerSettings; |
65 use crate::subproblem::InnerSettings; |
64 use crate::seminorms::*; |
66 use crate::seminorms::*; |
65 use crate::plot::*; |
67 use crate::plot::*; |
66 use crate::{AlgorithmOverrides, CommandLineArgs}; |
68 use crate::{AlgorithmOverrides, CommandLineArgs}; |
67 use crate::tolerance::Tolerance; |
69 use crate::tolerance::Tolerance; |
68 use crate::regularisation::{Regularisation, RadonRegTerm, NonnegRadonRegTerm}; |
70 use crate::regularisation::{ |
|
71 Regularisation, |
|
72 RadonRegTerm, |
|
73 NonnegRadonRegTerm |
|
74 }; |
|
75 use crate::dataterm::{ |
|
76 L1, |
|
77 L2Squared |
|
78 }; |
|
79 use alg_tools::norms::L2; |
69 |
80 |
70 /// Available algorithms and their configurations |
81 /// Available algorithms and their configurations |
71 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
82 #[derive(Copy, Clone, Debug, Serialize, Deserialize)] |
72 pub enum AlgorithmConfig<F : Float> { |
83 pub enum AlgorithmConfig<F : Float> { |
73 FB(FBConfig<F>), |
84 FB(FBConfig<F>), |
|
85 FISTA(FBConfig<F>), |
74 FW(FWConfig<F>), |
86 FW(FWConfig<F>), |
75 PDPS(PDPSConfig<F>), |
87 PDPS(PDPSConfig<F>), |
|
88 SlidingFB(SlidingFBConfig<F>), |
76 } |
89 } |
77 |
90 |
78 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { |
91 fn unpack_tolerance<F : Float>(v : &Vec<F>) -> Tolerance<F> { |
79 assert!(v.len() == 3); |
92 assert!(v.len() == 3); |
80 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } |
93 Tolerance::Power { initial : v[0], factor : v[1], exponent : v[2] } |
102 FB(fb) => FB(FBConfig { |
115 FB(fb) => FB(FBConfig { |
103 τ0 : cli.tau0.unwrap_or(fb.τ0), |
116 τ0 : cli.tau0.unwrap_or(fb.τ0), |
104 insertion : override_fb_generic(fb.insertion), |
117 insertion : override_fb_generic(fb.insertion), |
105 .. fb |
118 .. fb |
106 }), |
119 }), |
|
120 FISTA(fb) => FISTA(FBConfig { |
|
121 τ0 : cli.tau0.unwrap_or(fb.τ0), |
|
122 insertion : override_fb_generic(fb.insertion), |
|
123 .. fb |
|
124 }), |
107 PDPS(pdps) => PDPS(PDPSConfig { |
125 PDPS(pdps) => PDPS(PDPSConfig { |
108 τ0 : cli.tau0.unwrap_or(pdps.τ0), |
126 τ0 : cli.tau0.unwrap_or(pdps.τ0), |
109 σ0 : cli.sigma0.unwrap_or(pdps.σ0), |
127 σ0 : cli.sigma0.unwrap_or(pdps.σ0), |
110 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
128 acceleration : cli.acceleration.unwrap_or(pdps.acceleration), |
111 insertion : override_fb_generic(pdps.insertion), |
129 insertion : override_fb_generic(pdps.insertion), |
144 #[clap(name = "fwrelax")] |
167 #[clap(name = "fwrelax")] |
145 FWRelax, |
168 FWRelax, |
146 /// The μPDPS primal-dual proximal splitting method |
169 /// The μPDPS primal-dual proximal splitting method |
147 #[clap(name = "pdps")] |
170 #[clap(name = "pdps")] |
148 PDPS, |
171 PDPS, |
|
172 /// The Sliding FB method |
|
173 #[clap(name = "sliding_fb", alias = "sfb")] |
|
174 SlidingFB, |
149 } |
175 } |
150 |
176 |
151 impl DefaultAlgorithm { |
177 impl DefaultAlgorithm { |
152 /// Returns the algorithm configuration corresponding to the algorithm shorthand |
178 /// Returns the algorithm configuration corresponding to the algorithm shorthand |
153 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { |
179 pub fn default_config<F : Float>(&self) -> AlgorithmConfig<F> { |
154 use DefaultAlgorithm::*; |
180 use DefaultAlgorithm::*; |
155 match *self { |
181 match *self { |
156 FB => AlgorithmConfig::FB(Default::default()), |
182 FB => AlgorithmConfig::FB(Default::default()), |
157 FISTA => AlgorithmConfig::FB(FBConfig{ |
183 FISTA => AlgorithmConfig::FISTA(Default::default()), |
158 meta : FBMetaAlgorithm::InertiaFISTA, |
|
159 .. Default::default() |
|
160 }), |
|
161 FW => AlgorithmConfig::FW(Default::default()), |
184 FW => AlgorithmConfig::FW(Default::default()), |
162 FWRelax => AlgorithmConfig::FW(FWConfig{ |
185 FWRelax => AlgorithmConfig::FW(FWConfig{ |
163 variant : FWVariant::Relaxed, |
186 variant : FWVariant::Relaxed, |
164 .. Default::default() |
187 .. Default::default() |
165 }), |
188 }), |
166 PDPS => AlgorithmConfig::PDPS(Default::default()), |
189 PDPS => AlgorithmConfig::PDPS(Default::default()), |
|
190 SlidingFB => AlgorithmConfig::SlidingFB(Default::default()), |
167 } |
191 } |
168 } |
192 } |
169 |
193 |
170 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand |
194 /// Returns the [`Named`] algorithm corresponding to the algorithm shorthand |
171 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { |
195 pub fn get_named<F : Float>(&self) -> Named<AlgorithmConfig<F>> { |
331 Named<$type<F, NoiseDistr, S, K, P, N>> |
355 Named<$type<F, NoiseDistr, S, K, P, N>> |
332 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, |
356 where F : ClapFloat + nalgebra::RealField + ToNalgebraRealField<MixedType=F>, |
333 [usize; N] : Serialize, |
357 [usize; N] : Serialize, |
334 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
358 S : Sensor<F, N> + Copy + Serialize + std::fmt::Debug, |
335 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
359 P : Spread<F, N> + Copy + Serialize + std::fmt::Debug, |
336 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy, |
360 Convolution<S, P>: Spread<F, N> + Bounded<F> + LocalAnalysis<F, Bounds<F>, N> + Copy |
|
361 // TODO: shold not have differentiability as a requirement, but |
|
362 // decide availability of sliding based on it. |
|
363 //+ for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
|
364 // TODO: very weird that rust only compiles with Differentiable |
|
365 // instead of the above one on references, which is required by |
|
366 // poitsource_sliding_fb_reg. |
|
367 + Differentiable<Loc<F, N>, Output = Loc<F, N>> |
|
368 + Lipschitz<L2>, |
|
369 // <DefaultSG<F, S, P, N> as ForwardModel<Loc<F, N>, F>::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output = Loc<F, N>>, |
337 AutoConvolution<P> : BoundedBy<F, K>, |
370 AutoConvolution<P> : BoundedBy<F, K>, |
338 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> |
371 K : SimpleConvolutionKernel<F, N> + LocalAnalysis<F, Bounds<F>, N> |
339 + Copy + Serialize + std::fmt::Debug, |
372 + Copy + Serialize + std::fmt::Debug, |
|
373 //+ Differentiable<Loc<F, N>, Output = Loc<F, N>>, // TODO: shouldn't need to assume differentiability |
340 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
374 Cube<F, N>: P2Minimise<Loc<F, N>, F> + SetOrd, |
341 PlotLookup : Plotting<N>, |
375 PlotLookup : Plotting<N>, |
342 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
376 DefaultBT<F, N> : SensorGridBT<F, S, P, N, Depth=DynamicDepth> + BTSearch<F, N>, |
343 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
377 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
344 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
378 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
511 not_implemented(); |
545 not_implemented(); |
512 continue |
546 continue |
513 } |
547 } |
514 } |
548 } |
515 }, |
549 }, |
|
550 AlgorithmConfig::FISTA(ref algconfig) => { |
|
551 match (regularisation, dataterm) { |
|
552 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
|
553 running(); |
|
554 pointsource_fista_reg( |
|
555 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
|
556 iterator, plotter |
|
557 ) |
|
558 }, |
|
559 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
|
560 running(); |
|
561 pointsource_fista_reg( |
|
562 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
|
563 iterator, plotter |
|
564 ) |
|
565 }, |
|
566 _ => { |
|
567 not_implemented(); |
|
568 continue |
|
569 } |
|
570 } |
|
571 }, |
|
572 AlgorithmConfig::SlidingFB(ref algconfig) => { |
|
573 match (regularisation, dataterm) { |
|
574 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
|
575 running(); |
|
576 pointsource_sliding_fb_reg( |
|
577 &opA, &b, NonnegRadonRegTerm(α), &op𝒟, algconfig, |
|
578 iterator, plotter |
|
579 ) |
|
580 }, |
|
581 (Regularisation::Radon(α), DataTerm::L2Squared) => { |
|
582 running(); |
|
583 pointsource_sliding_fb_reg( |
|
584 &opA, &b, RadonRegTerm(α), &op𝒟, algconfig, |
|
585 iterator, plotter |
|
586 ) |
|
587 }, |
|
588 _ => { |
|
589 not_implemented(); |
|
590 continue |
|
591 } |
|
592 } |
|
593 }, |
516 AlgorithmConfig::PDPS(ref algconfig) => { |
594 AlgorithmConfig::PDPS(ref algconfig) => { |
517 running(); |
595 running(); |
518 match (regularisation, dataterm) { |
596 match (regularisation, dataterm) { |
519 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
597 (Regularisation::NonnegRadon(α), DataTerm::L2Squared) => { |
520 pointsource_pdps_reg( |
598 pointsource_pdps_reg( |