| 6 use numeric_literals::replace_float_literals; |
6 use numeric_literals::replace_float_literals; |
| 7 use serde::{Serialize, Deserialize}; |
7 use serde::{Serialize, Deserialize}; |
| 8 |
8 |
| 9 use alg_tools::iterate::AlgIteratorFactory; |
9 use alg_tools::iterate::AlgIteratorFactory; |
| 10 use alg_tools::euclidean::Euclidean; |
10 use alg_tools::euclidean::Euclidean; |
| 11 use alg_tools::sets::Cube; |
11 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; |
| 12 use alg_tools::loc::Loc; |
|
| 13 use alg_tools::mapping::{Mapping, Instance}; |
|
| 14 use alg_tools::norms::Norm; |
12 use alg_tools::norms::Norm; |
| 15 use alg_tools::direct_product::Pair; |
13 use alg_tools::direct_product::Pair; |
| 16 use alg_tools::bisection_tree::{ |
|
| 17 BTFN, |
|
| 18 PreBTFN, |
|
| 19 Bounds, |
|
| 20 BTNodeLookup, |
|
| 21 BTNode, |
|
| 22 BTSearch, |
|
| 23 P2Minimise, |
|
| 24 SupportGenerator, |
|
| 25 LocalAnalysis, |
|
| 26 //Bounded, |
|
| 27 }; |
|
| 28 use alg_tools::mapping::RealMapping; |
|
| 29 use alg_tools::nalgebra_support::ToNalgebraRealField; |
14 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 30 use alg_tools::linops::{ |
15 use alg_tools::linops::{ |
| 31 BoundedLinear, AXPY, GEMV, Adjointable, IdOp, |
16 BoundedLinear, AXPY, GEMV, Adjointable, IdOp, |
| 32 }; |
17 }; |
| 33 use alg_tools::convex::{Conjugable, Prox}; |
18 use alg_tools::convex::{Conjugable, Prox}; |
| 34 use alg_tools::norms::{L2, Linfinity, PairNorm}; |
19 use alg_tools::norms::{L2, PairNorm}; |
| 35 |
20 |
| 36 use crate::types::*; |
21 use crate::types::*; |
| 37 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
22 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
| 38 use crate::measures::merging::SpikeMerging; |
23 use crate::measures::merging::SpikeMerging; |
| 39 use crate::forward_model::{ |
24 use crate::forward_model::{ |
| 40 ForwardModel, |
25 ForwardModel, |
| 41 AdjointProductPairBoundedBy, |
26 AdjointProductPairBoundedBy, |
| 42 }; |
27 }; |
| 43 use crate::seminorms::DiscreteMeasureOp; |
|
| 44 use crate::plot::{ |
28 use crate::plot::{ |
| 45 SeqPlotter, |
29 SeqPlotter, |
| 46 Plotting, |
30 Plotting, |
| 47 PlotLookup |
31 PlotLookup |
| 48 }; |
32 }; |
| 81 |
65 |
| 82 /// Iteratively solve the pointsource localisation with an additional variable |
66 /// Iteratively solve the pointsource localisation with an additional variable |
| 83 /// using primal-dual proximal splitting with a forward step. |
67 /// using primal-dual proximal splitting with a forward step. |
| 84 #[replace_float_literals(F::cast_from(literal))] |
68 #[replace_float_literals(F::cast_from(literal))] |
| 85 pub fn pointsource_forward_pdps_pair< |
69 pub fn pointsource_forward_pdps_pair< |
| 86 'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize |
70 F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize |
| 87 >( |
71 >( |
| 88 opA : &'a A, |
72 opA : &A, |
| 89 b : &A::Observable, |
73 b : &A::Observable, |
| 90 reg : Reg, |
74 reg : Reg, |
| 91 op𝒟 : &'a 𝒟, |
75 prox_penalty : &P, |
| 92 config : &ForwardPDPSConfig<F>, |
76 config : &ForwardPDPSConfig<F>, |
| 93 iterator : I, |
77 iterator : I, |
| 94 mut plotter : SeqPlotter<F, N>, |
78 mut plotter : SeqPlotter<F, N>, |
| 95 //opKμ : KOpM, |
79 //opKμ : KOpM, |
| 96 opKz : &KOpZ, |
80 opKz : &KOpZ, |
| 100 mut y : Y, |
84 mut y : Y, |
| 101 ) -> MeasureZ<F, Z, N> |
85 ) -> MeasureZ<F, Z, N> |
| 102 where |
86 where |
| 103 F : Float + ToNalgebraRealField, |
87 F : Float + ToNalgebraRealField, |
| 104 I : AlgIteratorFactory<IterInfo<F, N>>, |
88 I : AlgIteratorFactory<IterInfo<F, N>>, |
| 105 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
|
| 106 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, |
|
| 107 A : ForwardModel< |
89 A : ForwardModel< |
| 108 MeasureZ<F, Z, N>, |
90 MeasureZ<F, Z, N>, |
| 109 F, |
91 F, |
| 110 PairNorm<Radon, L2, L2>, |
92 PairNorm<Radon, L2, L2>, |
| 111 PreadjointCodomain = Pair<BTFN<F, GA, BTA, N>, Z>, |
93 PreadjointCodomain = Pair<S, Z>, |
| 112 > |
94 > |
| 113 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, 𝒟, IdOp<Z>, FloatType=F>, |
95 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F>, |
| 114 BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
96 S: DifferentiableRealMapping<F, N>, |
| 115 G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone, |
97 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
| 116 𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>, |
|
| 117 Codomain = BTFN<F, G𝒟, BT𝒟, N>>, |
|
| 118 BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>, |
|
| 119 S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 120 K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>, |
|
| 121 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
|
| 122 Cube<F, N>: P2Minimise<Loc<F, N>, F>, |
|
| 123 PlotLookup : Plotting<N>, |
98 PlotLookup : Plotting<N>, |
| 124 RNDM<F, N> : SpikeMerging<F>, |
99 RNDM<F, N> : SpikeMerging<F>, |
| 125 Reg : RegTerm<F, N>, |
100 Reg : RegTerm<F, N>, |
| |
101 P : ProxPenalty<F, S, Reg, N>, |
| 126 KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y> |
102 KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y> |
| 127 + GEMV<F, Z> |
103 + GEMV<F, Z> |
| 128 + Adjointable<Z, Y, AdjointCodomain = Z>, |
104 + Adjointable<Z, Y, AdjointCodomain = Z>, |
| 129 for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>, |
105 for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>, |
| 130 Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd, |
106 Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd, |
| 148 // Initialise iterates |
124 // Initialise iterates |
| 149 let mut μ = DiscreteMeasure::new(); |
125 let mut μ = DiscreteMeasure::new(); |
| 150 let mut residual = calculate_residual(Pair(&μ, &z), opA, b); |
126 let mut residual = calculate_residual(Pair(&μ, &z), opA, b); |
| 151 |
127 |
| 152 // Set up parameters |
128 // Set up parameters |
| 153 let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity); |
129 let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt(); |
| 154 let bigM = 0.0; //opKμ.adjoint_product_bound(&op𝒟).unwrap().sqrt(); |
|
| 155 let nKz = opKz.opnorm_bound(L2, L2); |
130 let nKz = opKz.opnorm_bound(L2, L2); |
| 156 let opIdZ = IdOp::new(); |
131 let opIdZ = IdOp::new(); |
| 157 let (l, l_z) = opA.adjoint_product_pair_bound(&op𝒟, &opIdZ).unwrap(); |
132 let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); |
| 158 // We need to satisfy |
133 // We need to satisfy |
| 159 // |
134 // |
| 160 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 |
135 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 |
| 161 // ^^^^^^^^^^^^^^^^^^^^^^^^^ |
136 // ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 162 // with 1 > σ_p L_z and 1 > τ L. |
137 // with 1 > σ_p L_z and 1 > τ L. |
| 194 let mut stats = IterInfo::new(); |
169 let mut stats = IterInfo::new(); |
| 195 |
170 |
| 196 // Run the algorithm |
171 // Run the algorithm |
| 197 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { |
172 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { |
| 198 // Calculate initial transport |
173 // Calculate initial transport |
| 199 let Pair(τv, τz) = opA.preadjoint().apply(residual * τ); |
174 let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ); |
| 200 let z_base = z.clone(); |
175 let z_base = z.clone(); |
| 201 let μ_base = μ.clone(); |
176 let μ_base = μ.clone(); |
| 202 |
177 |
| 203 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
178 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
| 204 let (d, _within_tolerances) = insert_and_reweigh( |
179 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 205 &mut μ, &τv, &μ_base, None, |
180 &mut μ, &mut τv, &μ_base, None, |
| 206 op𝒟, op𝒟norm, |
|
| 207 τ, ε, &config.insertion, |
181 τ, ε, &config.insertion, |
| 208 ®, &state, &mut stats, |
182 ®, &state, &mut stats, |
| 209 ); |
183 ); |
| 210 |
184 |
| 211 // // Merge spikes. |
185 // // Merge spikes. |
| 246 // Give statistics if requested |
220 // Give statistics if requested |
| 247 let iter = state.iteration(); |
221 let iter = state.iteration(); |
| 248 stats.this_iters += 1; |
222 stats.this_iters += 1; |
| 249 |
223 |
| 250 state.if_verbose(|| { |
224 state.if_verbose(|| { |
| 251 plotter.plot_spikes(iter, Some(&d), Some(&τv), &μ); |
225 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); |
| 252 full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) |
226 full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| 253 }); |
227 }); |
| 254 |
228 |
| 255 // Update main tolerance for next iteration |
229 // Update main tolerance for next iteration |
| 256 ε = tolerance.update(ε, iter); |
230 ε = tolerance.update(ε, iter); |