src/pdps.rs

branch
dev
changeset 35
b087e3eab191
parent 34
efa60bc4f743
child 37
c5d8bd1a7728
equal deleted inserted replaced
34:efa60bc4f743 35:b087e3eab191
4 This corresponds to the manuscript 4 This corresponds to the manuscript
5 5
6 * Valkonen T. - _Proximal methods for point source localisation_, 6 * Valkonen T. - _Proximal methods for point source localisation_,
7 [arXiv:2212.02991](https://arxiv.org/abs/2212.02991). 7 [arXiv:2212.02991](https://arxiv.org/abs/2212.02991).
8 8
9 The main routine is [`pointsource_pdps`]. It is based on specilisatinn of 9 The main routine is [`pointsource_pdps_reg`].
10 [`generic_pointsource_fb_reg`] through relevant [`FBSpecialisation`] implementations.
11 Both norm-2-squared and norm-1 data terms are supported. That is, implemented are solvers for 10 Both norm-2-squared and norm-1 data terms are supported. That is, implemented are solvers for
12 <div> 11 <div>
13 $$ 12 $$
14 \min_{μ ∈ ℳ(Ω)}~ F_0(Aμ - b) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ), 13 \min_{μ ∈ ℳ(Ω)}~ F_0(Aμ - b) + α \|μ\|_{ℳ(Ω)} + δ_{≥ 0}(μ),
15 $$ 14 $$
35 This is the task of <code>generic_pointsource_fb</code>, where we use <code>FBSpecialisation</code> 34 This is the task of <code>generic_pointsource_fb</code>, where we use <code>FBSpecialisation</code>
36 to replace the specific residual $Aμ-b$ by $y$. 35 to replace the specific residual $Aμ-b$ by $y$.
37 For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$. 36 For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$.
38 For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$. 37 For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$.
39 </p> 38 </p>
40
41 Based on zero initialisation for $μ$, we use the [`Subdifferentiable`] trait to make an
42 initialisation corresponding to the second part of the optimality conditions.
43 In the algorithm itself, standard proximal steps are taking with respect to $F\_0^* + ⟨b, ·⟩$.
44 */ 39 */
45 40
46 use numeric_literals::replace_float_literals; 41 use numeric_literals::replace_float_literals;
47 use serde::{Serialize, Deserialize}; 42 use serde::{Serialize, Deserialize};
48 use nalgebra::DVector; 43 use nalgebra::DVector;
49 use clap::ValueEnum; 44 use clap::ValueEnum;
50 45
51 use alg_tools::iterate::{ 46 use alg_tools::iterate::AlgIteratorFactory;
52 AlgIteratorFactory,
53 AlgIteratorState,
54 };
55 use alg_tools::loc::Loc; 47 use alg_tools::loc::Loc;
56 use alg_tools::euclidean::Euclidean; 48 use alg_tools::euclidean::Euclidean;
57 use alg_tools::linops::Apply; 49 use alg_tools::linops::Mapping;
58 use alg_tools::norms::{ 50 use alg_tools::norms::{
59 Linfinity, 51 Linfinity,
60 Projection, 52 Projection,
61 }; 53 };
62 use alg_tools::bisection_tree::{ 54 use alg_tools::bisection_tree::{
67 BTNode, 59 BTNode,
68 BTSearch, 60 BTSearch,
69 SupportGenerator, 61 SupportGenerator,
70 LocalAnalysis, 62 LocalAnalysis,
71 }; 63 };
72 use alg_tools::mapping::RealMapping; 64 use alg_tools::mapping::{RealMapping, Instance};
73 use alg_tools::nalgebra_support::ToNalgebraRealField; 65 use alg_tools::nalgebra_support::ToNalgebraRealField;
74 use alg_tools::linops::AXPY; 66 use alg_tools::linops::AXPY;
75 67
76 use crate::types::*; 68 use crate::types::*;
77 use crate::measures::DiscreteMeasure; 69 use crate::measures::{DiscreteMeasure, RNDM, Radon};
78 use crate::measures::merging::SpikeMerging; 70 use crate::measures::merging::SpikeMerging;
79 use crate::forward_model::ForwardModel; 71 use crate::forward_model::{
72 AdjointProductBoundedBy,
73 ForwardModel
74 };
80 use crate::seminorms::DiscreteMeasureOp; 75 use crate::seminorms::DiscreteMeasureOp;
81 use crate::plot::{ 76 use crate::plot::{
82 SeqPlotter, 77 SeqPlotter,
83 Plotting, 78 Plotting,
84 PlotLookup 79 PlotLookup
85 }; 80 };
86 use crate::fb::{ 81 use crate::fb::{
87 FBGenericConfig, 82 FBGenericConfig,
88 insert_and_reweigh, 83 insert_and_reweigh,
89 postprocess, 84 postprocess,
90 prune_and_maybe_simple_merge 85 prune_with_stats
91 }; 86 };
92 use crate::regularisation::RegTerm; 87 use crate::regularisation::RegTerm;
93 use crate::dataterm::{ 88 use crate::dataterm::{
94 DataTerm, 89 DataTerm,
95 L2Squared, 90 L2Squared,
108 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed 103 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed
109 #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] 104 #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")]
110 Full 105 Full
111 } 106 }
112 107
113 /// Settings for [`pointsource_pdps`]. 108 #[replace_float_literals(F::cast_from(literal))]
109 impl Acceleration {
110 /// PDPS parameter acceleration. Updates τ and σ and returns ω.
111 /// This uses dual strong convexity, not primal.
112 fn accelerate<F : Float>(self, τ : &mut F, σ : &mut F, γ : F) -> F {
113 match self {
114 Acceleration::None => 1.0,
115 Acceleration::Partial => {
116 let ω = 1.0 / (1.0 + γ * (*σ)).sqrt();
117 *σ *= ω;
118 *τ /= ω;
119 ω
120 },
121 Acceleration::Full => {
122 let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt();
123 *σ *= ω;
124 *τ /= ω;
125 ω
126 },
127 }
128 }
129 }
130
131 /// Settings for [`pointsource_pdps_reg`].
114 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 132 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
115 #[serde(default)] 133 #[serde(default)]
116 pub struct PDPSConfig<F : Float> { 134 pub struct PDPSConfig<F : Float> {
117 /// Primal step length scaling. We must have `τ0 * σ0 < 1`. 135 /// Primal step length scaling. We must have `τ0 * σ0 < 1`.
118 pub τ0 : F, 136 pub τ0 : F,
153 fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F); 171 fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F);
154 } 172 }
155 173
156 174
157 #[replace_float_literals(F::cast_from(literal))] 175 #[replace_float_literals(F::cast_from(literal))]
158 impl<F : Float, V : Euclidean<F> + AXPY<F>, const N : usize> 176 impl<F, V, const N : usize> PDPSDataTerm<F, V, N>
159 PDPSDataTerm<F, V, N> 177 for L2Squared
160 for L2Squared { 178 where
179 F : Float,
180 V : Euclidean<F> + AXPY<F>,
181 for<'b> &'b V : Instance<V>,
182 {
161 fn some_subdifferential(&self, x : V) -> V { x } 183 fn some_subdifferential(&self, x : V) -> V { x }
162 184
163 fn factor_of_strong_convexity(&self) -> F { 185 fn factor_of_strong_convexity(&self) -> F {
164 1.0 186 1.0
165 } 187 }
166 188
167 #[inline] 189 #[inline]
168 fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { 190 fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) {
169 y.axpy(1.0 / (1.0 + σ), &y_prev, σ / (1.0 + σ)); 191 y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ));
170 } 192 }
171 } 193 }
172 194
173 #[replace_float_literals(F::cast_from(literal))] 195 #[replace_float_literals(F::cast_from(literal))]
174 impl<F : Float + nalgebra::RealField, const N : usize> 196 impl<F : Float + nalgebra::RealField, const N : usize>
208 op𝒟 : &'a 𝒟, 230 op𝒟 : &'a 𝒟,
209 pdpsconfig : &PDPSConfig<F>, 231 pdpsconfig : &PDPSConfig<F>,
210 iterator : I, 232 iterator : I,
211 mut plotter : SeqPlotter<F, N>, 233 mut plotter : SeqPlotter<F, N>,
212 dataterm : D, 234 dataterm : D,
213 ) -> DiscreteMeasure<Loc<F, N>, F> 235 ) -> RNDM<F, N>
214 where F : Float + ToNalgebraRealField, 236 where F : Float + ToNalgebraRealField,
215 I : AlgIteratorFactory<IterInfo<F, N>>, 237 I : AlgIteratorFactory<IterInfo<F, N>>,
216 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> 238 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
217 + std::ops::Add<A::Observable, Output=A::Observable>,
218 //+ std::ops::Mul<F, Output=A::Observable>, // <-- FIXME: compiler overflow
219 A::Observable : std::ops::MulAssign<F>,
220 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 239 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
221 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> 240 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
222 + Lipschitz<&'a 𝒟, FloatType=F>, 241 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
223 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 242 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
224 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 243 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
225 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 244 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
226 𝒟::Codomain : RealMapping<F, N>, 245 𝒟::Codomain : RealMapping<F, N>,
227 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 246 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
228 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 247 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
229 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 248 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
230 PlotLookup : Plotting<N>, 249 PlotLookup : Plotting<N>,
231 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 250 RNDM<F, N> : SpikeMerging<F>,
232 D : PDPSDataTerm<F, A::Observable, N>, 251 D : PDPSDataTerm<F, A::Observable, N>,
233 Reg : RegTerm<F, N> { 252 Reg : RegTerm<F, N> {
234 253
254 // Check parameters
255 assert!(pdpsconfig.τ0 > 0.0 &&
256 pdpsconfig.σ0 > 0.0 &&
257 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0,
258 "Invalid step length parameters");
259
235 // Set up parameters 260 // Set up parameters
236 let config = &pdpsconfig.generic; 261 let config = &pdpsconfig.generic;
237 let op𝒟norm = op𝒟.opnorm_bound(); 262 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
238 let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt(); 263 let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt();
239 let mut τ = pdpsconfig.τ0 / l; 264 let mut τ = pdpsconfig.τ0 / l;
240 let mut σ = pdpsconfig.σ0 / l; 265 let mut σ = pdpsconfig.σ0 / l;
241 let γ = dataterm.factor_of_strong_convexity(); 266 let γ = dataterm.factor_of_strong_convexity();
242 267
243 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 268 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
247 272
248 // Initialise iterates 273 // Initialise iterates
249 let mut μ = DiscreteMeasure::new(); 274 let mut μ = DiscreteMeasure::new();
250 let mut y = dataterm.some_subdifferential(-b); 275 let mut y = dataterm.some_subdifferential(-b);
251 let mut y_prev = y.clone(); 276 let mut y_prev = y.clone();
277 let full_stats = |μ : &RNDM<F, N>, ε, stats| IterInfo {
278 value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ),
279 n_spikes : μ.len(),
280 ε,
281 // postprocessing: config.postprocessing.then(|| μ.clone()),
282 .. stats
283 };
252 let mut stats = IterInfo::new(); 284 let mut stats = IterInfo::new();
253 285
254 // Run the algorithm 286 // Run the algorithm
255 iterator.iterate(|state| { 287 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
256 // Calculate smooth part of surrogate model. 288 // Calculate smooth part of surrogate model.
257 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` 289 let τv = opA.preadjoint().apply(y * τ);
258 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
259 // the residual and replacing it below before the end of this closure.
260 y *= -τ;
261 let r = std::mem::replace(&mut y, opA.empty_observable());
262 let minus_τv = opA.preadjoint().apply(r);
263 290
264 // Save current base point 291 // Save current base point
265 let μ_base = μ.clone(); 292 let μ_base = μ.clone();
266 293
267 // Insert and reweigh 294 // Insert and reweigh
268 let (d, within_tolerances) = insert_and_reweigh( 295 let (d, _within_tolerances) = insert_and_reweigh(
269 &mut μ, &minus_τv, &μ_base, None, 296 &mut μ, &τv, &μ_base, None,
270 op𝒟, op𝒟norm, 297 op𝒟, op𝒟norm,
271 τ, ε, 298 τ, ε,
272 config, &reg, state, &mut stats 299 config, &reg, &state, &mut stats
273 ); 300 );
274 301
275 // Prune and possibly merge spikes 302 // Prune and possibly merge spikes
276 prune_and_maybe_simple_merge( 303 if config.merge_now(&state) {
277 &mut μ, &minus_τv, &μ_base, 304 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
278 op𝒟, 305 let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base));
279 τ, ε, 306 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
280 config, &reg, state, &mut stats 307 });
281 ); 308 }
309 stats.pruned += prune_with_stats(&mut μ);
282 310
283 // Update step length parameters 311 // Update step length parameters
284 let ω = match pdpsconfig.acceleration { 312 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ);
285 Acceleration::None => 1.0,
286 Acceleration::Partial => {
287 let ω = 1.0 / (1.0 + γ * σ).sqrt();
288 σ = σ * ω;
289 τ = τ / ω;
290 ω
291 },
292 Acceleration::Full => {
293 let ω = 1.0 / (1.0 + 2.0 * γ * σ).sqrt();
294 σ = σ * ω;
295 τ = τ / ω;
296 ω
297 },
298 };
299 313
300 // Do dual update 314 // Do dual update
301 y = b.clone(); // y = b 315 y = b.clone(); // y = b
302 opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b 316 opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b
303 opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b 317 opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b
304 dataterm.dual_update(&mut y, &y_prev, σ); 318 dataterm.dual_update(&mut y, &y_prev, σ);
305 y_prev.copy_from(&y); 319 y_prev.copy_from(&y);
306 320
307 // Update main tolerance for next iteration 321 // Give statistics if requested
308 let ε_prev = ε; 322 let iter = state.iteration();
309 ε = tolerance.update(ε, state.iteration());
310 stats.this_iters += 1; 323 stats.this_iters += 1;
311 324
312 // Give function value if needed
313 state.if_verbose(|| { 325 state.if_verbose(|| {
314 // Plot if so requested 326 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ);
315 plotter.plot_spikes( 327 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
316 format!("iter {} end; {}", state.iteration(), within_tolerances), &d, 328 });
317 "start".to_string(), Some(&minus_τv), 329
318 reg.target_bounds(τ, ε_prev), &μ, 330 ε = tolerance.update(ε, iter);
319 ); 331 }
320 // Calculate mean inner iterations and reset relevant counters.
321 // Return the statistics
322 let res = IterInfo {
323 value : dataterm.calculate_fit_op(&μ, opA, b) + reg.apply(&μ),
324 n_spikes : μ.len(),
325 ε : ε_prev,
326 postprocessing: config.postprocessing.then(|| μ.clone()),
327 .. stats
328 };
329 stats = IterInfo::new();
330 res
331 })
332 });
333 332
334 postprocess(μ, config, dataterm, opA, b) 333 postprocess(μ, config, dataterm, opA, b)
335 } 334 }
336 335

mercurial