42 use serde::{Serialize, Deserialize}; |
42 use serde::{Serialize, Deserialize}; |
43 use nalgebra::DVector; |
43 use nalgebra::DVector; |
44 use clap::ValueEnum; |
44 use clap::ValueEnum; |
45 |
45 |
46 use alg_tools::iterate::AlgIteratorFactory; |
46 use alg_tools::iterate::AlgIteratorFactory; |
47 use alg_tools::loc::Loc; |
|
48 use alg_tools::euclidean::Euclidean; |
47 use alg_tools::euclidean::Euclidean; |
49 use alg_tools::linops::Mapping; |
48 use alg_tools::linops::Mapping; |
50 use alg_tools::norms::{ |
49 use alg_tools::norms::{ |
51 Linfinity, |
50 Linfinity, |
52 Projection, |
51 Projection, |
53 }; |
52 }; |
54 use alg_tools::bisection_tree::{ |
|
55 BTFN, |
|
56 PreBTFN, |
|
57 Bounds, |
|
58 BTNodeLookup, |
|
59 BTNode, |
|
60 BTSearch, |
|
61 SupportGenerator, |
|
62 LocalAnalysis, |
|
63 }; |
|
64 use alg_tools::mapping::{RealMapping, Instance}; |
53 use alg_tools::mapping::{RealMapping, Instance}; |
65 use alg_tools::nalgebra_support::ToNalgebraRealField; |
54 use alg_tools::nalgebra_support::ToNalgebraRealField; |
66 use alg_tools::linops::AXPY; |
55 use alg_tools::linops::AXPY; |
67 |
56 |
68 use crate::types::*; |
57 use crate::types::*; |
69 use crate::measures::{DiscreteMeasure, RNDM, Radon}; |
58 use crate::measures::{DiscreteMeasure, RNDM}; |
70 use crate::measures::merging::SpikeMerging; |
59 use crate::measures::merging::SpikeMerging; |
71 use crate::forward_model::{ |
60 use crate::forward_model::{ |
|
61 ForwardModel, |
72 AdjointProductBoundedBy, |
62 AdjointProductBoundedBy, |
73 ForwardModel |
63 }; |
74 }; |
|
75 use crate::seminorms::DiscreteMeasureOp; |
|
76 use crate::plot::{ |
64 use crate::plot::{ |
77 SeqPlotter, |
65 SeqPlotter, |
78 Plotting, |
66 Plotting, |
79 PlotLookup |
67 PlotLookup |
80 }; |
68 }; |
81 use crate::fb::{ |
69 use crate::fb::{ |
82 FBGenericConfig, |
|
83 insert_and_reweigh, |
|
84 postprocess, |
70 postprocess, |
85 prune_with_stats |
71 prune_with_stats |
|
72 }; |
|
73 pub use crate::prox_penalty::{ |
|
74 FBGenericConfig, |
|
75 ProxPenalty |
86 }; |
76 }; |
87 use crate::regularisation::RegTerm; |
77 use crate::regularisation::RegTerm; |
88 use crate::dataterm::{ |
78 use crate::dataterm::{ |
89 DataTerm, |
79 DataTerm, |
90 L2Squared, |
80 L2Squared, |
221 /// |
211 /// |
222 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. |
212 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. |
223 /// |
213 /// |
224 /// Returns the final iterate. |
214 /// Returns the final iterate. |
225 #[replace_float_literals(F::cast_from(literal))] |
215 #[replace_float_literals(F::cast_from(literal))] |
226 pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( |
216 pub fn pointsource_pdps_reg<F, I, A, D, Reg, P, const N : usize>( |
227 opA : &'a A, |
217 opA : &A, |
228 b : &'a A::Observable, |
218 b : &A::Observable, |
229 reg : Reg, |
219 reg : Reg, |
230 op𝒟 : &'a 𝒟, |
220 prox_penalty : &P, |
231 pdpsconfig : &PDPSConfig<F>, |
221 pdpsconfig : &PDPSConfig<F>, |
232 iterator : I, |
222 iterator : I, |
233 mut plotter : SeqPlotter<F, N>, |
223 mut plotter : SeqPlotter<F, N>, |
234 dataterm : D, |
224 dataterm : D, |
235 ) -> RNDM<F, N> |
225 ) -> RNDM<F, N> |
236 where F : Float + ToNalgebraRealField, |
226 where |
237 I : AlgIteratorFactory<IterInfo<F, N>>, |
227 F : Float + ToNalgebraRealField, |
238 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
228 I : AlgIteratorFactory<IterInfo<F, N>>, |
239 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
229 A : ForwardModel<RNDM<F, N>, F> |
240 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
230 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
241 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
231 A::PreadjointCodomain : RealMapping<F, N>, |
242 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
232 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
243 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
233 PlotLookup : Plotting<N>, |
244 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
234 RNDM<F, N> : SpikeMerging<F>, |
245 𝒟::Codomain : RealMapping<F, N>, |
235 D : PDPSDataTerm<F, A::Observable, N>, |
246 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
236 Reg : RegTerm<F, N>, |
247 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
237 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
248 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
238 { |
249 PlotLookup : Plotting<N>, |
|
250 RNDM<F, N> : SpikeMerging<F>, |
|
251 D : PDPSDataTerm<F, A::Observable, N>, |
|
252 Reg : RegTerm<F, N> { |
|
253 |
239 |
254 // Check parameters |
240 // Check parameters |
255 assert!(pdpsconfig.τ0 > 0.0 && |
241 assert!(pdpsconfig.τ0 > 0.0 && |
256 pdpsconfig.σ0 > 0.0 && |
242 pdpsconfig.σ0 > 0.0 && |
257 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, |
243 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, |
258 "Invalid step length parameters"); |
244 "Invalid step length parameters"); |
259 |
245 |
260 // Set up parameters |
246 // Set up parameters |
261 let config = &pdpsconfig.generic; |
247 let config = &pdpsconfig.generic; |
262 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
248 let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); |
263 let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt(); |
|
264 let mut τ = pdpsconfig.τ0 / l; |
249 let mut τ = pdpsconfig.τ0 / l; |
265 let mut σ = pdpsconfig.σ0 / l; |
250 let mut σ = pdpsconfig.σ0 / l; |
266 let γ = dataterm.factor_of_strong_convexity(); |
251 let γ = dataterm.factor_of_strong_convexity(); |
267 |
252 |
268 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
253 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
284 let mut stats = IterInfo::new(); |
269 let mut stats = IterInfo::new(); |
285 |
270 |
286 // Run the algorithm |
271 // Run the algorithm |
287 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
272 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
288 // Calculate smooth part of surrogate model. |
273 // Calculate smooth part of surrogate model. |
289 let τv = opA.preadjoint().apply(y * τ); |
274 let mut τv = opA.preadjoint().apply(y * τ); |
290 |
275 |
291 // Save current base point |
276 // Save current base point |
292 let μ_base = μ.clone(); |
277 let μ_base = μ.clone(); |
293 |
278 |
294 // Insert and reweigh |
279 // Insert and reweigh |
295 let (d, _within_tolerances) = insert_and_reweigh( |
280 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
296 &mut μ, &τv, &μ_base, None, |
281 &mut μ, &mut τv, &μ_base, None, |
297 op𝒟, op𝒟norm, |
|
298 τ, ε, |
282 τ, ε, |
299 config, ®, &state, &mut stats |
283 config, ®, &state, &mut stats |
300 ); |
284 ); |
301 |
285 |
302 // Prune and possibly merge spikes |
286 // Prune and possibly merge spikes |
303 if config.merge_now(&state) { |
287 if config.merge_now(&state) { |
304 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
288 stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, ®); |
305 let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); |
|
306 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
|
307 }); |
|
308 } |
289 } |
309 stats.pruned += prune_with_stats(&mut μ); |
290 stats.pruned += prune_with_stats(&mut μ); |
310 |
291 |
311 // Update step length parameters |
292 // Update step length parameters |
312 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
293 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |