src/pdps.rs

branch
dev
changeset 37
c5d8bd1a7728
parent 35
b087e3eab191
child 39
6316d68b58af
equal deleted inserted replaced
36:fb911f72e698 37:c5d8bd1a7728
42 use serde::{Serialize, Deserialize}; 42 use serde::{Serialize, Deserialize};
43 use nalgebra::DVector; 43 use nalgebra::DVector;
44 use clap::ValueEnum; 44 use clap::ValueEnum;
45 45
46 use alg_tools::iterate::AlgIteratorFactory; 46 use alg_tools::iterate::AlgIteratorFactory;
47 use alg_tools::loc::Loc;
48 use alg_tools::euclidean::Euclidean; 47 use alg_tools::euclidean::Euclidean;
49 use alg_tools::linops::Mapping; 48 use alg_tools::linops::Mapping;
50 use alg_tools::norms::{ 49 use alg_tools::norms::{
51 Linfinity, 50 Linfinity,
52 Projection, 51 Projection,
53 }; 52 };
54 use alg_tools::bisection_tree::{
55 BTFN,
56 PreBTFN,
57 Bounds,
58 BTNodeLookup,
59 BTNode,
60 BTSearch,
61 SupportGenerator,
62 LocalAnalysis,
63 };
64 use alg_tools::mapping::{RealMapping, Instance}; 53 use alg_tools::mapping::{RealMapping, Instance};
65 use alg_tools::nalgebra_support::ToNalgebraRealField; 54 use alg_tools::nalgebra_support::ToNalgebraRealField;
66 use alg_tools::linops::AXPY; 55 use alg_tools::linops::AXPY;
67 56
68 use crate::types::*; 57 use crate::types::*;
69 use crate::measures::{DiscreteMeasure, RNDM, Radon}; 58 use crate::measures::{DiscreteMeasure, RNDM};
70 use crate::measures::merging::SpikeMerging; 59 use crate::measures::merging::SpikeMerging;
71 use crate::forward_model::{ 60 use crate::forward_model::{
61 ForwardModel,
72 AdjointProductBoundedBy, 62 AdjointProductBoundedBy,
73 ForwardModel 63 };
74 };
75 use crate::seminorms::DiscreteMeasureOp;
76 use crate::plot::{ 64 use crate::plot::{
77 SeqPlotter, 65 SeqPlotter,
78 Plotting, 66 Plotting,
79 PlotLookup 67 PlotLookup
80 }; 68 };
81 use crate::fb::{ 69 use crate::fb::{
82 FBGenericConfig,
83 insert_and_reweigh,
84 postprocess, 70 postprocess,
85 prune_with_stats 71 prune_with_stats
72 };
73 pub use crate::prox_penalty::{
74 FBGenericConfig,
75 ProxPenalty
86 }; 76 };
87 use crate::regularisation::RegTerm; 77 use crate::regularisation::RegTerm;
88 use crate::dataterm::{ 78 use crate::dataterm::{
89 DataTerm, 79 DataTerm,
90 L2Squared, 80 L2Squared,
221 /// 211 ///
222 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript. 212 /// For the mathematical formulation, see the [module level](self) documentation and the manuscript.
223 /// 213 ///
224 /// Returns the final iterate. 214 /// Returns the final iterate.
225 #[replace_float_literals(F::cast_from(literal))] 215 #[replace_float_literals(F::cast_from(literal))]
226 pub fn pointsource_pdps_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, D, Reg, const N : usize>( 216 pub fn pointsource_pdps_reg<F, I, A, D, Reg, P, const N : usize>(
227 opA : &'a A, 217 opA : &A,
228 b : &'a A::Observable, 218 b : &A::Observable,
229 reg : Reg, 219 reg : Reg,
230 op𝒟 : &'a 𝒟, 220 prox_penalty : &P,
231 pdpsconfig : &PDPSConfig<F>, 221 pdpsconfig : &PDPSConfig<F>,
232 iterator : I, 222 iterator : I,
233 mut plotter : SeqPlotter<F, N>, 223 mut plotter : SeqPlotter<F, N>,
234 dataterm : D, 224 dataterm : D,
235 ) -> RNDM<F, N> 225 ) -> RNDM<F, N>
236 where F : Float + ToNalgebraRealField, 226 where
237 I : AlgIteratorFactory<IterInfo<F, N>>, 227 F : Float + ToNalgebraRealField,
238 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, 228 I : AlgIteratorFactory<IterInfo<F, N>>,
239 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 229 A : ForwardModel<RNDM<F, N>, F>
240 A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> 230 + AdjointProductBoundedBy<RNDM<F, N>, P, FloatType=F>,
241 + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, 231 A::PreadjointCodomain : RealMapping<F, N>,
242 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, 232 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
243 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, 233 PlotLookup : Plotting<N>,
244 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>, 234 RNDM<F, N> : SpikeMerging<F>,
245 𝒟::Codomain : RealMapping<F, N>, 235 D : PDPSDataTerm<F, A::Observable, N>,
246 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 236 Reg : RegTerm<F, N>,
247 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, 237 P : ProxPenalty<F, A::PreadjointCodomain, Reg, N>,
248 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 238 {
249 PlotLookup : Plotting<N>,
250 RNDM<F, N> : SpikeMerging<F>,
251 D : PDPSDataTerm<F, A::Observable, N>,
252 Reg : RegTerm<F, N> {
253 239
254 // Check parameters 240 // Check parameters
255 assert!(pdpsconfig.τ0 > 0.0 && 241 assert!(pdpsconfig.τ0 > 0.0 &&
256 pdpsconfig.σ0 > 0.0 && 242 pdpsconfig.σ0 > 0.0 &&
257 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0, 243 pdpsconfig.τ0 * pdpsconfig.σ0 <= 1.0,
258 "Invalid step length parameters"); 244 "Invalid step length parameters");
259 245
260 // Set up parameters 246 // Set up parameters
261 let config = &pdpsconfig.generic; 247 let config = &pdpsconfig.generic;
262 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); 248 let l = opA.adjoint_product_bound(prox_penalty).unwrap().sqrt();
263 let l = opA.adjoint_product_bound(&op𝒟).unwrap().sqrt();
264 let mut τ = pdpsconfig.τ0 / l; 249 let mut τ = pdpsconfig.τ0 / l;
265 let mut σ = pdpsconfig.σ0 / l; 250 let mut σ = pdpsconfig.σ0 / l;
266 let γ = dataterm.factor_of_strong_convexity(); 251 let γ = dataterm.factor_of_strong_convexity();
267 252
268 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 253 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
284 let mut stats = IterInfo::new(); 269 let mut stats = IterInfo::new();
285 270
286 // Run the algorithm 271 // Run the algorithm
287 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 272 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
288 // Calculate smooth part of surrogate model. 273 // Calculate smooth part of surrogate model.
289 let τv = opA.preadjoint().apply(y * τ); 274 let mut τv = opA.preadjoint().apply(y * τ);
290 275
291 // Save current base point 276 // Save current base point
292 let μ_base = μ.clone(); 277 let μ_base = μ.clone();
293 278
294 // Insert and reweigh 279 // Insert and reweigh
295 let (d, _within_tolerances) = insert_and_reweigh( 280 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
296 &mut μ, &τv, &μ_base, None, 281 &mut μ, &mut τv, &μ_base, None,
297 op𝒟, op𝒟norm,
298 τ, ε, 282 τ, ε,
299 config, &reg, &state, &mut stats 283 config, &reg, &state, &mut stats
300 ); 284 );
301 285
302 // Prune and possibly merge spikes 286 // Prune and possibly merge spikes
303 if config.merge_now(&state) { 287 if config.merge_now(&state) {
304 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { 288 stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, &reg);
305 let mut d = &τv + op𝒟.preapply(μ_candidate.sub_matching(&μ_base));
306 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
307 });
308 } 289 }
309 stats.pruned += prune_with_stats(&mut μ); 290 stats.pruned += prune_with_stats(&mut μ);
310 291
311 // Update step length parameters 292 // Update step length parameters
312 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); 293 let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ);
321 // Give statistics if requested 302 // Give statistics if requested
322 let iter = state.iteration(); 303 let iter = state.iteration();
323 stats.this_iters += 1; 304 stats.this_iters += 1;
324 305
325 state.if_verbose(|| { 306 state.if_verbose(|| {
326 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); 307 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ);
327 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new())) 308 full_stats(&μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
328 }); 309 });
329 310
330 ε = tolerance.update(ε, iter); 311 ε = tolerance.update(ε, iter);
331 } 312 }

mercurial