src/pdps.rs

branch
dev
changeset 61
4f468d35fa29
parent 39
6316d68b58af
child 63
7a8a55fd41c0
equal deleted inserted replaced
60:9738b51d90d7 61:4f468d35fa29
36 For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$. 36 For $F_0(y)=\frac{1}{2}\|y\|_2^2$ the second part reads $y = Aμ -b$.
37 For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$. 37 For $F_0(y)=\|y\|_1$ the second part reads $y ∈ ∂\|·\|_1(Aμ - b)$.
38 </p> 38 </p>
39 */ 39 */
40 40
41 use crate::fb::{postprocess, prune_with_stats};
42 use crate::forward_model::ForwardModel;
43 use crate::measures::merging::SpikeMerging;
44 use crate::measures::merging::SpikeMergingMethod;
45 use crate::measures::{DiscreteMeasure, RNDM};
46 use crate::plot::Plotter;
47 pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBoundPD};
48 use crate::regularisation::RegTerm;
49 use crate::types::*;
50 use alg_tools::convex::{Conjugable, ConvexMapping, Prox};
51 use alg_tools::error::DynResult;
52 use alg_tools::iterate::AlgIteratorFactory;
53 use alg_tools::linops::{Mapping, AXPY};
54 use alg_tools::mapping::{DataTerm, Instance};
55 use alg_tools::nalgebra_support::ToNalgebraRealField;
56 use anyhow::ensure;
57 use clap::ValueEnum;
41 use numeric_literals::replace_float_literals; 58 use numeric_literals::replace_float_literals;
42 use serde::{Serialize, Deserialize}; 59 use serde::{Deserialize, Serialize};
43 use nalgebra::DVector;
44 use clap::ValueEnum;
45
46 use alg_tools::iterate::AlgIteratorFactory;
47 use alg_tools::euclidean::Euclidean;
48 use alg_tools::linops::Mapping;
49 use alg_tools::norms::{
50 Linfinity,
51 Projection,
52 };
53 use alg_tools::mapping::{RealMapping, Instance};
54 use alg_tools::nalgebra_support::ToNalgebraRealField;
55 use alg_tools::linops::AXPY;
56
57 use crate::types::*;
58 use crate::measures::{DiscreteMeasure, RNDM};
59 use crate::measures::merging::SpikeMerging;
60 use crate::forward_model::{
61 ForwardModel,
62 AdjointProductBoundedBy,
63 };
64 use crate::plot::{
65 SeqPlotter,
66 Plotting,
67 PlotLookup
68 };
69 use crate::fb::{
70 postprocess,
71 prune_with_stats
72 };
73 pub use crate::prox_penalty::{
74 FBGenericConfig,
75 ProxPenalty
76 };
77 use crate::regularisation::RegTerm;
78 use crate::dataterm::{
79 DataTerm,
80 L2Squared,
81 L1
82 };
83 use crate::measures::merging::SpikeMergingMethod;
84
85 60
86 /// Acceleration 61 /// Acceleration
87 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)] 62 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)]
88 pub enum Acceleration { 63 pub enum Acceleration {
89 /// No acceleration 64 /// No acceleration
91 None, 66 None,
92 /// Partial acceleration, $ω = 1/\sqrt{1+σ}$ 67 /// Partial acceleration, $ω = 1/\sqrt{1+σ}$
93 #[clap(name = "partial", help = "Partial acceleration, ω = 1/√(1+σ)")] 68 #[clap(name = "partial", help = "Partial acceleration, ω = 1/√(1+σ)")]
94 Partial, 69 Partial,
95 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed 70 /// Full acceleration, $ω = 1/\sqrt{1+2σ}$; no gap convergence guaranteed
96 #[clap(name = "full", help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed")] 71 #[clap(
97 Full 72 name = "full",
73 help = "Full acceleration, ω = 1/√(1+2σ); no gap convergence guaranteed"
74 )]
75 Full,
98 } 76 }
99 77
100 #[replace_float_literals(F::cast_from(literal))] 78 #[replace_float_literals(F::cast_from(literal))]
101 impl Acceleration { 79 impl Acceleration {
102 /// PDPS parameter acceleration. Updates τ and σ and returns ω. 80 /// PDPS parameter acceleration. Updates τ and σ and returns ω.
103 /// This uses dual strong convexity, not primal. 81 /// This uses dual strong convexity, not primal.
104 fn accelerate<F : Float>(self, τ : &mut F, σ : &mut F, γ : F) -> F { 82 fn accelerate<F: Float>(self, τ: &mut F, σ: &mut F, γ: F) -> F {
105 match self { 83 match self {
106 Acceleration::None => 1.0, 84 Acceleration::None => 1.0,
107 Acceleration::Partial => { 85 Acceleration::Partial => {
108 let ω = 1.0 / (1.0 + γ * (*σ)).sqrt(); 86 let ω = 1.0 / (1.0 + γ * (*σ)).sqrt();
109 *σ *= ω; 87 *σ *= ω;
110 *τ /= ω; 88 *τ /= ω;
111 ω 89 ω
112 }, 90 }
113 Acceleration::Full => { 91 Acceleration::Full => {
114 let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt(); 92 let ω = 1.0 / (1.0 + 2.0 * γ * (*σ)).sqrt();
115 *σ *= ω; 93 *σ *= ω;
116 *τ /= ω; 94 *τ /= ω;
117 ω 95 ω
118 }, 96 }
119 } 97 }
120 } 98 }
121 } 99 }
122 100
123 /// Settings for [`pointsource_pdps_reg`]. 101 /// Settings for [`pointsource_pdps_reg`].
124 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 102 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
125 #[serde(default)] 103 #[serde(default)]
126 pub struct PDPSConfig<F : Float> { 104 pub struct PDPSConfig<F: Float> {
127 /// Primal step length scaling. We must have `τ0 * σ0 < 1`. 105 /// Primal step length scaling. We must have `τ0 * σ0 < 1`.
128 pub τ0 : F, 106 pub τ0: F,
129 /// Dual step length scaling. We must have `τ0 * σ0 < 1`. 107 /// Dual step length scaling. We must have `τ0 * σ0 < 1`.
130 pub σ0 : F, 108 pub σ0: F,
131 /// Accelerate if available 109 /// Accelerate if available
132 pub acceleration : Acceleration, 110 pub acceleration: Acceleration,
133 /// Generic parameters 111 /// Generic parameters
134 pub generic : FBGenericConfig<F>, 112 pub generic: InsertionConfig<F>,
135 } 113 }
136 114
137 #[replace_float_literals(F::cast_from(literal))] 115 #[replace_float_literals(F::cast_from(literal))]
138 impl<F : Float> Default for PDPSConfig<F> { 116 impl<F: Float> Default for PDPSConfig<F> {
139 fn default() -> Self { 117 fn default() -> Self {
140 let τ0 = 5.0; 118 let τ0 = 5.0;
141 PDPSConfig { 119 PDPSConfig {
142 τ0, 120 τ0,
143 σ0 : 0.99/τ0, 121 σ0: 0.99 / τ0,
144 acceleration : Acceleration::Partial, 122 acceleration: Acceleration::Partial,
145 generic : FBGenericConfig { 123 generic: InsertionConfig {
146 merging : SpikeMergingMethod { enabled : true, ..Default::default() }, 124 merging: SpikeMergingMethod { enabled: true, ..Default::default() },
147 .. Default::default() 125 ..Default::default()
148 }, 126 },
149 } 127 }
150 } 128 }
151 } 129 }
152 130
153 /// Trait for data terms for the PDPS
154 #[replace_float_literals(F::cast_from(literal))]
155 pub trait PDPSDataTerm<F : Float, V, const N : usize> : DataTerm<F, V, N> {
156 /// Calculate some subdifferential at `x` for the conjugate
157 fn some_subdifferential(&self, x : V) -> V;
158
159 /// Factor of strong convexity of the conjugate
160 #[inline]
161 fn factor_of_strong_convexity(&self) -> F {
162 0.0
163 }
164
165 /// Perform dual update
166 fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F);
167 }
168
169
170 #[replace_float_literals(F::cast_from(literal))]
171 impl<F, V, const N : usize> PDPSDataTerm<F, V, N>
172 for L2Squared
173 where
174 F : Float,
175 V : Euclidean<F> + AXPY<F>,
176 for<'b> &'b V : Instance<V>,
177 {
178 fn some_subdifferential(&self, x : V) -> V { x }
179
180 fn factor_of_strong_convexity(&self) -> F {
181 1.0
182 }
183
184 #[inline]
185 fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) {
186 y.axpy(1.0 / (1.0 + σ), y_prev, σ / (1.0 + σ));
187 }
188 }
189
190 #[replace_float_literals(F::cast_from(literal))]
191 impl<F : Float + nalgebra::RealField, const N : usize>
192 PDPSDataTerm<F, DVector<F>, N>
193 for L1 {
194 fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> {
195 // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well.
196 x.iter_mut()
197 .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) });
198 x
199 }
200
201 #[inline]
202 fn dual_update(&self, y : &mut DVector<F>, y_prev : &DVector<F>, σ : F) {
203 y.axpy(1.0, y_prev, σ);
204 y.proj_ball_mut(1.0, Linfinity);
205 }
206 }
207
208 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. 131 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting.
209 /// 132 ///
210 /// The `dataterm` should be either [`L1`] for norm-1 data term or [`L2Squared`] for norm-2-squared.
211 /// The settings in `config` have their [respective documentation](PDPSConfig). `opA` is the 133 /// The settings in `config` have their [respective documentation](PDPSConfig). `opA` is the
212 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight. 134 /// forward operator $A$, $b$ the observable, and $\lambda$ the regularisation weight.
213 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution 135 /// The operator `op𝒟` is used for forming the proximal term. Typically it is a convolution
214 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control 136 /// operator. Finally, the `iterator` is an outer loop verbosity and iteration count control
215 /// as documented in [`alg_tools::iterate`]. 137 /// as documented in [`alg_tools::iterate`].
216 /// 138 ///
217 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. 139 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript.
218 /// 140 ///
219 /// Returns the final iterate. 141 /// Returns the final iterate.
220 #[replace_float_literals(F::cast_from(literal))] 142 #[replace_float_literals(F::cast_from(literal))]
221 pub fn pointsource_pdps_reg<F, I, A, D, Reg, P, const N : usize>( 143 pub fn pointsource_pdps_reg<'a, F, I, A, Phi, Reg, Plot, P, const N: usize>(
222 opA : &A, 144 f: &'a DataTerm<F, RNDM<N, F>, A, Phi>,
223 b : &A::Observable, 145 reg: &Reg,
224 reg : Reg, 146 prox_penalty: &P,
225 prox_penalty : &P, 147 pdpsconfig: &PDPSConfig<F>,
226 pdpsconfig : &PDPSConfig<F>, 148 iterator: I,
227 iterator : I, 149 mut plotter: Plot,
228 mut plotter : SeqPlotter<F, N>, 150 μ0 : Option<RNDM<N, F>>,
229 dataterm : D, 151 ) -> DynResult<RNDM<N, F>>
230 ) -> RNDM<F, N>
231 where 152 where
232 F : Float + ToNalgebraRealField, 153 F: Float + ToNalgebraRealField,
233 I : AlgIteratorFactory<IterInfo<F, N>>, 154 I: AlgIteratorFactory<IterInfo<F>>,
234 A : ForwardModel<RNDM<F, N>, F> 155 A: ForwardModel<RNDM<N, F>, F>,
235 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, 156 for<'b> &'b A::Observable: Instance<A::Observable>,
236 A::PreadjointCodomain : RealMapping<F, N>, 157 A::Observable: AXPY<Field = F>,
237 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, 158 RNDM<N, F>: SpikeMerging<F>,
238 PlotLookup : Plotting<N>, 159 Reg: RegTerm<Loc<N, F>, F>,
239 RNDM<F, N> : SpikeMerging<F>, 160 Phi: Conjugable<A::Observable, F>,
240 D : PDPSDataTerm<F, A::Observable, N>, 161 for<'b> Phi::Conjugate<'b>: Prox<A::Observable>,
241 Reg : RegTerm<F, N>, 162 P: ProxPenalty<Loc<N, F>, A::PreadjointCodomain, Reg, F> + StepLengthBoundPD<F, A, RNDM<N, F>>,
242 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, 163 Plot: Plotter<P::ReturnMapping, A::PreadjointCodomain, RNDM<N, F>>,
243 { 164 {
244
245 // Check parameters 165 // Check parameters
246 assert!(pdpsconfig.τ0 > 0.0 && 166 ensure!(
247 pdpsconfig.σ0 > 0.0 && 167 pdpsconfig.τ0 > 0.0 && pdpsconfig.σ0 > 0.0 && pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0,
248 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, 168 "Invalid step length parameters"
249 "Invalid step length parameters"); 169 );
170
171 let opA = f.operator();
172 let b = f.data();
173 let phistar = f.fidelity().conjugate();
250 174
251 // Set up parameters 175 // Set up parameters
252 let config = &pdpsconfig.generic; 176 let config = &pdpsconfig.generic;
253 let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); 177 let l = prox_penalty.step_length_bound_pd(opA)?;
254 let mut τ = pdpsconfig.τ0 / l; 178 let mut τ = pdpsconfig.τ0 / l;
255 let mut σ = pdpsconfig.σ0 / l; 179 let mut σ = pdpsconfig.σ0 / l;
256 let γ = dataterm.factor_of_strong_convexity(); 180 let γ = phistar.factor_of_strong_convexity();
257 181
258 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 182 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
259 // by τ compared to the conditional gradient approach. 183 // by τ compared to the conditional gradient approach.
260 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 184 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
261 let mut ε = tolerance.initial(); 185 let mut ε = tolerance.initial();
262 186
263 // Initialise iterates 187 // Initialise iterates
264 let mut μ = DiscreteMeasure::new(); 188 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
265 let mut y = dataterm.some_subdifferential(-b); 189 let mut y = f.residual(&μ);
266 let mut y_prev = y.clone(); 190 let full_stats = |μ: &RNDM<N, F>, ε, stats| IterInfo {
267 let full_stats = |μ : &RNDM<F, N>, ε, stats| IterInfo { 191 value: f.apply(μ) + reg.apply(μ),
268 value : dataterm.calculate_fit_op(μ, opA, b) + reg.apply(μ), 192 n_spikes: μ.len(),
269 n_spikes : μ.len(),
270 ε, 193 ε,
271 // postprocessing: config.postprocessing.then(|| μ.clone()), 194 // postprocessing: config.postprocessing.then(|| μ.clone()),
272 .. stats 195 ..stats
273 }; 196 };
274 let mut stats = IterInfo::new(); 197 let mut stats = IterInfo::new();
275 198
276 // Run the algorithm 199 // Run the algorithm
277 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 200 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
278 // Calculate smooth part of surrogate model. 201 // Calculate smooth part of surrogate model.
279 let mut τv = opA.preadjoint().apply(y * τ); 202 // FIXME: the clone is required to avoid compiler overflows with reference-Mul requirement above.
203 let mut τv = opA.preadjoint().apply(y.clone() * τ);
280 204
281 // Save current base point 205 // Save current base point
282 let μ_base = μ.clone(); 206 let μ_base = μ.clone();
283 207
284 // Insert and reweigh 208 // Insert and reweigh
285 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 209 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
286 &mut μ, &mut τv, &μ_base, None, 210 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
287 τ, ε, 211 )?;
288 config, &reg, &state, &mut stats
289 );
290 212
291 // Prune and possibly merge spikes 213 // Prune and possibly merge spikes
292 if config.merge_now(&state) { 214 if config.merge_now(&state) {
293 stats.merged += prox_penalty.merge_spikes_no_fitness( 215 stats.merged += prox_penalty
294 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, 216 .merge_spikes_no_fitness(&mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg);
295 );
296 } 217 }
297 stats.pruned += prune_with_stats(&mut μ); 218 stats.pruned += prune_with_stats(&mut μ);
298 219
299 // Update step length parameters 220 // Update step length parameters
300 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); 221 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ);
301 222
302 // Do dual update 223 // Do dual update
303 y = b.clone(); // y = b 224 // y = y_prev + τb
304 opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b 225 y.axpy(τ, b, 1.0);
305 opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b 226 // y = y_prev - τ(A[(1+ω)μ^{k+1}]-b)
306 dataterm.dual_update(&mut y, &y_prev, σ); 227 opA.gemv(&mut y, -τ * (1.0 + ω), &μ, 1.0);
307 y_prev.copy_from(&y); 228 // y = y_prev - τ(A[(1+ω)μ^{k+1} - ω μ^k]-b)
229 opA.gemv(&mut y, τ * ω, &μ_base, 1.0);
230 y = phistar.prox(τ, y);
308 231
309 // Give statistics if requested 232 // Give statistics if requested
310 let iter = state.iteration(); 233 let iter = state.iteration();
311 stats.this_iters += 1; 234 stats.this_iters += 1;
312 235
316 }); 239 });
317 240
318 ε = tolerance.update(ε, iter); 241 ε = tolerance.update(ε, iter);
319 } 242 }
320 243
321 postprocess(μ, config, dataterm, opA, b) 244 postprocess(μ, config, |μ| f.apply(μ))
322 } 245 }
323

mercurial