src/pdps.rs

branch
dev
changeset 32
56c8adc32b09
parent 29
87649ccfa6a8
child 34
efa60bc4f743
equal deleted inserted replaced
30:bd13c2ae3450 32:56c8adc32b09
46 use numeric_literals::replace_float_literals; 46 use numeric_literals::replace_float_literals;
47 use serde::{Serialize, Deserialize}; 47 use serde::{Serialize, Deserialize};
48 use nalgebra::DVector; 48 use nalgebra::DVector;
49 use clap::ValueEnum; 49 use clap::ValueEnum;
50 50
51 use alg_tools::iterate:: AlgIteratorFactory; 51 use alg_tools::iterate::{
52 AlgIteratorFactory,
53 AlgIteratorState,
54 };
52 use alg_tools::loc::Loc; 55 use alg_tools::loc::Loc;
53 use alg_tools::euclidean::Euclidean; 56 use alg_tools::euclidean::Euclidean;
57 use alg_tools::linops::Apply;
54 use alg_tools::norms::{ 58 use alg_tools::norms::{
55 L1, Linfinity, 59 Linfinity,
56 Projection, Norm, 60 Projection,
57 }; 61 };
58 use alg_tools::bisection_tree::{ 62 use alg_tools::bisection_tree::{
59 BTFN, 63 BTFN,
60 PreBTFN, 64 PreBTFN,
61 Bounds, 65 Bounds,
69 use alg_tools::nalgebra_support::ToNalgebraRealField; 73 use alg_tools::nalgebra_support::ToNalgebraRealField;
70 use alg_tools::linops::AXPY; 74 use alg_tools::linops::AXPY;
71 75
72 use crate::types::*; 76 use crate::types::*;
73 use crate::measures::DiscreteMeasure; 77 use crate::measures::DiscreteMeasure;
74 use crate::measures::merging::{ 78 use crate::measures::merging::SpikeMerging;
75 SpikeMerging,
76 };
77 use crate::forward_model::ForwardModel; 79 use crate::forward_model::ForwardModel;
78 use crate::seminorms::{ 80 use crate::seminorms::DiscreteMeasureOp;
79 DiscreteMeasureOp, Lipschitz
80 };
81 use crate::plot::{ 81 use crate::plot::{
82 SeqPlotter, 82 SeqPlotter,
83 Plotting, 83 Plotting,
84 PlotLookup 84 PlotLookup
85 }; 85 };
86 use crate::fb::{ 86 use crate::fb::{
87 FBGenericConfig, 87 FBGenericConfig,
88 FBSpecialisation, 88 insert_and_reweigh,
89 generic_pointsource_fb_reg, 89 postprocess,
90 RegTerm, 90 prune_and_maybe_simple_merge
91 };
92 use crate::regularisation::RegTerm;
93 use crate::dataterm::{
94 DataTerm,
95 L2Squared,
96 L1
91 }; 97 };
92 98
93 /// Acceleration 99 /// Acceleration
94 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)] 100 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, ValueEnum, Debug)]
95 pub enum Acceleration { 101 pub enum Acceleration {
129 insertion : Default::default() 135 insertion : Default::default()
130 } 136 }
131 } 137 }
132 } 138 }
133 139
134 /// Trait for subdifferentiable objects 140 /// Trait for data terms for the PDPS
135 pub trait Subdifferentiable<F : Float, V, U=V> { 141 #[replace_float_literals(F::cast_from(literal))]
136 /// Calculate some subdifferential at `x` 142 pub trait PDPSDataTerm<F : Float, V, const N : usize> : DataTerm<F, V, N> {
137 fn some_subdifferential(&self, x : V) -> U; 143 /// Calculate some subdifferential at `x` for the conjugate
138 } 144 fn some_subdifferential(&self, x : V) -> V;
139 145
140 /// Type for indicating norm-2-squared data fidelity. 146 /// Factor of strong convexity of the conjugate
141 pub struct L2Squared; 147 #[inline]
142 148 fn factor_of_strong_convexity(&self) -> F {
143 impl<F : Float, V : Euclidean<F>> Subdifferentiable<F, V> for L2Squared { 149 0.0
150 }
151
152 /// Perform dual update
153 fn dual_update(&self, _y : &mut V, _y_prev : &V, _σ : F);
154 }
155
156
157 #[replace_float_literals(F::cast_from(literal))]
158 impl<F : Float, V : Euclidean<F> + AXPY<F>, const N : usize>
159 PDPSDataTerm<F, V, N>
160 for L2Squared {
144 fn some_subdifferential(&self, x : V) -> V { x } 161 fn some_subdifferential(&self, x : V) -> V { x }
145 } 162
146 163 fn factor_of_strong_convexity(&self) -> F {
147 impl<F : Float + nalgebra::RealField> Subdifferentiable<F, DVector<F>> for L1 { 164 1.0
165 }
166
167 #[inline]
168 fn dual_update(&self, y : &mut V, y_prev : &V, σ : F) {
169 y.axpy(1.0 / (1.0 + σ), &y_prev, σ / (1.0 + σ));
170 }
171 }
172
173 #[replace_float_literals(F::cast_from(literal))]
174 impl<F : Float + nalgebra::RealField, const N : usize>
175 PDPSDataTerm<F, DVector<F>, N>
176 for L1 {
148 fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> { 177 fn some_subdifferential(&self, mut x : DVector<F>) -> DVector<F> {
149 // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well. 178 // nalgebra sucks for providing second copies of the same stuff that's elsewhere as well.
150 x.iter_mut() 179 x.iter_mut()
151 .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) }); 180 .for_each(|v| if *v != F::ZERO { *v = *v/<F as NumTraitsFloat>::abs(*v) });
152 x 181 x
153 } 182 }
154 } 183
155 184 #[inline]
156 /// Specialisation of [`generic_pointsource_fb_reg`] to PDPS. 185 fn dual_update(&self, y : &mut DVector<F>, y_prev : &DVector<F>, σ : F) {
157 pub struct PDPS< 186 y.axpy(1.0, y_prev, σ);
158 'a,
159 F : Float + ToNalgebraRealField,
160 A : ForwardModel<Loc<F, N>, F>,
161 D,
162 const N : usize
163 > {
164 /// The data
165 b : &'a A::Observable,
166 /// The forward operator
167 opA : &'a A,
168 /// Primal step length
169 τ : F,
170 // Dual step length
171 σ : F,
172 /// Whether acceleration should be applied (if data term supports)
173 acceleration : Acceleration,
174 /// The dataterm. Only used by the type system.
175 _dataterm : D,
176 /// Previous dual iterate.
177 y_prev : A::Observable,
178 }
179
180 /// Implementation of [`FBSpecialisation`] for μPDPS with norm-2-squared data fidelity.
181 #[replace_float_literals(F::cast_from(literal))]
182 impl<
183 'a,
184 F : Float + ToNalgebraRealField,
185 A : ForwardModel<Loc<F, N>, F>,
186 const N : usize
187 > FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L2Squared, N>
188 where for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> {
189
190 fn update(
191 &mut self,
192 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
193 μ_base : &DiscreteMeasure<Loc<F, N>, F>
194 ) -> (A::Observable, Option<F>) {
195 let σ = self.σ;
196 let τ = self.τ;
197 let ω = match self.acceleration {
198 Acceleration::None => 1.0,
199 Acceleration::Partial => {
200 let ω = 1.0 / (1.0 + σ).sqrt();
201 self.σ = σ * ω;
202 self.τ = τ / ω;
203 ω
204 },
205 Acceleration::Full => {
206 let ω = 1.0 / (1.0 + 2.0 * σ).sqrt();
207 self.σ = σ * ω;
208 self.τ = τ / ω;
209 ω
210 },
211 };
212
213 μ.prune();
214
215 let mut y = self.b.clone();
216 self.opA.gemv(&mut y, 1.0 + ω, μ, -1.0);
217 self.opA.gemv(&mut y, -ω, μ_base, 1.0);
218 y.axpy(1.0 / (1.0 + σ), &self.y_prev, σ / (1.0 + σ));
219 self.y_prev.copy_from(&y);
220
221 (y, Some(self.τ))
222 }
223
224 fn calculate_fit(
225 &self,
226 μ : &DiscreteMeasure<Loc<F, N>, F>,
227 _y : &A::Observable
228 ) -> F {
229 self.calculate_fit_simple(μ)
230 }
231
232 fn calculate_fit_simple(
233 &self,
234 μ : &DiscreteMeasure<Loc<F, N>, F>,
235 ) -> F {
236 let mut residual = self.b.clone();
237 self.opA.gemv(&mut residual, 1.0, μ, -1.0);
238 residual.norm2_squared_div2()
239 }
240 }
241
242 /// Implementation of [`FBSpecialisation`] for μPDPS with norm-1 data fidelity.
243 #[replace_float_literals(F::cast_from(literal))]
244 impl<
245 'a,
246 F : Float + ToNalgebraRealField,
247 A : ForwardModel<Loc<F, N>, F>,
248 const N : usize
249 > FBSpecialisation<F, A::Observable, N> for PDPS<'a, F, A, L1, N>
250 where A::Observable : Projection<F, Linfinity> + Norm<F, L1>,
251 for<'b> &'b A::Observable : std::ops::Add<A::Observable, Output=A::Observable> {
252 fn update(
253 &mut self,
254 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
255 μ_base : &DiscreteMeasure<Loc<F, N>, F>
256 ) -> (A::Observable, Option<F>) {
257 let σ = self.σ;
258
259 μ.prune();
260
261 //let ȳ = self.opA.apply(μ) * 2.0 - self.opA.apply(μ_base);
262 //*y = proj_{[-1,1]}(&self.y_prev + (ȳ - self.b) * σ)
263 let mut y = self.y_prev.clone();
264 self.opA.gemv(&mut y, 2.0 * σ, μ, 1.0);
265 self.opA.gemv(&mut y, -σ, μ_base, 1.0);
266 y.axpy(-σ, self.b, 1.0);
267 y.proj_ball_mut(1.0, Linfinity); 187 y.proj_ball_mut(1.0, Linfinity);
268 self.y_prev.copy_from(&y);
269
270 (y, None)
271 }
272
273 fn calculate_fit(
274 &self,
275 μ : &DiscreteMeasure<Loc<F, N>, F>,
276 _y : &A::Observable
277 ) -> F {
278 self.calculate_fit_simple(μ)
279 }
280
281 fn calculate_fit_simple(
282 &self,
283 μ : &DiscreteMeasure<Loc<F, N>, F>,
284 ) -> F {
285 let mut residual = self.b.clone();
286 self.opA.gemv(&mut residual, 1.0, μ, -1.0);
287 residual.norm(L1)
288 } 188 }
289 } 189 }
290 190
291 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting. 191 /// Iteratively solve the pointsource localisation problem using primal-dual proximal splitting.
292 /// 192 ///
304 pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( 204 pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>(
305 opA : &'a A, 205 opA : &'a A,
306 b : &'a A::Observable, 206 b : &'a A::Observable,
307 reg : Reg, 207 reg : Reg,
308 op𝒟 : &'a 𝒟, 208 op𝒟 : &'a 𝒟,
309 config : &PDPSConfig<F>, 209 pdpsconfig : &PDPSConfig<F>,
310 iterator : I, 210 iterator : I,
311 plotter : SeqPlotter<F, N>, 211 mut plotter : SeqPlotter<F, N>,
312 dataterm : D, 212 dataterm : D,
313 ) -> DiscreteMeasure<Loc<F, N>, F> 213 ) -> DiscreteMeasure<Loc<F, N>, F>
314 where F : Float + ToNalgebraRealField, 214 where F : Float + ToNalgebraRealField,
315 I : AlgIteratorFactory<IterInfo<F, N>>, 215 I : AlgIteratorFactory<IterInfo<F, N>>,
316 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> 216 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>
317 + std::ops::Add<A::Observable, Output=A::Observable>, 217 + std::ops::Add<A::Observable, Output=A::Observable>,
318 //+ std::ops::Mul<F, Output=A::Observable>, // <-- FIXME: compiler overflow 218 //+ std::ops::Mul<F, Output=A::Observable>, // <-- FIXME: compiler overflow
319 A::Observable : std::ops::MulAssign<F>, 219 A::Observable : std::ops::MulAssign<F>,
320 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 220 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
321 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> 221 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
322 + Lipschitz<𝒟, FloatType=F>, 222 + Lipschitz<&'a 𝒟, FloatType=F>,
323 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 223 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
324 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 224 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
325 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 225 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
326 𝒟::Codomain : RealMapping<F, N>, 226 𝒟::Codomain : RealMapping<F, N>,
327 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 227 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
328 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 228 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
329 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 229 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
330 PlotLookup : Plotting<N>, 230 PlotLookup : Plotting<N>,
331 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 231 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
332 PDPS<'a, F, A, D, N> : FBSpecialisation<F, A::Observable, N>, 232 D : PDPSDataTerm<F, A::Observable, N>,
333 D : Subdifferentiable<F, A::Observable>,
334 Reg : RegTerm<F, N> { 233 Reg : RegTerm<F, N> {
335 234
336 let y = dataterm.some_subdifferential(-b); 235 // Set up parameters
236 let config = &pdpsconfig.insertion;
237 let op𝒟norm = op𝒟.opnorm_bound();
337 let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt(); 238 let l = opA.lipschitz_factor(&op𝒟).unwrap().sqrt();
338 let τ = config.τ0 / l; 239 let mut τ = pdpsconfig.τ0 / l;
339 let σ = config.σ0 / l; 240 let mut σ = pdpsconfig.σ0 / l;
340 241 let γ = dataterm.factor_of_strong_convexity();
341 let pdps = PDPS { 242
342 b, 243 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
343 opA, 244 // by τ compared to the conditional gradient approach.
344 τ, 245 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
345 σ, 246 let mut ε = tolerance.initial();
346 acceleration : config.acceleration, 247
347 _dataterm : dataterm, 248 // Initialise iterates
348 y_prev : y.clone(), 249 let mut μ = DiscreteMeasure::new();
349 }; 250 let mut y = dataterm.some_subdifferential(-b);
350 251 let mut y_prev = y.clone();
351 generic_pointsource_fb_reg( 252 let mut stats = IterInfo::new();
352 opA, reg, op𝒟, τ, &config.insertion, iterator, plotter, y, pdps 253
353 ) 254 // Run the algorithm
354 } 255 iterator.iterate(|state| {
355 256 // Calculate smooth part of surrogate model.
257 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
258 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
259 // the residual and replacing it below before the end of this closure.
260 y *= -τ;
261 let r = std::mem::replace(&mut y, opA.empty_observable());
262 let minus_τv = opA.preadjoint().apply(r);
263
264 // Save current base point
265 let μ_base = μ.clone();
266
267 // Insert and reweigh
268 let (d, within_tolerances) = insert_and_reweigh(
269 &mut μ, &minus_τv, &μ_base, None,
270 op𝒟, op𝒟norm,
271 τ, ε,
272 config, &reg, state, &mut stats
273 );
274
275 // Prune and possibly merge spikes
276 prune_and_maybe_simple_merge(
277 &mut μ, &minus_τv, &μ_base,
278 op𝒟,
279 τ, ε,
280 config, &reg, state, &mut stats
281 );
282
283 // Update step length parameters
284 let ω = match pdpsconfig.acceleration {
285 Acceleration::None => 1.0,
286 Acceleration::Partial => {
287 let ω = 1.0 / (1.0 + γ * σ).sqrt();
288 σ = σ * ω;
289 τ = τ / ω;
290 ω
291 },
292 Acceleration::Full => {
293 let ω = 1.0 / (1.0 + 2.0 * γ * σ).sqrt();
294 σ = σ * ω;
295 τ = τ / ω;
296 ω
297 },
298 };
299
300 // Do dual update
301 y = b.clone(); // y = b
302 opA.gemv(&mut y, 1.0 + ω, &μ, -1.0); // y = A[(1+ω)μ^{k+1}]-b
303 opA.gemv(&mut y, -ω, &μ_base, 1.0); // y = A[(1+ω)μ^{k+1} - ω μ^k]-b
304 dataterm.dual_update(&mut y, &y_prev, σ);
305 y_prev.copy_from(&y);
306
307 // Update main tolerance for next iteration
308 let ε_prev = ε;
309 ε = tolerance.update(ε, state.iteration());
310 stats.this_iters += 1;
311
312 // Give function value if needed
313 state.if_verbose(|| {
314 // Plot if so requested
315 plotter.plot_spikes(
316 format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
317 "start".to_string(), Some(&minus_τv),
318 reg.target_bounds(τ, ε_prev), &μ,
319 );
320 // Calculate mean inner iterations and reset relevant counters.
321 // Return the statistics
322 let res = IterInfo {
323 value : dataterm.calculate_fit_op(&μ, opA, b) + reg.apply(&μ),
324 n_spikes : μ.len(),
325 ε : ε_prev,
326 postprocessing: config.postprocessing.then(|| μ.clone()),
327 .. stats
328 };
329 stats = IterInfo::new();
330 res
331 })
332 });
333
334 postprocess(μ, config, dataterm, opA, b)
335 }
336

mercurial