35 This is the task of <code>generic_pointsource_fb</code>, where we use <code>FBSpecialisation</code> |
34 This is the task of <code>generic_pointsource_fb</code>, where we use <code>FBSpecialisation</code> |
36 to replace the specific residual $Aμ-b$ by $y$. |
35 to replace the specific residual $Aμ-b$ by $y$. |
37 For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$. |
36 For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$. |
38 For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$. |
37 For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$. |
39 </p> |
38 </p> |
40 |
|
41 Based on zero initialisation for $μ$, we use the [`Subdifferentiable`] trait to make an |
|
42 initialisation corresponding to the second part of the optimality conditions. |
|
43 In the algorithm itself, standard proximal steps are taking with respect to $F\_0^* + ⟨b, ·⟩$. |
|
44 */ |
39 */ |
45 |
40 |
46 use numeric_literals::replace_float_literals; |
41 use numeric_literals::replace_float_literals; |
47 use serde::{Serialize, Deserialize}; |
42 use serde::{Serialize, Deserialize}; |
48 use nalgebra::DVector; |
43 use nalgebra::DVector; |
49 use clap::ValueEnum; |
44 use clap::ValueEnum; |
50 |
45 |
51 use alg_tools::iterate::{ |
46 use alg_tools::iterate::AlgIteratorFactory; |
52 AlgIteratorFactory, |
|
53 AlgIteratorState, |
|
54 }; |
|
55 use alg_tools::loc::Loc; |
47 use alg_tools::loc::Loc; |
56 use alg_tools::euclidean::Euclidean; |
48 use alg_tools::euclidean::Euclidean; |
57 use alg_tools::linops::Apply; |
49 use alg_tools::linops::Mapping; |
58 use alg_tools::norms::{ |
50 use alg_tools::norms::{ |
59 Linfinity, |
51 Linfinity, |
60 Projection, |
52 Projection, |
61 }; |
53 }; |
62 use alg_tools::bisection_tree::{ |
54 use alg_tools::bisection_tree::{ |
108 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed |
103 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed |
109 #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] |
104 #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] |
110 Full |
105 Full |
111 } |
106 } |
112 |
107 |
113 /// Settings for [`pointsource_pdps`]. |
108 #[replace_float_literals(F::cast_from(literal))] |
|
109 impl Acceleration { |
|
110 /// PDPS parameter acceleration. Updates τ and σ and returns ω. |
|
111 /// This uses dual strong convexity, not primal. |
|
112 fn accelerate<F : Float>(self, τ : &mut F, σ : &mut F, γ : F) -> F { |
|
113 match self { |
|
114 Acceleration::None => 1.0, |
|
115 Acceleration::Partial => { |
|
116 let ω = 1.0 / (1.0 + γ * (*σ)).sqrt(); |
|
117 *σ *= ω; |
|
118 *τ /= ω; |
|
119 ω |
|
120 }, |
|
121 Acceleration::Full => { |
|
122 let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt(); |
|
123 *σ *= ω; |
|
124 *τ /= ω; |
|
125 ω |
|
126 }, |
|
127 } |
|
128 } |
|
129 } |
|
130 |
|
131 /// Settings for [`pointsource_pdps_reg`]. |
114 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
132 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
115 #[serde(default)] |
133 #[serde(default)] |
116 pub struct PDPSConfig<F : Float> { |
134 pub struct PDPSConfig<F : Float> { |
117 /// Primal step length scaling. We must have `τ0 * σ0 < 1`. |
135 /// Primal step length scaling. We must have `τ0 * σ0 < 1`. |
118 pub τ0 : F, |
136 pub τ0 : F, |
153 fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F); |
171 fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F); |
154 } |
172 } |
155 |
173 |
156 |
174 |
157 #[replace_float_literals(F::cast_from(literal))] |
175 #[replace_float_literals(F::cast_from(literal))] |
158 impl<F : Float, V : Euclidean<F> + AXPY<F>, const N : usize> |
176 impl<F, V, const N : usize> PDPSDataTerm<F, V, N> |
159 PDPSDataTerm<F, V, N> |
177 for L2Squared |
160 for L2Squared { |
178 where |
|
179 F : Float, |
|
180 V : Euclidean<F> + AXPY<F>, |
|
181 for<'b> &'b V : Instance<V>, |
|
182 { |
161 fn some_subdifferential(&self, x : V) -> V { x } |
183 fn some_subdifferential(&self, x : V) -> V { x } |
162 |
184 |
163 fn factor_of_strong_convexity(&self) -> F { |
185 fn factor_of_strong_convexity(&self) -> F { |
164 1.0 |
186 1.0 |
165 } |
187 } |
166 |
188 |
167 #[inline] |
189 #[inline] |
168 fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { |
190 fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { |
169 y.axpy(1.0 / (1.0 + σ), &y_prev, σ / (1.0 + σ)); |
191 y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ)); |
170 } |
192 } |
171 } |
193 } |
172 |
194 |
173 #[replace_float_literals(F::cast_from(literal))] |
195 #[replace_float_literals(F::cast_from(literal))] |
174 impl<F : Float + nalgebra::RealField, const N : usize> |
196 impl<F : Float + nalgebra::RealField, const N : usize> |
208 op𝒟 : &'a 𝒟, |
230 op𝒟 : &'a 𝒟, |
209 pdpsconfig : &PDPSConfig<F>, |
231 pdpsconfig : &PDPSConfig<F>, |
210 iterator : I, |
232 iterator : I, |
211 mut plotter : SeqPlotter<F, N>, |
233 mut plotter : SeqPlotter<F, N>, |
212 dataterm : D, |
234 dataterm : D, |
213 ) -> DiscreteMeasure<Loc<F, N>, F> |
235 ) -> RNDM<F, N> |
214 where F : Float + ToNalgebraRealField, |
236 where F : Float + ToNalgebraRealField, |
215 I : AlgIteratorFactory<IterInfo<F, N>>, |
237 I : AlgIteratorFactory<IterInfo<F, N>>, |
216 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> |
238 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
217 + std::ops::Add<A::Observable, Output=A::Observable>, |
|
218 //+ std::ops::Mul<F, Output=A::Observable>, // <-- FIXME: compiler overflow |
|
219 A::Observable : std::ops::MulAssign<F>, |
|
220 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
239 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
221 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
240 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
222 + Lipschitz<&'a 𝒟, FloatType=F>, |
241 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
223 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
242 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
224 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
243 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
225 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
244 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
226 𝒟::Codomain : RealMapping<F, N>, |
245 𝒟::Codomain : RealMapping<F, N>, |
227 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
246 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
228 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
247 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
229 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
248 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
230 PlotLookup : Plotting<N>, |
249 PlotLookup : Plotting<N>, |
231 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
250 RNDM<F, N> : SpikeMerging<F>, |
232 D : PDPSDataTerm<F, A::Observable, N>, |
251 D : PDPSDataTerm<F, A::Observable, N>, |
233 Reg : RegTerm<F, N> { |
252 Reg : RegTerm<F, N> { |
234 |
253 |
|
254 // Check parameters |
|
255 assert!(pdpsconfig.τ0 > 0.0 && |
|
256 pdpsconfig.σ0 > 0.0 && |
|
257 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, |
|
258 "Invalid step length parameters"); |
|
259 |
235 // Set up parameters |
260 // Set up parameters |
236 let config = &pdpsconfig.generic; |
261 let config = &pdpsconfig.generic; |
237 let op𝒟norm = op𝒟.opnorm_bound(); |
262 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
238 let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt(); |
263 let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt(); |
239 let mut τ = pdpsconfig.τ0 / l; |
264 let mut τ = pdpsconfig.τ0 / l; |
240 let mut σ = pdpsconfig.σ0 / l; |
265 let mut σ = pdpsconfig.σ0 / l; |
241 let γ = dataterm.factor_of_strong_convexity(); |
266 let γ = dataterm.factor_of_strong_convexity(); |
242 |
267 |
243 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
268 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
247 |
272 |
248 // Initialise iterates |
273 // Initialise iterates |
249 let mut μ = DiscreteMeasure::new(); |
274 let mut μ = DiscreteMeasure::new(); |
250 let mut y = dataterm.some_subdifferential(-b); |
275 let mut y = dataterm.some_subdifferential(-b); |
251 let mut y_prev = y.clone(); |
276 let mut y_prev = y.clone(); |
|
277 let full_stats = |μ : &RNDM<F, N>, ε, stats| IterInfo { |
|
278 value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ), |
|
279 n_spikes : μ.len(), |
|
280 ε, |
|
281 // postprocessing: config.postprocessing.then(|| μ.clone()), |
|
282 .. stats |
|
283 }; |
252 let mut stats = IterInfo::new(); |
284 let mut stats = IterInfo::new(); |
253 |
285 |
254 // Run the algorithm |
286 // Run the algorithm |
255 iterator.iterate(|state| { |
287 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
256 // Calculate smooth part of surrogate model. |
288 // Calculate smooth part of surrogate model. |
257 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
289 let τv = opA.preadjoint().apply(y * τ); |
258 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
|
259 // the residual and replacing it below before the end of this closure. |
|
260 y *= -τ; |
|
261 let r = std::mem::replace(&mut y, opA.empty_observable()); |
|
262 let minus_τv = opA.preadjoint().apply(r); |
|
263 |
290 |
264 // Save current base point |
291 // Save current base point |
265 let μ_base = μ.clone(); |
292 let μ_base = μ.clone(); |
266 |
293 |
267 // Insert and reweigh |
294 // Insert and reweigh |
268 let (d, within_tolerances) = insert_and_reweigh( |
295 let (d, _within_tolerances) = insert_and_reweigh( |
269 &mut μ, &minus_τv, &μ_base, None, |
296 &mut μ, &τv, &μ_base, None, |
270 op𝒟, op𝒟norm, |
297 op𝒟, op𝒟norm, |
271 τ, ε, |
298 τ, ε, |
272 config, ®, state, &mut stats |
299 config, ®, &state, &mut stats |
273 ); |
300 ); |
274 |
301 |
275 // Prune and possibly merge spikes |
302 // Prune and possibly merge spikes |
276 prune_and_maybe_simple_merge( |
303 if config.merge_now(&state) { |
277 &mut μ, &minus_τv, &μ_base, |
304 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
278 op𝒟, |
305 let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); |
279 τ, ε, |
306 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
280 config, ®, state, &mut stats |
307 }); |
281 ); |
308 } |
|
309 stats.pruned += prune_with_stats(&mut μ); |
282 |
310 |
283 // Update step length parameters |
311 // Update step length parameters |
284 let ω = match pdpsconfig.acceleration { |
312 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
285 Acceleration::None => 1.0, |
|
286 Acceleration::Partial => { |
|
287 let ω = 1.0 / (1.0 + γ * σ).sqrt(); |
|
288 σ = σ * ω; |
|
289 τ = τ / ω; |
|
290 ω |
|
291 }, |
|
292 Acceleration::Full => { |
|
293 let ω = 1.0 / (1.0 + 2.0 * γ * σ).sqrt(); |
|
294 σ = σ * ω; |
|
295 τ = τ / ω; |
|
296 ω |
|
297 }, |
|
298 }; |
|
299 |
313 |
300 // Do dual update |
314 // Do dual update |
301 y = b.clone(); // y = b |
315 y = b.clone(); // y = b |
302 opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b |
316 opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b |
303 opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b |
317 opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b |
304 dataterm.dual_update(&mut y, &y_prev, σ); |
318 dataterm.dual_update(&mut y, &y_prev, σ); |
305 y_prev.copy_from(&y); |
319 y_prev.copy_from(&y); |
306 |
320 |
307 // Update main tolerance for next iteration |
321 // Give statistics if requested |
308 let ε_prev = ε; |
322 let iter = state.iteration(); |
309 ε = tolerance.update(ε, state.iteration()); |
|
310 stats.this_iters += 1; |
323 stats.this_iters += 1; |
311 |
324 |
312 // Give function value if needed |
|
313 state.if_verbose(|| { |
325 state.if_verbose(|| { |
314 // Plot if so requested |
326 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); |
315 plotter.plot_spikes( |
327 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) |
316 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, |
328 }); |
317 "start".to_string(), Some(&minus_τv), |
329 |
318 reg.target_bounds(τ, ε_prev), &μ, |
330 ε = tolerance.update(ε, iter); |
319 ); |
331 } |
320 // Calculate mean inner iterations and reset relevant counters. |
|
321 // Return the statistics |
|
322 let res = IterInfo { |
|
323 value : dataterm.calculate_fit_op(&μ, opA, b) + reg.apply(&μ), |
|
324 n_spikes : μ.len(), |
|
325 ε : ε_prev, |
|
326 postprocessing: config.postprocessing.then(|| μ.clone()), |
|
327 .. stats |
|
328 }; |
|
329 stats = IterInfo::new(); |
|
330 res |
|
331 }) |
|
332 }); |
|
333 |
332 |
334 postprocess(μ, config, dataterm, opA, b) |
333 postprocess(μ, config, dataterm, opA, b) |
335 } |
334 } |
336 |
335 |