129 insertion : Default::default() |
135 insertion : Default::default() |
130 } |
136 } |
131 } |
137 } |
132 } |
138 } |
133 |
139 |
134 /// Trait for subdifferentiable objects |
140 /// Trait for data terms for the PDPS |
135 pub trait Subdifferentiable<F : Float, V, U=V> { |
141 #[replace_float_literals(F::cast_from(literal))] |
136 /// Calculate some subdifferential at `x` |
142 pub trait PDPSDataTerm<F : Float, V, const N : usize> : DataTerm<F, V, N> { |
137 fn some_subdifferential(&self, x : V) -> U; |
143 /// Calculate some subdifferential at `x` for the conjugate |
138 } |
144 fn some_subdifferential(&self, x : V) -> V; |
139 |
145 |
140 /// Type for indicating norm-2-squared data fidelity. |
146 /// Factor of strong convexity of the conjugate |
141 pub struct L2Squared; |
147 #[inline] |
142 |
148 fn factor_of_strong_convexity(&self) -> F { |
143 impl<F : Float, V : Euclidean<F>> Subdifferentiable<F, V> for L2Squared { |
149 0.0 |
|
150 } |
|
151 |
|
152 /// Perform dual update |
|
153 fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F); |
|
154 } |
|
155 |
|
156 |
|
157 #[replace_float_literals(F::cast_from(literal))] |
|
158 impl<F : Float, V : Euclidean<F> + AXPY<F>, const N : usize> |
|
159 PDPSDataTerm<F, V, N> |
|
160 for L2Squared { |
144 fn some_subdifferential(&self, x : V) -> V { x } |
161 fn some_subdifferential(&self, x : V) -> V { x } |
145 } |
162 |
146 |
163 fn factor_of_strong_convexity(&self) -> F { |
147 impl<F : Float + nalgebra::RealField> Subdifferentiable<F, DVector<F>> for L1 { |
164 1.0 |
|
165 } |
|
166 |
|
167 #[inline] |
|
168 fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { |
|
169 y.axpy(1.0 / (1.0 + σ), &y_prev, σ / (1.0 + σ)); |
|
170 } |
|
171 } |
|
172 |
|
173 #[replace_float_literals(F::cast_from(literal))] |
|
174 impl<F : Float + nalgebra::RealField, const N : usize> |
|
175 PDPSDataTerm<F, DVector<F>, N> |
|
176 for L1 { |
148 fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> { |
177 fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> { |
149 // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. |
178 // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. |
150 x.iter_mut() |
179 x.iter_mut() |
151 .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) }); |
180 .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) }); |
152 x |
181 x |
153 } |
182 } |
154 } |
183 |
155 |
184 #[inline] |
156 /// Specialisation of [`generic_pointsource_fb_reg`] to PDPS. |
185 fn dual_update(&self, y : &mut DVector<F>, y_prev : &DVector<F>, σ : F) { |
157 pub struct PDPS< |
186 y.axpy(1.0, y_prev, σ); |
158 'a, |
|
159 F : Float + ToNalgebraRealField, |
|
160 A : ForwardModel<Loc<F, N>, F>, |
|
161 D, |
|
162 const N : usize |
|
163 > { |
|
164 /// The data |
|
165 b : &'a A::Observable, |
|
166 /// The forward operator |
|
167 opA : &'a A, |
|
168 /// Primal step length |
|
169 τ : F, |
|
170 // Dual step length |
|
171 σ : F, |
|
172 /// Whether acceleration should be applied (if data term supports) |
|
173 acceleration : Acceleration, |
|
174 /// The dataterm. Only used by the type system. |
|
175 _dataterm : D, |
|
176 /// Previous dual iterate. |
|
177 y_prev : A::Observable, |
|
178 } |
|
179 |
|
180 /// Implementation of [`FBSpecialisation`] for μPDPS with norm-2-squared data fidelity. |
|
181 #[replace_float_literals(F::cast_from(literal))] |
|
182 impl< |
|
183 'a, |
|
184 F : Float + ToNalgebraRealField, |
|
185 A : ForwardModel<Loc<F, N>, F>, |
|
186 const N : usize |
|
187 > FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L2Squared, N> |
|
188 where for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> { |
|
189 |
|
190 fn update( |
|
191 &mut self, |
|
192 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
193 μ_base : &DiscreteMeasure<Loc<F, N>, F> |
|
194 ) -> (A::Observable, Option<F>) { |
|
195 let σ = self.σ; |
|
196 let τ = self.τ; |
|
197 let ω = match self.acceleration { |
|
198 Acceleration::None => 1.0, |
|
199 Acceleration::Partial => { |
|
200 let ω = 1.0 / (1.0 + σ).sqrt(); |
|
201 self.σ = σ * ω; |
|
202 self.τ = τ / ω; |
|
203 ω |
|
204 }, |
|
205 Acceleration::Full => { |
|
206 let ω = 1.0 / (1.0 + 2.0 * σ).sqrt(); |
|
207 self.σ = σ * ω; |
|
208 self.τ = τ / ω; |
|
209 ω |
|
210 }, |
|
211 }; |
|
212 |
|
213 μ.prune(); |
|
214 |
|
215 let mut y = self.b.clone(); |
|
216 self.opA.gemv(&mut y, 1.0 + ω, μ, -1.0); |
|
217 self.opA.gemv(&mut y, -ω, μ_base, 1.0); |
|
218 y.axpy(1.0 / (1.0 + σ), &self.y_prev, σ / (1.0 + σ)); |
|
219 self.y_prev.copy_from(&y); |
|
220 |
|
221 (y, Some(self.τ)) |
|
222 } |
|
223 |
|
224 fn calculate_fit( |
|
225 &self, |
|
226 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
227 _y : &A::Observable |
|
228 ) -> F { |
|
229 self.calculate_fit_simple(μ) |
|
230 } |
|
231 |
|
232 fn calculate_fit_simple( |
|
233 &self, |
|
234 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
235 ) -> F { |
|
236 let mut residual = self.b.clone(); |
|
237 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
|
238 residual.norm2_squared_div2() |
|
239 } |
|
240 } |
|
241 |
|
242 /// Implementation of [`FBSpecialisation`] for μPDPS with norm-1 data fidelity. |
|
243 #[replace_float_literals(F::cast_from(literal))] |
|
244 impl< |
|
245 'a, |
|
246 F : Float + ToNalgebraRealField, |
|
247 A : ForwardModel<Loc<F, N>, F>, |
|
248 const N : usize |
|
249 > FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L1, N> |
|
250 where A::Observable : Projection<F, Linfinity> + Norm<F, L1>, |
|
251 for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> { |
|
252 fn update( |
|
253 &mut self, |
|
254 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
|
255 μ_base : &DiscreteMeasure<Loc<F, N>, F> |
|
256 ) -> (A::Observable, Option<F>) { |
|
257 let σ = self.σ; |
|
258 |
|
259 μ.prune(); |
|
260 |
|
261 //let ȳ = self.opA.apply(μ) * 2.0 - self.opA.apply(μ_base); |
|
262 //*y = proj_{[-1,1]}(&self.y_prev + (ȳ - self.b) * σ) |
|
263 let mut y = self.y_prev.clone(); |
|
264 self.opA.gemv(&mut y, 2.0 * σ, μ, 1.0); |
|
265 self.opA.gemv(&mut y, -σ, μ_base, 1.0); |
|
266 y.axpy(-σ, self.b, 1.0); |
|
267 y.proj_ball_mut(1.0, Linfinity); |
187 y.proj_ball_mut(1.0, Linfinity); |
268 self.y_prev.copy_from(&y); |
|
269 |
|
270 (y, None) |
|
271 } |
|
272 |
|
273 fn calculate_fit( |
|
274 &self, |
|
275 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
276 _y : &A::Observable |
|
277 ) -> F { |
|
278 self.calculate_fit_simple(μ) |
|
279 } |
|
280 |
|
281 fn calculate_fit_simple( |
|
282 &self, |
|
283 μ : &DiscreteMeasure<Loc<F, N>, F>, |
|
284 ) -> F { |
|
285 let mut residual = self.b.clone(); |
|
286 self.opA.gemv(&mut residual, 1.0, μ, -1.0); |
|
287 residual.norm(L1) |
|
288 } |
188 } |
289 } |
189 } |
290 |
190 |
291 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. |
191 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. |
292 /// |
192 /// |
304 pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( |
204 pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( |
305 opA : &'a A, |
205 opA : &'a A, |
306 b : &'a A::Observable, |
206 b : &'a A::Observable, |
307 reg : Reg, |
207 reg : Reg, |
308 op𝒟 : &'a 𝒟, |
208 op𝒟 : &'a 𝒟, |
309 config : &PDPSConfig<F>, |
209 pdpsconfig : &PDPSConfig<F>, |
310 iterator : I, |
210 iterator : I, |
311 plotter : SeqPlotter<F, N>, |
211 mut plotter : SeqPlotter<F, N>, |
312 dataterm : D, |
212 dataterm : D, |
313 ) -> DiscreteMeasure<Loc<F, N>, F> |
213 ) -> DiscreteMeasure<Loc<F, N>, F> |
314 where F : Float + ToNalgebraRealField, |
214 where F : Float + ToNalgebraRealField, |
315 I : AlgIteratorFactory<IterInfo<F, N>>, |
215 I : AlgIteratorFactory<IterInfo<F, N>>, |
316 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> |
216 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> |
317 + std::ops::Add<A::Observable, Output=A::Observable>, |
217 + std::ops::Add<A::Observable, Output=A::Observable>, |
318 //+ std::ops::Mul<F, Output=A::Observable>, // <-- FIXME: compiler overflow |
218 //+ std::ops::Mul<F, Output=A::Observable>, // <-- FIXME: compiler overflow |
319 A::Observable : std::ops::MulAssign<F>, |
219 A::Observable : std::ops::MulAssign<F>, |
320 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
220 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
321 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
221 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
322 + Lipschitz<𝒟, FloatType=F>, |
222 + Lipschitz<&'a 𝒟, FloatType=F>, |
323 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
223 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
324 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
224 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
325 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
225 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
326 𝒟::Codomain : RealMapping<F, N>, |
226 𝒟::Codomain : RealMapping<F, N>, |
327 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
227 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
328 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
228 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
329 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
229 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
330 PlotLookup : Plotting<N>, |
230 PlotLookup : Plotting<N>, |
331 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
231 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
332 PDPS<'a, F, A, D, N> : FBSpecialisation<F, A::Observable, N>, |
232 D : PDPSDataTerm<F, A::Observable, N>, |
333 D : Subdifferentiable<F, A::Observable>, |
|
334 Reg : RegTerm<F, N> { |
233 Reg : RegTerm<F, N> { |
335 |
234 |
336 let y = dataterm.some_subdifferential(-b); |
235 // Set up parameters |
|
236 let config = &pdpsconfig.insertion; |
|
237 let op𝒟norm = op𝒟.opnorm_bound(); |
337 let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt(); |
238 let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt(); |
338 let τ = config.τ0 / l; |
239 let mut τ = pdpsconfig.τ0 / l; |
339 let σ = config.σ0 / l; |
240 let mut σ = pdpsconfig.σ0 / l; |
340 |
241 let γ = dataterm.factor_of_strong_convexity(); |
341 let pdps = PDPS { |
242 |
342 b, |
243 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
343 opA, |
244 // by τ compared to the conditional gradient approach. |
344 τ, |
245 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
345 σ, |
246 let mut ε = tolerance.initial(); |
346 acceleration : config.acceleration, |
247 |
347 _dataterm : dataterm, |
248 // Initialise iterates |
348 y_prev : y.clone(), |
249 let mut μ = DiscreteMeasure::new(); |
349 }; |
250 let mut y = dataterm.some_subdifferential(-b); |
350 |
251 let mut y_prev = y.clone(); |
351 generic_pointsource_fb_reg( |
252 let mut stats = IterInfo::new(); |
352 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, y, pdps |
253 |
353 ) |
254 // Run the algorithm |
354 } |
255 iterator.iterate(|state| { |
355 |
256 // Calculate smooth part of surrogate model. |
|
257 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
|
258 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
259 // the residual and replacing it below before the end of this closure. |
|
260 y *= -τ; |
|
261 let r = std::mem::replace(&mut y, opA.empty_observable()); |
|
262 let minus_τv = opA.preadjoint().apply(r); |
|
263 |
|
264 // Save current base point |
|
265 let μ_base = μ.clone(); |
|
266 |
|
267 // Insert and reweigh |
|
268 let (d, within_tolerances) = insert_and_reweigh( |
|
269 &mut μ, &minus_τv, &μ_base, None, |
|
270 op𝒟, op𝒟norm, |
|
271 τ, ε, |
|
272 config, ®, state, &mut stats |
|
273 ); |
|
274 |
|
275 // Prune and possibly merge spikes |
|
276 prune_and_maybe_simple_merge( |
|
277 &mut μ, &minus_τv, &μ_base, |
|
278 op𝒟, |
|
279 τ, ε, |
|
280 config, ®, state, &mut stats |
|
281 ); |
|
282 |
|
283 // Update step length parameters |
|
284 let ω = match pdpsconfig.acceleration { |
|
285 Acceleration::None => 1.0, |
|
286 Acceleration::Partial => { |
|
287 let ω = 1.0 / (1.0 + γ * σ).sqrt(); |
|
288 σ = σ * ω; |
|
289 τ = τ / ω; |
|
290 ω |
|
291 }, |
|
292 Acceleration::Full => { |
|
293 let ω = 1.0 / (1.0 + 2.0 * γ * σ).sqrt(); |
|
294 σ = σ * ω; |
|
295 τ = τ / ω; |
|
296 ω |
|
297 }, |
|
298 }; |
|
299 |
|
300 // Do dual update |
|
301 y = b.clone(); // y = b |
|
302 opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b |
|
303 opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b |
|
304 dataterm.dual_update(&mut y, &y_prev, σ); |
|
305 y_prev.copy_from(&y); |
|
306 |
|
307 // Update main tolerance for next iteration |
|
308 let ε_prev = ε; |
|
309 ε = tolerance.update(ε, state.iteration()); |
|
310 stats.this_iters += 1; |
|
311 |
|
312 // Give function value if needed |
|
313 state.if_verbose(|| { |
|
314 // Plot if so requested |
|
315 plotter.plot_spikes( |
|
316 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, |
|
317 "start".to_string(), Some(&minus_τv), |
|
318 reg.target_bounds(τ, ε_prev), &μ, |
|
319 ); |
|
320 // Calculate mean inner iterations and reset relevant counters. |
|
321 // Return the statistics |
|
322 let res = IterInfo { |
|
323 value : dataterm.calculate_fit_op(&μ, opA, b) + reg.apply(&μ), |
|
324 n_spikes : μ.len(), |
|
325 ε : ε_prev, |
|
326 postprocessing: config.postprocessing.then(|| μ.clone()), |
|
327 .. stats |
|
328 }; |
|
329 stats = IterInfo::new(); |
|
330 res |
|
331 }) |
|
332 }); |
|
333 |
|
334 postprocess(μ, config, dataterm, opA, b) |
|
335 } |
|
336 |