35 This is the task of <code>generic_pointsource_fb</code>, where we use <code>FBSpecialisation</code> |
34 This is the task of <code>generic_pointsource_fb</code>, where we use <code>FBSpecialisation</code> |
36 to replace the specific residual $Aμ-b$ by $y$. |
35 to replace the specific residual $Aμ-b$ by $y$. |
37 For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$. |
36 For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$. |
38 For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$. |
37 For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$. |
39 </p> |
38 </p> |
40 |
|
41 Based on zero initialisation for $μ$, we use the [`Subdifferentiable`] trait to make an |
|
42 initialisation corresponding to the second part of the optimality conditions. |
|
43 In the algorithm itself, standard proximal steps are taking with respect to $F\_0^* + ⟨b, ·⟩$. |
|
44 */ |
39 */ |
45 |
40 |
46 use numeric_literals::replace_float_literals; |
41 use numeric_literals::replace_float_literals; |
47 use serde::{Serialize, Deserialize}; |
42 use serde::{Serialize, Deserialize}; |
48 use nalgebra::DVector; |
43 use nalgebra::DVector; |
49 use clap::ValueEnum; |
44 use clap::ValueEnum; |
50 |
45 |
51 use alg_tools::iterate:: AlgIteratorFactory; |
46 use alg_tools::iterate::AlgIteratorFactory; |
52 use alg_tools::sets::Cube; |
|
53 use alg_tools::loc::Loc; |
|
54 use alg_tools::euclidean::Euclidean; |
47 use alg_tools::euclidean::Euclidean; |
|
48 use alg_tools::linops::Mapping; |
55 use alg_tools::norms::{ |
49 use alg_tools::norms::{ |
56 L1, Linfinity, |
50 Linfinity, |
57 Projection, Norm, |
51 Projection, |
58 }; |
52 }; |
59 use alg_tools::bisection_tree::{ |
53 use alg_tools::mapping::{RealMapping, Instance}; |
60 BTFN, |
|
61 PreBTFN, |
|
62 Bounds, |
|
63 BTNodeLookup, |
|
64 BTNode, |
|
65 BTSearch, |
|
66 P2Minimise, |
|
67 SupportGenerator, |
|
68 LocalAnalysis, |
|
69 }; |
|
70 use alg_tools::mapping::RealMapping; |
|
71 use alg_tools::nalgebra_support::ToNalgebraRealField; |
54 use alg_tools::nalgebra_support::ToNalgebraRealField; |
72 use alg_tools::linops::AXPY; |
55 use alg_tools::linops::AXPY; |
73 |
56 |
74 use crate::types::*; |
57 use crate::types::*; |
75 use crate::measures::DiscreteMeasure; |
58 use crate::measures::{DiscreteMeasure, RNDM}; |
76 use crate::measures::merging::{ |
59 use crate::measures::merging::SpikeMerging; |
77 SpikeMerging, |
60 use crate::forward_model::{ |
78 }; |
61 ForwardModel, |
79 use crate::forward_model::ForwardModel; |
62 AdjointProductBoundedBy, |
80 use crate::seminorms::{ |
|
81 DiscreteMeasureOp, Lipschitz |
|
82 }; |
63 }; |
83 use crate::plot::{ |
64 use crate::plot::{ |
84 SeqPlotter, |
65 SeqPlotter, |
85 Plotting, |
66 Plotting, |
86 PlotLookup |
67 PlotLookup |
87 }; |
68 }; |
88 use crate::fb::{ |
69 use crate::fb::{ |
|
70 postprocess, |
|
71 prune_with_stats |
|
72 }; |
|
73 pub use crate::prox_penalty::{ |
89 FBGenericConfig, |
74 FBGenericConfig, |
90 FBSpecialisation, |
75 ProxPenalty |
91 generic_pointsource_fb_reg, |
76 }; |
92 RegTerm, |
77 use crate::regularisation::RegTerm; |
93 }; |
78 use crate::dataterm::{ |
94 use crate::regularisation::NonnegRadonRegTerm; |
79 DataTerm, |
|
80 L2Squared, |
|
81 L1 |
|
82 }; |
|
83 use crate::measures::merging::SpikeMergingMethod; |
|
84 |
95 |
85 |
96 /// Acceleration |
86 /// Acceleration |
97 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)] |
87 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)] |
98 pub enum Acceleration { |
88 pub enum Acceleration { |
99 /// No acceleration |
89 /// No acceleration |
105 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed |
95 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed |
106 #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] |
96 #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] |
107 Full |
97 Full |
108 } |
98 } |
109 |
99 |
110 /// Settings for [`pointsource_pdps`]. |
100 #[replace_float_literals(F::cast_from(literal))] |
|
101 impl Acceleration { |
|
102 /// PDPS parameter acceleration. Updates τ and σ and returns ω. |
|
103 /// This uses dual strong convexity, not primal. |
|
104 fn accelerate<F : Float>(self, τ : &mut F, σ : &mut F, γ : F) -> F { |
|
105 match self { |
|
106 Acceleration::None => 1.0, |
|
107 Acceleration::Partial => { |
|
108 let ω = 1.0 / (1.0 + γ * (*σ)).sqrt(); |
|
109 *σ *= ω; |
|
110 *τ /= ω; |
|
111 ω |
|
112 }, |
|
113 Acceleration::Full => { |
|
114 let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt(); |
|
115 *σ *= ω; |
|
116 *τ /= ω; |
|
117 ω |
|
118 }, |
|
119 } |
|
120 } |
|
121 } |
|
122 |
|
123 /// Settings for [`pointsource_pdps_reg`]. |
111 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
124 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
112 #[serde(default)] |
125 #[serde(default)] |
113 pub struct PDPSConfig<F : Float> { |
126 pub struct PDPSConfig<F : Float> { |
114 /// Primal step length scaling. We must have `τ0 * σ0 < 1`. |
127 /// Primal step length scaling. We must have `τ0 * σ0 < 1`. |
115 pub τ0 : F, |
128 pub τ0 : F, |
116 /// Dual step length scaling. We must have `τ0 * σ0 < 1`. |
129 /// Dual step length scaling. We must have `τ0 * σ0 < 1`. |
117 pub σ0 : F, |
130 pub σ0 : F, |
118 /// Accelerate if available |
131 /// Accelerate if available |
119 pub acceleration : Acceleration, |
132 pub acceleration : Acceleration, |
120 /// Generic parameters |
133 /// Generic parameters |
121 pub insertion : FBGenericConfig<F>, |
134 pub generic : FBGenericConfig<F>, |
122 } |
135 } |
123 |
136 |
124 #[replace_float_literals(F::cast_from(literal))] |
137 #[replace_float_literals(F::cast_from(literal))] |
125 impl<F : Float> Default for PDPSConfig<F> { |
138 impl<F : Float> Default for PDPSConfig<F> { |
126 fn default() -> Self { |
139 fn default() -> Self { |
127 let τ0 = 0.5; |
140 let τ0 = 5.0; |
128 PDPSConfig { |
141 PDPSConfig { |
129 τ0, |
142 τ0, |
130 σ0 : 0.99/τ0, |
143 σ0 : 0.99/τ0, |
131 acceleration : Acceleration::Partial, |
144 acceleration : Acceleration::Partial, |
132 insertion : Default::default() |
145 generic : FBGenericConfig { |
|
146 merging : SpikeMergingMethod { enabled : true, ..Default::default() }, |
|
147 .. Default::default() |
|
148 }, |
133 } |
149 } |
134 } |
150 } |
135 } |
151 } |
136 |
152 |
137 /// Trait for subdifferentiable objects |
153 /// Trait for data terms for the PDPS |
138 pub trait Subdifferentiable<F : Float, V, U=V> { |
154 #[replace_float_literals(F::cast_from(literal))] |
139 /// Calculate some subdifferential at `x` |
155 pub trait PDPSDataTerm<F : Float, V, const N : usize> : DataTerm<F, V, N> { |
140 fn some_subdifferential(&self, x : V) -> U; |
156 /// Calculate some subdifferential at `x` for the conjugate |
141 } |
157 fn some_subdifferential(&self, x : V) -> V; |
142 |
158 |
143 /// Type for indicating norm-2-squared data fidelity. |
159 /// Factor of strong convexity of the conjugate |
144 pub struct L2Squared; |
160 #[inline] |
145 |
161 fn factor_of_strong_convexity(&self) -> F { |
146 impl<F : Float, V : Euclidean<F>> Subdifferentiable<F, V> for L2Squared { |
162 0.0 |
|
163 } |
|
164 |
|
165 /// Perform dual update |
|
166 fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F); |
|
167 } |
|
168 |
|
169 |
|
170 #[replace_float_literals(F::cast_from(literal))] |
|
171 impl<F, V, const N : usize> PDPSDataTerm<F, V, N> |
|
172 for L2Squared |
|
173 where |
|
174 F : Float, |
|
175 V : Euclidean<F> + AXPY<F>, |
|
176 for<'b> &'b V : Instance<V>, |
|
177 { |
147 fn some_subdifferential(&self, x : V) -> V { x } |
178 fn some_subdifferential(&self, x : V) -> V { x } |
148 } |
179 |
149 |
180 fn factor_of_strong_convexity(&self) -> F { |
150 impl<F : Float + nalgebra::RealField> Subdifferentiable<F, DVector<F>> for L1 { |
181 1.0 |
|
182 } |
|
183 |
|
184 #[inline] |
|
185 fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { |
|
186 y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ)); |
|
187 } |
|
188 } |
|
189 |
|
190 #[replace_float_literals(F::cast_from(literal))] |
|
191 impl<F : Float + nalgebra::RealField, const N : usize> |
|
192 PDPSDataTerm<F, DVector<F>, N> |
|
193 for L1 { |
151 fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> { |
194 fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> { |
152 // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. |
195 // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. |
153 x.iter_mut() |
196 x.iter_mut() |
154 .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) }); |
197 .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) }); |
155 x |
198 x |
156 } |
199 } |
157 } |
200 |
158 |
201 #[inline] |
159 /// Specialisation of [`generic_pointsource_fb_reg`] to PDPS. |
202 fn dual_update(&self, y : &mut DVector<F>, y_prev : &DVector<F>, σ : F) { |
160 pub struct PDPS< |
203 y.axpy(1.0, y_prev, σ); |
161 'a, |
|
162 F : Float + ToNalgebraRealField, |
|
163 A : ForwardModel<Loc<F, N>, F>, |
|
164 D, |
|
165 const N : usize |
|
166 > { |
|
167 /// The data |
|
168 b : &'a A::Observable, |
|
169 /// The forward operator |
|
170 opA : &'a A, |
|
171 /// Primal step length |
|
172 τ : F, |
|
173 // Dual step length |
|
174 σ : F, |
|
175 /// Whether acceleration should be applied (if data term supports) |
|
176 acceleration : Acceleration, |
|
177 /// The dataterm. Only used by the type system. |
|
178 _dataterm : D, |
|
179 /// Previous dual iterate. |
|
180 y_prev : A::Observable, |
|
181 } |
|
182 |
|
183 /// Implementation of [`FBSpecialisation`] for μPDPS with norm-2-squared data fidelity. |
|
184 #[replace_float_literals(F::cast_from(literal))] |
|
185 impl< |
|
186 'a, |
|
187 F : Float + ToNalgebraRealField, |
|
188 A : ForwardModel<Loc<F, N>, F>, |
|
189 const N : usize |
|
190 > FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L2Squared, N> |
|
191 where for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> { |
|
192 |
|
193 fn update( |
|
194 &mut self, |
|
195 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
196 μ_base : &DiscreteMeasure<Loc<F, N>, F> |
|
197 ) -> (A::Observable, Option<F>) { |
|
198 let σ = self.σ; |
|
199 let τ = self.τ; |
|
200 let ω = match self.acceleration { |
|
201 Acceleration::None => 1.0, |
|
202 Acceleration::Partial => { |
|
203 let ω = 1.0 / (1.0 + σ).sqrt(); |
|
204 self.σ = σ * ω; |
|
205 self.τ = τ / ω; |
|
206 ω |
|
207 }, |
|
208 Acceleration::Full => { |
|
209 let ω = 1.0 / (1.0 + 2.0 * σ).sqrt(); |
|
210 self.σ = σ * ω; |
|
211 self.τ = τ / ω; |
|
212 ω |
|
213 }, |
|
214 }; |
|
215 |
|
216 μ.prune(); |
|
217 |
|
218 let mut y = self.b.clone(); |
|
219 self.opA.gemv(&mut y, 1.0 + ω, μ, -1.0); |
|
220 self.opA.gemv(&mut y, -ω, μ_base, 1.0); |
|
221 y.axpy(1.0 / (1.0 + σ), &self.y_prev, σ / (1.0 + σ)); |
|
222 self.y_prev.copy_from(&y); |
|
223 |
|
224 (y, Some(self.τ)) |
|
225 } |
|
226 |
|
227 fn calculate_fit( |
|
228 &self, |
|
229 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
230 _y : &A::Observable |
|
231 ) -> F { |
|
232 self.calculate_fit_simple(μ) |
|
233 } |
|
234 |
|
235 fn calculate_fit_simple( |
|
236 &self, |
|
237 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
238 ) -> F { |
|
239 let mut residual = self.b.clone(); |
|
240 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
|
241 residual.norm2_squared_div2() |
|
242 } |
|
243 } |
|
244 |
|
245 /// Implementation of [`FBSpecialisation`] for μPDPS with norm-1 data fidelity. |
|
246 #[replace_float_literals(F::cast_from(literal))] |
|
247 impl< |
|
248 'a, |
|
249 F : Float + ToNalgebraRealField, |
|
250 A : ForwardModel<Loc<F, N>, F>, |
|
251 const N : usize |
|
252 > FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L1, N> |
|
253 where A::Observable : Projection<F, Linfinity> + Norm<F, L1>, |
|
254 for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> { |
|
255 fn update( |
|
256 &mut self, |
|
257 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
258 μ_base : &DiscreteMeasure<Loc<F, N>, F> |
|
259 ) -> (A::Observable, Option<F>) { |
|
260 let σ = self.σ; |
|
261 |
|
262 μ.prune(); |
|
263 |
|
264 //let ȳ = self.opA.apply(μ) * 2.0 - self.opA.apply(μ_base); |
|
265 //*y = proj_{[-1,1]}(&self.y_prev + (ȳ - self.b) * σ) |
|
266 let mut y = self.y_prev.clone(); |
|
267 self.opA.gemv(&mut y, 2.0 * σ, μ, 1.0); |
|
268 self.opA.gemv(&mut y, -σ, μ_base, 1.0); |
|
269 y.axpy(-σ, self.b, 1.0); |
|
270 y.proj_ball_mut(1.0, Linfinity); |
204 y.proj_ball_mut(1.0, Linfinity); |
271 self.y_prev.copy_from(&y); |
|
272 |
|
273 (y, None) |
|
274 } |
|
275 |
|
276 fn calculate_fit( |
|
277 &self, |
|
278 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
279 _y : &A::Observable |
|
280 ) -> F { |
|
281 self.calculate_fit_simple(μ) |
|
282 } |
|
283 |
|
284 fn calculate_fit_simple( |
|
285 &self, |
|
286 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
287 ) -> F { |
|
288 let mut residual = self.b.clone(); |
|
289 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
|
290 residual.norm(L1) |
|
291 } |
205 } |
292 } |
206 } |
293 |
207 |
294 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. |
208 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. |
295 /// |
209 /// |
302 /// |
216 /// |
303 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. |
217 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. |
304 /// |
218 /// |
305 /// Returns the final iterate. |
219 /// Returns the final iterate. |
306 #[replace_float_literals(F::cast_from(literal))] |
220 #[replace_float_literals(F::cast_from(literal))] |
307 pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( |
221 pub fn pointsource_pdps_reg<F, I, A, D, Reg, P, const N : usize>( |
308 opA : &'a A, |
222 opA : &A, |
309 b : &'a A::Observable, |
223 b : &A::Observable, |
310 reg : Reg, |
224 reg : Reg, |
311 op𝒟 : &'a 𝒟, |
225 prox_penalty : &P, |
312 config : &PDPSConfig<F>, |
226 pdpsconfig : &PDPSConfig<F>, |
313 iterator : I, |
227 iterator : I, |
314 plotter : SeqPlotter<F, N>, |
228 mut plotter : SeqPlotter<F, N>, |
315 dataterm : D, |
229 dataterm : D, |
316 ) -> DiscreteMeasure<Loc<F, N>, F> |
230 ) -> RNDM<F, N> |
317 where F : Float + ToNalgebraRealField, |
231 where |
318 I : AlgIteratorFactory<IterInfo<F, N>>, |
232 F : Float + ToNalgebraRealField, |
319 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> |
233 I : AlgIteratorFactory<IterInfo<F, N>>, |
320 + std::ops::Add<A::Observable, Output=A::Observable>, |
234 A : ForwardModel<RNDM<F, N>, F> |
321 //+ std::ops::Mul<F, Output=A::Observable>, // <-- FIXME: compiler overflow |
235 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
322 A::Observable : std::ops::MulAssign<F>, |
236 A::PreadjointCodomain : RealMapping<F, N>, |
323 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
237 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
324 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
238 PlotLookup : Plotting<N>, |
325 + Lipschitz<𝒟, FloatType=F>, |
239 RNDM<F, N> : SpikeMerging<F>, |
326 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
240 D : PDPSDataTerm<F, A::Observable, N>, |
327 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
241 Reg : RegTerm<F, N>, |
328 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
242 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
329 𝒟::Codomain : RealMapping<F, N>, |
243 { |
330 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
244 |
331 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
245 // Check parameters |
332 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
246 assert!(pdpsconfig.τ0 > 0.0 && |
333 PlotLookup : Plotting<N>, |
247 pdpsconfig.σ0 > 0.0 && |
334 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
248 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, |
335 PDPS<'a, F, A, D, N> : FBSpecialisation<F, A::Observable, N>, |
249 "Invalid step length parameters"); |
336 D : Subdifferentiable<F, A::Observable>, |
250 |
337 Reg : RegTerm<F, N> { |
251 // Set up parameters |
338 |
252 let config = &pdpsconfig.generic; |
339 let y = dataterm.some_subdifferential(-b); |
253 let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); |
340 let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt(); |
254 let mut τ = pdpsconfig.τ0 / l; |
341 let τ = config.τ0 / l; |
255 let mut σ = pdpsconfig.σ0 / l; |
342 let σ = config.σ0 / l; |
256 let γ = dataterm.factor_of_strong_convexity(); |
343 |
257 |
344 let pdps = PDPS { |
258 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
345 b, |
259 // by τ compared to the conditional gradient approach. |
346 opA, |
260 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
347 τ, |
261 let mut ε = tolerance.initial(); |
348 σ, |
262 |
349 acceleration : config.acceleration, |
263 // Initialise iterates |
350 _dataterm : dataterm, |
264 let mut μ = DiscreteMeasure::new(); |
351 y_prev : y.clone(), |
265 let mut y = dataterm.some_subdifferential(-b); |
|
266 let mut y_prev = y.clone(); |
|
267 let full_stats = |μ : &RNDM<F, N>, ε, stats| IterInfo { |
|
268 value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ), |
|
269 n_spikes : μ.len(), |
|
270 ε, |
|
271 // postprocessing: config.postprocessing.then(|| μ.clone()), |
|
272 .. stats |
352 }; |
273 }; |
353 |
274 let mut stats = IterInfo::new(); |
354 generic_pointsource_fb_reg( |
275 |
355 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, y, pdps |
276 // Run the algorithm |
356 ) |
277 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
357 } |
278 // Calculate smooth part of surrogate model. |
358 |
279 let mut τv = opA.preadjoint().apply(y * τ); |
359 // |
280 |
360 // Deprecated interfaces |
281 // Save current base point |
361 // |
282 let μ_base = μ.clone(); |
362 |
283 |
363 #[deprecated(note = "Use `pointsource_pdps_reg`")] |
284 // Insert and reweigh |
364 pub fn pointsource_pdps<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, const N : usize>( |
285 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
365 opA : &'a A, |
286 &mut μ, &mut τv, &μ_base, None, |
366 b : &'a A::Observable, |
287 τ, ε, |
367 α : F, |
288 config, ®, &state, &mut stats |
368 op𝒟 : &'a 𝒟, |
289 ); |
369 config : &PDPSConfig<F>, |
290 |
370 iterator : I, |
291 // Prune and possibly merge spikes |
371 plotter : SeqPlotter<F, N>, |
292 if config.merge_now(&state) { |
372 dataterm : D, |
293 stats.merged += prox_penalty.merge_spikes_no_fitness( |
373 ) -> DiscreteMeasure<Loc<F, N>, F> |
294 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, |
374 where F : Float + ToNalgebraRealField, |
295 ); |
375 I : AlgIteratorFactory<IterInfo<F, N>>, |
296 } |
376 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> |
297 stats.pruned += prune_with_stats(&mut μ); |
377 + std::ops::Add<A::Observable, Output=A::Observable>, |
298 |
378 A::Observable : std::ops::MulAssign<F>, |
299 // Update step length parameters |
379 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
300 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
380 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
301 |
381 + Lipschitz<𝒟, FloatType=F>, |
302 // Do dual update |
382 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
303 y = b.clone(); // y = b |
383 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
304 opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b |
384 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
305 opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b |
385 𝒟::Codomain : RealMapping<F, N>, |
306 dataterm.dual_update(&mut y, &y_prev, σ); |
386 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
307 y_prev.copy_from(&y); |
387 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
308 |
388 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
309 // Give statistics if requested |
389 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
310 let iter = state.iteration(); |
390 PlotLookup : Plotting<N>, |
311 stats.this_iters += 1; |
391 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
312 |
392 PDPS<'a, F, A, D, N> : FBSpecialisation<F, A::Observable, N>, |
313 state.if_verbose(|| { |
393 D : Subdifferentiable<F, A::Observable> { |
314 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); |
394 |
315 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
395 pointsource_pdps_reg(opA, b, NonnegRadonRegTerm(α), op𝒟, config, iterator, plotter, dataterm) |
316 }); |
396 } |
317 |
|
318 ε = tolerance.update(ε, iter); |
|
319 } |
|
320 |
|
321 postprocess(μ, config, dataterm, opA, b) |
|
322 } |
|
323 |