| 42 use serde::{Serialize, Deserialize}; |
42 use serde::{Serialize, Deserialize}; |
| 43 use nalgebra::DVector; |
43 use nalgebra::DVector; |
| 44 use clap::ValueEnum; |
44 use clap::ValueEnum; |
| 45 |
45 |
| 46 use alg_tools::iterate::AlgIteratorFactory; |
46 use alg_tools::iterate::AlgIteratorFactory; |
| 47 use alg_tools::loc::Loc; |
|
| 48 use alg_tools::euclidean::Euclidean; |
47 use alg_tools::euclidean::Euclidean; |
| 49 use alg_tools::linops::Mapping; |
48 use alg_tools::linops::Mapping; |
| 50 use alg_tools::norms::{ |
49 use alg_tools::norms::{ |
| 51 Linfinity, |
50 Linfinity, |
| 52 Projection, |
51 Projection, |
| 53 }; |
52 }; |
| 54 use alg_tools::bisection_tree::{ |
|
| 55 BTFN, |
|
| 56 PreBTFN, |
|
| 57 Bounds, |
|
| 58 BTNodeLookup, |
|
| 59 BTNode, |
|
| 60 BTSearch, |
|
| 61 SupportGenerator, |
|
| 62 LocalAnalysis, |
|
| 63 }; |
|
| 64 use alg_tools::mapping::{RealMapping, Instance}; |
53 use alg_tools::mapping::{RealMapping, Instance}; |
| 65 use alg_tools::nalgebra_support::ToNalgebraRealField; |
54 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 66 use alg_tools::linops::AXPY; |
55 use alg_tools::linops::AXPY; |
| 67 |
56 |
| 68 use crate::types::*; |
57 use crate::types::*; |
| 69 use crate::measures::{DiscreteMeasure, RNDM, Radon}; |
58 use crate::measures::{DiscreteMeasure, RNDM}; |
| 70 use crate::measures::merging::SpikeMerging; |
59 use crate::measures::merging::SpikeMerging; |
| 71 use crate::forward_model::{ |
60 use crate::forward_model::{ |
| |
61 ForwardModel, |
| 72 AdjointProductBoundedBy, |
62 AdjointProductBoundedBy, |
| 73 ForwardModel |
63 }; |
| 74 }; |
|
| 75 use crate::seminorms::DiscreteMeasureOp; |
|
| 76 use crate::plot::{ |
64 use crate::plot::{ |
| 77 SeqPlotter, |
65 SeqPlotter, |
| 78 Plotting, |
66 Plotting, |
| 79 PlotLookup |
67 PlotLookup |
| 80 }; |
68 }; |
| 81 use crate::fb::{ |
69 use crate::fb::{ |
| 82 FBGenericConfig, |
|
| 83 insert_and_reweigh, |
|
| 84 postprocess, |
70 postprocess, |
| 85 prune_with_stats |
71 prune_with_stats |
| |
72 }; |
| |
73 pub use crate::prox_penalty::{ |
| |
74 FBGenericConfig, |
| |
75 ProxPenalty |
| 86 }; |
76 }; |
| 87 use crate::regularisation::RegTerm; |
77 use crate::regularisation::RegTerm; |
| 88 use crate::dataterm::{ |
78 use crate::dataterm::{ |
| 89 DataTerm, |
79 DataTerm, |
| 90 L2Squared, |
80 L2Squared, |
| 221 /// |
211 /// |
| 222 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. |
212 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. |
| 223 /// |
213 /// |
| 224 /// Returns the final iterate. |
214 /// Returns the final iterate. |
| 225 #[replace_float_literals(F::cast_from(literal))] |
215 #[replace_float_literals(F::cast_from(literal))] |
| 226 pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( |
216 pub fn pointsource_pdps_reg<F, I, A, D, Reg, P, const N : usize>( |
| 227 opA : &'a A, |
217 opA : &A, |
| 228 b : &'a A::Observable, |
218 b : &A::Observable, |
| 229 reg : Reg, |
219 reg : Reg, |
| 230 op𝒟 : &'a 𝒟, |
220 prox_penalty : &P, |
| 231 pdpsconfig : &PDPSConfig<F>, |
221 pdpsconfig : &PDPSConfig<F>, |
| 232 iterator : I, |
222 iterator : I, |
| 233 mut plotter : SeqPlotter<F, N>, |
223 mut plotter : SeqPlotter<F, N>, |
| 234 dataterm : D, |
224 dataterm : D, |
| 235 ) -> RNDM<F, N> |
225 ) -> RNDM<F, N> |
| 236 where F : Float + ToNalgebraRealField, |
226 where |
| 237 I : AlgIteratorFactory<IterInfo<F, N>>, |
227 F : Float + ToNalgebraRealField, |
| 238 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
228 I : AlgIteratorFactory<IterInfo<F, N>>, |
| 239 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
229 A : ForwardModel<RNDM<F, N>, F> |
| 240 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> |
230 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>, |
| 241 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
231 A::PreadjointCodomain : RealMapping<F, N>, |
| 242 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
232 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
| 243 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
233 PlotLookup : Plotting<N>, |
| 244 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, |
234 RNDM<F, N> : SpikeMerging<F>, |
| 245 𝒟::Codomain : RealMapping<F, N>, |
235 D : PDPSDataTerm<F, A::Observable, N>, |
| 246 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
236 Reg : RegTerm<F, N>, |
| 247 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
237 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>, |
| 248 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
238 { |
| 249 PlotLookup : Plotting<N>, |
|
| 250 RNDM<F, N> : SpikeMerging<F>, |
|
| 251 D : PDPSDataTerm<F, A::Observable, N>, |
|
| 252 Reg : RegTerm<F, N> { |
|
| 253 |
239 |
| 254 // Check parameters |
240 // Check parameters |
| 255 assert!(pdpsconfig.τ0 > 0.0 && |
241 assert!(pdpsconfig.τ0 > 0.0 && |
| 256 pdpsconfig.σ0 > 0.0 && |
242 pdpsconfig.σ0 > 0.0 && |
| 257 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, |
243 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, |
| 258 "Invalid step length parameters"); |
244 "Invalid step length parameters"); |
| 259 |
245 |
| 260 // Set up parameters |
246 // Set up parameters |
| 261 let config = &pdpsconfig.generic; |
247 let config = &pdpsconfig.generic; |
| 262 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
248 let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt(); |
| 263 let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt(); |
|
| 264 let mut τ = pdpsconfig.τ0 / l; |
249 let mut τ = pdpsconfig.τ0 / l; |
| 265 let mut σ = pdpsconfig.σ0 / l; |
250 let mut σ = pdpsconfig.σ0 / l; |
| 266 let γ = dataterm.factor_of_strong_convexity(); |
251 let γ = dataterm.factor_of_strong_convexity(); |
| 267 |
252 |
| 268 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
253 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 284 let mut stats = IterInfo::new(); |
269 let mut stats = IterInfo::new(); |
| 285 |
270 |
| 286 // Run the algorithm |
271 // Run the algorithm |
| 287 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
272 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 288 // Calculate smooth part of surrogate model. |
273 // Calculate smooth part of surrogate model. |
| 289 let τv = opA.preadjoint().apply(y * τ); |
274 let mut τv = opA.preadjoint().apply(y * τ); |
| 290 |
275 |
| 291 // Save current base point |
276 // Save current base point |
| 292 let μ_base = μ.clone(); |
277 let μ_base = μ.clone(); |
| 293 |
278 |
| 294 // Insert and reweigh |
279 // Insert and reweigh |
| 295 let (d, _within_tolerances) = insert_and_reweigh( |
280 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 296 &mut μ, &τv, &μ_base, None, |
281 &mut μ, &mut τv, &μ_base, None, |
| 297 op𝒟, op𝒟norm, |
|
| 298 τ, ε, |
282 τ, ε, |
| 299 config, ®, &state, &mut stats |
283 config, ®, &state, &mut stats |
| 300 ); |
284 ); |
| 301 |
285 |
| 302 // Prune and possibly merge spikes |
286 // Prune and possibly merge spikes |
| 303 if config.merge_now(&state) { |
287 if config.merge_now(&state) { |
| 304 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
288 stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, ®); |
| 305 let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base)); |
|
| 306 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
|
| 307 }); |
|
| 308 } |
289 } |
| 309 stats.pruned += prune_with_stats(&mut μ); |
290 stats.pruned += prune_with_stats(&mut μ); |
| 310 |
291 |
| 311 // Update step length parameters |
292 // Update step length parameters |
| 312 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
293 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |