| 91 None, |
66 None, |
| 92 /// Partial acceleration, $ω = 1/\sqrt{1+σ}$ |
67 /// Partial acceleration, $ω = 1/\sqrt{1+σ}$ |
| 93 #[clap(name = "partial", help = "Partial acceleration, ω = 1/√(1+σ)")] |
68 #[clap(name = "partial", help = "Partial acceleration, ω = 1/√(1+σ)")] |
| 94 Partial, |
69 Partial, |
| 95 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed |
70 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed |
| 96 #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] |
71 #[clap( |
| 97 Full |
72 name = "full", |
| |
73 help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed" |
| |
74 )] |
| |
75 Full, |
| 98 } |
76 } |
| 99 |
77 |
| 100 #[replace_float_literals(F::cast_from(literal))] |
78 #[replace_float_literals(F::cast_from(literal))] |
| 101 impl Acceleration { |
79 impl Acceleration { |
| 102 /// PDPS parameter acceleration. Updates τ and σ and returns ω. |
80 /// PDPS parameter acceleration. Updates τ and σ and returns ω. |
| 103 /// This uses dual strong convexity, not primal. |
81 /// This uses dual strong convexity, not primal. |
| 104 fn accelerate<F : Float>(self, τ : &mut F, σ : &mut F, γ : F) -> F { |
82 fn accelerate<F: Float>(self, τ: &mut F, σ: &mut F, γ: F) -> F { |
| 105 match self { |
83 match self { |
| 106 Acceleration::None => 1.0, |
84 Acceleration::None => 1.0, |
| 107 Acceleration::Partial => { |
85 Acceleration::Partial => { |
| 108 let ω = 1.0 / (1.0 + γ * (*σ)).sqrt(); |
86 let ω = 1.0 / (1.0 + γ * (*σ)).sqrt(); |
| 109 *σ *= ω; |
87 *σ *= ω; |
| 110 *τ /= ω; |
88 *τ /= ω; |
| 111 ω |
89 ω |
| 112 }, |
90 } |
| 113 Acceleration::Full => { |
91 Acceleration::Full => { |
| 114 let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt(); |
92 let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt(); |
| 115 *σ *= ω; |
93 *σ *= ω; |
| 116 *τ /= ω; |
94 *τ /= ω; |
| 117 ω |
95 ω |
| 118 }, |
96 } |
| 119 } |
97 } |
| 120 } |
98 } |
| 121 } |
99 } |
| 122 |
100 |
| 123 /// Settings for [`pointsource_pdps_reg`]. |
101 /// Settings for [`pointsource_pdps_reg`]. |
| 124 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
102 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 125 #[serde(default)] |
103 #[serde(default)] |
| 126 pub struct PDPSConfig<F : Float> { |
104 pub struct PDPSConfig<F: Float> { |
| 127 /// Primal step length scaling. We must have `τ0 * σ0 < 1`. |
105 /// Primal step length scaling. We must have `τ0 * σ0 < 1`. |
| 128 pub τ0 : F, |
106 pub τ0: F, |
| 129 /// Dual step length scaling. We must have `τ0 * σ0 < 1`. |
107 /// Dual step length scaling. We must have `τ0 * σ0 < 1`. |
| 130 pub σ0 : F, |
108 pub σ0: F, |
| 131 /// Accelerate if available |
109 /// Accelerate if available |
| 132 pub acceleration : Acceleration, |
110 pub acceleration: Acceleration, |
| 133 /// Generic parameters |
111 /// Generic parameters |
| 134 pub generic : FBGenericConfig<F>, |
112 pub generic: InsertionConfig<F>, |
| 135 } |
113 } |
| 136 |
114 |
| 137 #[replace_float_literals(F::cast_from(literal))] |
115 #[replace_float_literals(F::cast_from(literal))] |
| 138 impl<F : Float> Default for PDPSConfig<F> { |
116 impl<F: Float> Default for PDPSConfig<F> { |
| 139 fn default() -> Self { |
117 fn default() -> Self { |
| 140 let τ0 = 5.0; |
118 let τ0 = 5.0; |
| 141 PDPSConfig { |
119 PDPSConfig { |
| 142 τ0, |
120 τ0, |
| 143 σ0 : 0.99/τ0, |
121 σ0: 0.99 / τ0, |
| 144 acceleration : Acceleration::Partial, |
122 acceleration: Acceleration::Partial, |
| 145 generic : FBGenericConfig { |
123 generic: InsertionConfig { |
| 146 merging : SpikeMergingMethod { enabled : true, ..Default::default() }, |
124 merging: SpikeMergingMethod { enabled: true, ..Default::default() }, |
| 147 .. Default::default() |
125 ..Default::default() |
| 148 }, |
126 }, |
| 149 } |
127 } |
| 150 } |
128 } |
| 151 } |
129 } |
| 152 |
130 |
| 153 /// Trait for data terms for the PDPS |
|
| 154 #[replace_float_literals(F::cast_from(literal))] |
|
| 155 pub trait PDPSDataTerm<F : Float, V, const N : usize> : DataTerm<F, V, N> { |
|
| 156 /// Calculate some subdifferential at `x` for the conjugate |
|
| 157 fn some_subdifferential(&self, x : V) -> V; |
|
| 158 |
|
| 159 /// Factor of strong convexity of the conjugate |
|
| 160 #[inline] |
|
| 161 fn factor_of_strong_convexity(&self) -> F { |
|
| 162 0.0 |
|
| 163 } |
|
| 164 |
|
| 165 /// Perform dual update |
|
| 166 fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F); |
|
| 167 } |
|
| 168 |
|
| 169 |
|
| 170 #[replace_float_literals(F::cast_from(literal))] |
|
| 171 impl<F, V, const N : usize> PDPSDataTerm<F, V, N> |
|
| 172 for L2Squared |
|
| 173 where |
|
| 174 F : Float, |
|
| 175 V : Euclidean<F> + AXPY<F>, |
|
| 176 for<'b> &'b V : Instance<V>, |
|
| 177 { |
|
| 178 fn some_subdifferential(&self, x : V) -> V { x } |
|
| 179 |
|
| 180 fn factor_of_strong_convexity(&self) -> F { |
|
| 181 1.0 |
|
| 182 } |
|
| 183 |
|
| 184 #[inline] |
|
| 185 fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) { |
|
| 186 y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ)); |
|
| 187 } |
|
| 188 } |
|
| 189 |
|
| 190 #[replace_float_literals(F::cast_from(literal))] |
|
| 191 impl<F : Float + nalgebra::RealField, const N : usize> |
|
| 192 PDPSDataTerm<F, DVector<F>, N> |
|
| 193 for L1 { |
|
| 194 fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> { |
|
| 195 // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. |
|
| 196 x.iter_mut() |
|
| 197 .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) }); |
|
| 198 x |
|
| 199 } |
|
| 200 |
|
| 201 #[inline] |
|
| 202 fn dual_update(&self, y : &mut DVector<F>, y_prev : &DVector<F>, σ : F) { |
|
| 203 y.axpy(1.0, y_prev, σ); |
|
| 204 y.proj_ball_mut(1.0, Linfinity); |
|
| 205 } |
|
| 206 } |
|
| 207 |
|
| 208 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. |
131 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. |
| 209 /// |
132 /// |
| 210 /// The `dataterm` should be either [`L1`] for norm-1 data term or [`L2Squared`] for norm-2-squared. |
|
| 211 /// The settings in `config` have their [respective documentation](PDPSConfig). `opA` is the |
133 /// The settings in `config` have their [respective documentation](PDPSConfig). `opA` is the |
| 212 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
134 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. |
| 213 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution |
135 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution |
| 214 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
136 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control |
| 215 /// as documented in [`alg_tools::iterate`]. |
137 /// as documented in [`alg_tools::iterate`]. |
| 216 /// |
138 /// |
| 217 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. |
139 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. |
| 218 /// |
140 /// |
| 219 /// Returns the final iterate. |
141 /// Returns the final iterate. |
| 220 #[replace_float_literals(F::cast_from(literal))] |
142 #[replace_float_literals(F::cast_from(literal))] |
| 221 pub fn pointsource_pdps_reg<F, I, A, D, Reg, P, const N : usize>( |
143 pub fn pointsource_pdps_reg<'a, F, I, A, Phi, Reg, Plot, P, const N: usize>( |
| 222 opA : &A, |
144 f: &'a DataTerm<F, RNDM<N, F>, A, Phi>, |
| 223 b : &A::Observable, |
145 reg: &Reg, |
| 224 reg : Reg, |
146 prox_penalty: &P, |
| 225 prox_penalty : &P, |
147 pdpsconfig: &PDPSConfig<F>, |
| 226 pdpsconfig : &PDPSConfig<F>, |
148 iterator: I, |
| 227 iterator : I, |
149 mut plotter: Plot, |
| 228 mut plotter : SeqPlotter<F, N>, |
150 μ0 : Option<RNDM<N, F>>, |
| 229 dataterm : D, |
151 ) -> DynResult<RNDM<N, F>> |
| 230 ) -> RNDM<F, N> |
|
| 231 where |
152 where |
| 232 F : Float + ToNalgebraRealField, |
153 F: Float + ToNalgebraRealField, |
| 233 I : AlgIteratorFactory<IterInfo<F, N>>, |
154 I: AlgIteratorFactory<IterInfo<F>>, |
| 234 A : ForwardModel<RNDM<F, N>, F> |
155 A: ForwardModel<RNDM<N, F>, F>, |
| 235 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
156 for<'b> &'b A::Observable: Instance<A::Observable>, |
| 236 A::PreadjointCodomain : RealMapping<F, N>, |
157 A::Observable: AXPY<Field = F>, |
| 237 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
158 RNDM<N, F>: SpikeMerging<F>, |
| 238 PlotLookup : Plotting<N>, |
159 Reg: RegTerm<Loc<N, F>, F>, |
| 239 RNDM<F, N> : SpikeMerging<F>, |
160 Phi: Conjugable<A::Observable, F>, |
| 240 D : PDPSDataTerm<F, A::Observable, N>, |
161 for<'b> Phi::Conjugate<'b>: Prox<A::Observable>, |
| 241 Reg : RegTerm<F, N>, |
162 P: ProxPenalty<Loc<N, F>, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD<F, A, RNDM<N, F>>, |
| 242 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
163 Plot: Plotter<P::ReturnMapping, A::PreadjointCodomain, RNDM<N, F>>, |
| 243 { |
164 { |
| 244 |
|
| 245 // Check parameters |
165 // Check parameters |
| 246 assert!(pdpsconfig.τ0 > 0.0 && |
166 ensure!( |
| 247 pdpsconfig.σ0 > 0.0 && |
167 pdpsconfig.τ0 > 0.0 && pdpsconfig.σ0 > 0.0 && pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, |
| 248 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, |
168 "Invalid step length parameters" |
| 249 "Invalid step length parameters"); |
169 ); |
| |
170 |
| |
171 let opA = f.operator(); |
| |
172 let b = f.data(); |
| |
173 let phistar = f.fidelity().conjugate(); |
| 250 |
174 |
| 251 // Set up parameters |
175 // Set up parameters |
| 252 let config = &pdpsconfig.generic; |
176 let config = &pdpsconfig.generic; |
| 253 let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); |
177 let l = prox_penalty.step_length_bound_pd(opA)?; |
| 254 let mut τ = pdpsconfig.τ0 / l; |
178 let mut τ = pdpsconfig.τ0 / l; |
| 255 let mut σ = pdpsconfig.σ0 / l; |
179 let mut σ = pdpsconfig.σ0 / l; |
| 256 let γ = dataterm.factor_of_strong_convexity(); |
180 let γ = phistar.factor_of_strong_convexity(); |
| 257 |
181 |
| 258 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
182 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 259 // by τ compared to the conditional gradient approach. |
183 // by τ compared to the conditional gradient approach. |
| 260 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
184 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); |
| 261 let mut ε = tolerance.initial(); |
185 let mut ε = tolerance.initial(); |
| 262 |
186 |
| 263 // Initialise iterates |
187 // Initialise iterates |
| 264 let mut μ = DiscreteMeasure::new(); |
188 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 265 let mut y = dataterm.some_subdifferential(-b); |
189 let mut y = f.residual(&μ); |
| 266 let mut y_prev = y.clone(); |
190 let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo { |
| 267 let full_stats = |μ : &RNDM<F, N>, ε, stats| IterInfo { |
191 value: f.apply(μ) + reg.apply(μ), |
| 268 value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ), |
192 n_spikes: μ.len(), |
| 269 n_spikes : μ.len(), |
|
| 270 ε, |
193 ε, |
| 271 // postprocessing: config.postprocessing.then(|| μ.clone()), |
194 // postprocessing: config.postprocessing.then(|| μ.clone()), |
| 272 .. stats |
195 ..stats |
| 273 }; |
196 }; |
| 274 let mut stats = IterInfo::new(); |
197 let mut stats = IterInfo::new(); |
| 275 |
198 |
| 276 // Run the algorithm |
199 // Run the algorithm |
| 277 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
200 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 278 // Calculate smooth part of surrogate model. |
201 // Calculate smooth part of surrogate model. |
| 279 let mut τv = opA.preadjoint().apply(y * τ); |
202 // FIXME: the clone is required to avoid compiler overflows with reference-Mul requirement above. |
| |
203 let mut τv = opA.preadjoint().apply(y.clone() * τ); |
| 280 |
204 |
| 281 // Save current base point |
205 // Save current base point |
| 282 let μ_base = μ.clone(); |
206 let μ_base = μ.clone(); |
| 283 |
207 |
| 284 // Insert and reweigh |
208 // Insert and reweigh |
| 285 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
209 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 286 &mut μ, &mut τv, &μ_base, None, |
210 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
| 287 τ, ε, |
211 )?; |
| 288 config, ®, &state, &mut stats |
|
| 289 ); |
|
| 290 |
212 |
| 291 // Prune and possibly merge spikes |
213 // Prune and possibly merge spikes |
| 292 if config.merge_now(&state) { |
214 if config.merge_now(&state) { |
| 293 stats.merged += prox_penalty.merge_spikes_no_fitness( |
215 stats.merged += prox_penalty |
| 294 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, |
216 .merge_spikes_no_fitness(&mut μ, &mut τv, &μ_base, None, τ, ε, config, ®); |
| 295 ); |
|
| 296 } |
217 } |
| 297 stats.pruned += prune_with_stats(&mut μ); |
218 stats.pruned += prune_with_stats(&mut μ); |
| 298 |
219 |
| 299 // Update step length parameters |
220 // Update step length parameters |
| 300 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
221 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
| 301 |
222 |
| 302 // Do dual update |
223 // Do dual update |
| 303 y = b.clone(); // y = b |
224 // y = y_prev + τb |
| 304 opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b |
225 y.axpy(τ, b, 1.0); |
| 305 opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b |
226 // y = y_prev - τ(A[(1+ω)μ^{k+1}]-b) |
| 306 dataterm.dual_update(&mut y, &y_prev, σ); |
227 opA.gemv(&mut y, -τ * (1.0 + ω), &μ, 1.0); |
| 307 y_prev.copy_from(&y); |
228 // y = y_prev - τ(A[(1+ω)μ^{k+1} - ω μ^k]-b) |
| |
229 opA.gemv(&mut y, τ * ω, &μ_base, 1.0); |
| |
230 y = phistar.prox(τ, y); |
| 308 |
231 |
| 309 // Give statistics if requested |
232 // Give statistics if requested |
| 310 let iter = state.iteration(); |
233 let iter = state.iteration(); |
| 311 stats.this_iters += 1; |
234 stats.this_iters += 1; |
| 312 |
235 |