6 use numeric_literals::replace_float_literals; |
6 use numeric_literals::replace_float_literals; |
7 use serde::{Serialize, Deserialize}; |
7 use serde::{Serialize, Deserialize}; |
8 |
8 |
9 use alg_tools::iterate::AlgIteratorFactory; |
9 use alg_tools::iterate::AlgIteratorFactory; |
10 use alg_tools::euclidean::Euclidean; |
10 use alg_tools::euclidean::Euclidean; |
11 use alg_tools::sets::Cube; |
11 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; |
12 use alg_tools::loc::Loc; |
|
13 use alg_tools::mapping::{Mapping, Instance}; |
|
14 use alg_tools::norms::Norm; |
12 use alg_tools::norms::Norm; |
15 use alg_tools::direct_product::Pair; |
13 use alg_tools::direct_product::Pair; |
16 use alg_tools::bisection_tree::{ |
|
17 BTFN, |
|
18 PreBTFN, |
|
19 Bounds, |
|
20 BTNodeLookup, |
|
21 BTNode, |
|
22 BTSearch, |
|
23 P2Minimise, |
|
24 SupportGenerator, |
|
25 LocalAnalysis, |
|
26 //Bounded, |
|
27 }; |
|
28 use alg_tools::mapping::RealMapping; |
|
29 use alg_tools::nalgebra_support::ToNalgebraRealField; |
14 use alg_tools::nalgebra_support::ToNalgebraRealField; |
30 use alg_tools::linops::{ |
15 use alg_tools::linops::{ |
31 BoundedLinear, AXPY, GEMV, Adjointable, IdOp, |
16 BoundedLinear, AXPY, GEMV, Adjointable, IdOp, |
32 }; |
17 }; |
33 use alg_tools::convex::{Conjugable, Prox}; |
18 use alg_tools::convex::{Conjugable, Prox}; |
34 use alg_tools::norms::{L2, Linfinity, PairNorm}; |
19 use alg_tools::norms::{L2, PairNorm}; |
35 |
20 |
36 use crate::types::*; |
21 use crate::types::*; |
37 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
22 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
38 use crate::measures::merging::SpikeMerging; |
23 use crate::measures::merging::SpikeMerging; |
39 use crate::forward_model::{ |
24 use crate::forward_model::{ |
40 ForwardModel, |
25 ForwardModel, |
41 AdjointProductPairBoundedBy, |
26 AdjointProductPairBoundedBy, |
42 }; |
27 }; |
43 use crate::seminorms::DiscreteMeasureOp; |
|
44 use crate::plot::{ |
28 use crate::plot::{ |
45 SeqPlotter, |
29 SeqPlotter, |
46 Plotting, |
30 Plotting, |
47 PlotLookup |
31 PlotLookup |
48 }; |
32 }; |
81 |
65 |
82 /// Iteratively solve the pointsource localisation with an additional variable |
66 /// Iteratively solve the pointsource localisation with an additional variable |
83 /// using primal-dual proximal splitting with a forward step. |
67 /// using primal-dual proximal splitting with a forward step. |
84 #[replace_float_literals(F::cast_from(literal))] |
68 #[replace_float_literals(F::cast_from(literal))] |
85 pub fn pointsource_forward_pdps_pair< |
69 pub fn pointsource_forward_pdps_pair< |
86 'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize |
70 F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize |
87 >( |
71 >( |
88 opA : &'a A, |
72 opA : &A, |
89 b : &A::Observable, |
73 b : &A::Observable, |
90 reg : Reg, |
74 reg : Reg, |
91 op𝒟 : &'a 𝒟, |
75 prox_penalty : &P, |
92 config : &ForwardPDPSConfig<F>, |
76 config : &ForwardPDPSConfig<F>, |
93 iterator : I, |
77 iterator : I, |
94 mut plotter : SeqPlotter<F, N>, |
78 mut plotter : SeqPlotter<F, N>, |
95 //opKμ : KOpM, |
79 //opKμ : KOpM, |
96 opKz : &KOpZ, |
80 opKz : &KOpZ, |
100 mut y : Y, |
84 mut y : Y, |
101 ) -> MeasureZ<F, Z, N> |
85 ) -> MeasureZ<F, Z, N> |
102 where |
86 where |
103 F : Float + ToNalgebraRealField, |
87 F : Float + ToNalgebraRealField, |
104 I : AlgIteratorFactory<IterInfo<F, N>>, |
88 I : AlgIteratorFactory<IterInfo<F, N>>, |
105 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
|
106 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
107 A : ForwardModel< |
89 A : ForwardModel< |
108 MeasureZ<F, Z, N>, |
90 MeasureZ<F, Z, N>, |
109 F, |
91 F, |
110 PairNorm<Radon, L2, L2>, |
92 PairNorm<Radon, L2, L2>, |
111 PreadjointCodomain = Pair<BTFN<F, GA, BTA, N>, Z>, |
93 PreadjointCodomain = Pair<S, Z>, |
112 > |
94 > |
113 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, 𝒟, IdOp<Z>, FloatType=F>, |
95 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F>, |
114 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
96 S: DifferentiableRealMapping<F, N>, |
115 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
97 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
116 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, |
|
117 Codomain = BTFN<F, G𝒟, BT𝒟, N>>, |
|
118 BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
119 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
120 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
121 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
122 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
|
123 PlotLookup : Plotting<N>, |
98 PlotLookup : Plotting<N>, |
124 RNDM<F, N> : SpikeMerging<F>, |
99 RNDM<F, N> : SpikeMerging<F>, |
125 Reg : RegTerm<F, N>, |
100 Reg : RegTerm<F, N>, |
|
101 P : ProxPenalty<F, S, Reg, N>, |
126 KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y> |
102 KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y> |
127 + GEMV<F, Z> |
103 + GEMV<F, Z> |
128 + Adjointable<Z, Y, AdjointCodomain = Z>, |
104 + Adjointable<Z, Y, AdjointCodomain = Z>, |
129 for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>, |
105 for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>, |
130 Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd, |
106 Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd, |
148 // Initialise iterates |
124 // Initialise iterates |
149 let mut μ = DiscreteMeasure::new(); |
125 let mut μ = DiscreteMeasure::new(); |
150 let mut residual = calculate_residual(Pair(&μ, &z), opA, b); |
126 let mut residual = calculate_residual(Pair(&μ, &z), opA, b); |
151 |
127 |
152 // Set up parameters |
128 // Set up parameters |
153 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
129 let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt(); |
154 let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); |
|
155 let nKz = opKz.opnorm_bound(L2, L2); |
130 let nKz = opKz.opnorm_bound(L2, L2); |
156 let opIdZ = IdOp::new(); |
131 let opIdZ = IdOp::new(); |
157 let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap(); |
132 let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); |
158 // We need to satisfy |
133 // We need to satisfy |
159 // |
134 // |
160 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 |
135 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 |
161 // ^^^^^^^^^^^^^^^^^^^^^^^^^ |
136 // ^^^^^^^^^^^^^^^^^^^^^^^^^ |
162 // with 1 > σ_p L_z and 1 > τ L. |
137 // with 1 > σ_p L_z and 1 > τ L. |
194 let mut stats = IterInfo::new(); |
169 let mut stats = IterInfo::new(); |
195 |
170 |
196 // Run the algorithm |
171 // Run the algorithm |
197 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { |
172 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { |
198 // Calculate initial transport |
173 // Calculate initial transport |
199 let Pair(τv, τz) = opA.preadjoint().apply(residual * τ); |
174 let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ); |
200 let z_base = z.clone(); |
175 let z_base = z.clone(); |
201 let μ_base = μ.clone(); |
176 let μ_base = μ.clone(); |
202 |
177 |
203 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
178 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
204 let (d, _within_tolerances) = insert_and_reweigh( |
179 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
205 &mut μ, &τv, &μ_base, None, |
180 &mut μ, &mut τv, &μ_base, None, |
206 op𝒟, op𝒟norm, |
|
207 τ, ε, &config.insertion, |
181 τ, ε, &config.insertion, |
208 ®, &state, &mut stats, |
182 ®, &state, &mut stats, |
209 ); |
183 ); |
210 |
184 |
211 // // Merge spikes. |
185 // // Merge spikes. |
246 // Give statistics if requested |
220 // Give statistics if requested |
247 let iter = state.iteration(); |
221 let iter = state.iteration(); |
248 stats.this_iters += 1; |
222 stats.this_iters += 1; |
249 |
223 |
250 state.if_verbose(|| { |
224 state.if_verbose(|| { |
251 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); |
225 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); |
252 full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) |
226 full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) |
253 }); |
227 }); |
254 |
228 |
255 // Update main tolerance for next iteration |
229 // Update main tolerance for next iteration |
256 ε = tolerance.update(ε, iter); |
230 ε = tolerance.update(ε, iter); |