2 Solver for the point source localisation problem using a sliding |
2 Solver for the point source localisation problem using a sliding |
3 primal-dual proximal splitting method. |
3 primal-dual proximal splitting method. |
4 */ |
4 */ |
5 |
5 |
6 use numeric_literals::replace_float_literals; |
6 use numeric_literals::replace_float_literals; |
7 use serde::{Serialize, Deserialize}; |
7 use serde::{Deserialize, Serialize}; |
8 //use colored::Colorize; |
8 //use colored::Colorize; |
9 //use nalgebra::{DVector, DMatrix}; |
9 //use nalgebra::{DVector, DMatrix}; |
10 use std::iter::Iterator; |
10 use std::iter::Iterator; |
11 |
11 |
|
12 use alg_tools::convex::{Conjugable, Prox}; |
|
13 use alg_tools::direct_product::Pair; |
|
14 use alg_tools::euclidean::Euclidean; |
12 use alg_tools::iterate::AlgIteratorFactory; |
15 use alg_tools::iterate::AlgIteratorFactory; |
13 use alg_tools::euclidean::Euclidean; |
16 use alg_tools::linops::{Adjointable, BoundedLinear, IdOp, AXPY, GEMV}; |
14 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; |
17 use alg_tools::mapping::{DifferentiableRealMapping, Instance, Mapping}; |
15 use alg_tools::norms::{Norm, Dist}; |
|
16 use alg_tools::direct_product::Pair; |
|
17 use alg_tools::nalgebra_support::ToNalgebraRealField; |
18 use alg_tools::nalgebra_support::ToNalgebraRealField; |
18 use alg_tools::linops::{ |
19 use alg_tools::norms::{Dist, Norm}; |
19 BoundedLinear, AXPY, GEMV, Adjointable, IdOp, |
20 use alg_tools::norms::{PairNorm, L2}; |
20 }; |
21 |
21 use alg_tools::convex::{Conjugable, Prox}; |
22 use crate::forward_model::{AdjointProductPairBoundedBy, BoundedCurvature, ForwardModel}; |
22 use alg_tools::norms::{L2, PairNorm}; |
23 use crate::measures::merging::SpikeMerging; |
23 |
24 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
24 use crate::types::*; |
25 use crate::types::*; |
25 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
|
26 use crate::measures::merging::SpikeMerging; |
|
27 use crate::forward_model::{ |
|
28 ForwardModel, |
|
29 AdjointProductPairBoundedBy, |
|
30 BoundedCurvature, |
|
31 }; |
|
32 // use crate::transport::TransportLipschitz; |
26 // use crate::transport::TransportLipschitz; |
33 //use crate::tolerance::Tolerance; |
27 //use crate::tolerance::Tolerance; |
34 use crate::plot::{ |
|
35 SeqPlotter, |
|
36 Plotting, |
|
37 PlotLookup |
|
38 }; |
|
39 use crate::fb::*; |
28 use crate::fb::*; |
|
29 use crate::plot::{PlotLookup, Plotting, SeqPlotter}; |
40 use crate::regularisation::SlidingRegTerm; |
30 use crate::regularisation::SlidingRegTerm; |
41 // use crate::dataterm::L2Squared; |
31 // use crate::dataterm::L2Squared; |
|
32 use crate::dataterm::{calculate_residual, calculate_residual2}; |
42 use crate::sliding_fb::{ |
33 use crate::sliding_fb::{ |
43 TransportConfig, |
34 aposteriori_transport, initial_transport, TransportConfig, TransportStepLength, |
44 TransportStepLength, |
|
45 initial_transport, |
|
46 aposteriori_transport, |
|
47 }; |
35 }; |
48 use crate::dataterm::{ |
|
49 calculate_residual2, |
|
50 calculate_residual, |
|
51 }; |
|
52 |
|
53 |
36 |
54 /// Settings for [`pointsource_sliding_pdps_pair`]. |
37 /// Settings for [`pointsource_sliding_pdps_pair`]. |
55 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
38 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
56 #[serde(default)] |
39 #[serde(default)] |
57 pub struct SlidingPDPSConfig<F : Float> { |
40 pub struct SlidingPDPSConfig<F: Float> { |
58 /// Primal step length scaling. |
41 /// Primal step length scaling. |
59 pub τ0 : F, |
42 pub τ0: F, |
60 /// Primal step length scaling. |
43 /// Primal step length scaling. |
61 pub σp0 : F, |
44 pub σp0: F, |
62 /// Dual step length scaling. |
45 /// Dual step length scaling. |
63 pub σd0 : F, |
46 pub σd0: F, |
64 /// Transport parameters |
47 /// Transport parameters |
65 pub transport : TransportConfig<F>, |
48 pub transport: TransportConfig<F>, |
66 /// Generic parameters |
49 /// Generic parameters |
67 pub insertion : FBGenericConfig<F>, |
50 pub insertion: FBGenericConfig<F>, |
68 } |
51 } |
69 |
52 |
70 #[replace_float_literals(F::cast_from(literal))] |
53 #[replace_float_literals(F::cast_from(literal))] |
71 impl<F : Float> Default for SlidingPDPSConfig<F> { |
54 impl<F: Float> Default for SlidingPDPSConfig<F> { |
72 fn default() -> Self { |
55 fn default() -> Self { |
73 SlidingPDPSConfig { |
56 SlidingPDPSConfig { |
74 τ0 : 0.99, |
57 τ0: 0.99, |
75 σd0 : 0.05, |
58 σd0: 0.05, |
76 σp0 : 0.99, |
59 σp0: 0.99, |
77 transport : TransportConfig { θ0 : 0.9, ..Default::default()}, |
60 transport: TransportConfig { |
78 insertion : Default::default() |
61 θ0: 0.9, |
|
62 ..Default::default() |
|
63 }, |
|
64 insertion: Default::default(), |
79 } |
65 } |
80 } |
66 } |
81 } |
67 } |
82 |
68 |
83 type MeasureZ<F, Z, const N : usize> = Pair<RNDM<F, N>, Z>; |
69 type MeasureZ<F, Z, const N: usize> = Pair<RNDM<F, N>, Z>; |
84 |
70 |
85 /// Iteratively solve the pointsource localisation with an additional variable |
71 /// Iteratively solve the pointsource localisation with an additional variable |
86 /// using sliding primal-dual proximal splitting |
72 /// using sliding primal-dual proximal splitting |
87 /// |
73 /// |
88 /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. |
74 /// The parametrisation is as for [`crate::forward_pdps::pointsource_forward_pdps_pair`]. |
89 #[replace_float_literals(F::cast_from(literal))] |
75 #[replace_float_literals(F::cast_from(literal))] |
90 pub fn pointsource_sliding_pdps_pair< |
76 pub fn pointsource_sliding_pdps_pair< |
91 F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize |
77 F, |
|
78 I, |
|
79 A, |
|
80 S, |
|
81 Reg, |
|
82 P, |
|
83 Z, |
|
84 R, |
|
85 Y, |
|
86 /*KOpM, */ KOpZ, |
|
87 H, |
|
88 const N: usize, |
92 >( |
89 >( |
93 opA : &A, |
90 opA: &A, |
94 b : &A::Observable, |
91 b: &A::Observable, |
95 reg : Reg, |
92 reg: Reg, |
96 prox_penalty : &P, |
93 prox_penalty: &P, |
97 config : &SlidingPDPSConfig<F>, |
94 config: &SlidingPDPSConfig<F>, |
98 iterator : I, |
95 iterator: I, |
99 mut plotter : SeqPlotter<F, N>, |
96 mut plotter: SeqPlotter<F, N>, |
100 //opKμ : KOpM, |
97 //opKμ : KOpM, |
101 opKz : &KOpZ, |
98 opKz: &KOpZ, |
102 fnR : &R, |
99 fnR: &R, |
103 fnH : &H, |
100 fnH: &H, |
104 mut z : Z, |
101 mut z: Z, |
105 mut y : Y, |
102 mut y: Y, |
106 ) -> MeasureZ<F, Z, N> |
103 ) -> MeasureZ<F, Z, N> |
107 where |
104 where |
108 F : Float + ToNalgebraRealField, |
105 F: Float + ToNalgebraRealField, |
109 I : AlgIteratorFactory<IterInfo<F, N>>, |
106 I: AlgIteratorFactory<IterInfo<F, N>>, |
110 A : ForwardModel< |
107 A: ForwardModel<MeasureZ<F, Z, N>, F, PairNorm<Radon, L2, L2>, PreadjointCodomain = Pair<S, Z>> |
111 MeasureZ<F, Z, N>, |
108 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType = F> |
112 F, |
109 + BoundedCurvature<FloatType = F>, |
113 PairNorm<Radon, L2, L2>, |
110 S: DifferentiableRealMapping<F, N>, |
114 PreadjointCodomain = Pair<S, Z>, |
111 for<'b> &'b A::Observable: std::ops::Neg<Output = A::Observable> + Instance<A::Observable>, |
115 > |
112 PlotLookup: Plotting<N>, |
116 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F> |
113 RNDM<F, N>: SpikeMerging<F>, |
117 + BoundedCurvature<FloatType=F>, |
114 Reg: SlidingRegTerm<F, N>, |
118 S : DifferentiableRealMapping<F, N>, |
115 P: ProxPenalty<F, S, Reg, N>, |
119 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
|
120 PlotLookup : Plotting<N>, |
|
121 RNDM<F, N> : SpikeMerging<F>, |
|
122 Reg : SlidingRegTerm<F, N>, |
|
123 P : ProxPenalty<F, S, Reg, N>, |
|
124 // KOpM : Linear<RNDM<F, N>, Codomain=Y> |
116 // KOpM : Linear<RNDM<F, N>, Codomain=Y> |
125 // + GEMV<F, RNDM<F, N>> |
117 // + GEMV<F, RNDM<F, N>> |
126 // + Preadjointable< |
118 // + Preadjointable< |
127 // RNDM<F, N>, Y, |
119 // RNDM<F, N>, Y, |
128 // PreadjointCodomain = S, |
120 // PreadjointCodomain = S, |
129 // > |
121 // > |
130 // + TransportLipschitz<L2Squared, FloatType=F> |
122 // + TransportLipschitz<L2Squared, FloatType=F> |
131 // + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
123 // + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>, |
132 // for<'b> KOpM::Preadjoint<'b> : GEMV<F, Y>, |
124 // for<'b> KOpM::Preadjoint<'b> : GEMV<F, Y>, |
133 // Since Z is Hilbert, we may just as well use adjoints for K_z. |
125 // Since Z is Hilbert, we may just as well use adjoints for K_z. |
134 KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y> |
126 KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y> |
135 + GEMV<F, Z> |
127 + GEMV<F, Z> |
136 + Adjointable<Z, Y, AdjointCodomain = Z>, |
128 + Adjointable<Z, Y, AdjointCodomain = Z>, |
137 for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>, |
129 for<'b> KOpZ::Adjoint<'b>: GEMV<F, Y>, |
138 Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd, |
130 Y: AXPY<F> + Euclidean<F, Output = Y> + Clone + ClosedAdd, |
139 for<'b> &'b Y : Instance<Y>, |
131 for<'b> &'b Y: Instance<Y>, |
140 Z : AXPY<F, Owned=Z> + Euclidean<F, Output=Z> + Clone + Norm<F, L2> + Dist<F, L2>, |
132 Z: AXPY<F, Owned = Z> + Euclidean<F, Output = Z> + Clone + Norm<F, L2> + Dist<F, L2>, |
141 for<'b> &'b Z : Instance<Z>, |
133 for<'b> &'b Z: Instance<Z>, |
142 R : Prox<Z, Codomain=F>, |
134 R: Prox<Z, Codomain = F>, |
143 H : Conjugable<Y, F, Codomain=F>, |
135 H: Conjugable<Y, F, Codomain = F>, |
144 for<'b> H::Conjugate<'b> : Prox<Y>, |
136 for<'b> H::Conjugate<'b>: Prox<Y>, |
145 { |
137 { |
146 |
|
147 // Check parameters |
138 // Check parameters |
148 assert!(config.τ0 > 0.0 && |
139 assert!( |
149 config.τ0 < 1.0 && |
140 config.τ0 > 0.0 |
150 config.σp0 > 0.0 && |
141 && config.τ0 < 1.0 |
151 config.σp0 < 1.0 && |
142 && config.σp0 > 0.0 |
152 config.σd0 > 0.0 && |
143 && config.σp0 < 1.0 |
153 config.σp0 * config.σd0 <= 1.0, |
144 && config.σd0 > 0.0 |
154 "Invalid step length parameters"); |
145 && config.σp0 * config.σd0 <= 1.0, |
|
146 "Invalid step length parameters" |
|
147 ); |
155 config.transport.check(); |
148 config.transport.check(); |
156 |
149 |
157 // Initialise iterates |
150 // Initialise iterates |
158 let mut μ = DiscreteMeasure::new(); |
151 let mut μ = DiscreteMeasure::new(); |
159 let mut γ1 = DiscreteMeasure::new(); |
152 let mut γ1 = DiscreteMeasure::new(); |
182 // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} |
177 // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} |
183 // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) |
178 // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) |
184 // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) |
179 // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) |
185 let φ = 1.0 - config.σp0; |
180 let φ = 1.0 - config.σp0; |
186 let a = 1.0 - σ_p * l_z; |
181 let a = 1.0 - σ_p * l_z; |
187 let τ = config.τ0 * φ / ( σ_d * bigM * a + φ * l ); |
182 let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l); |
188 let ψ = 1.0 - τ * l; |
183 let ψ = 1.0 - τ * l; |
189 let β = σ_p * config.σd0 * nKz / a; // σ_p * σ_d * (nKz * nK_z) / a; |
184 let β = σ_p * config.σd0 * nKz / a; // σ_p * σ_d * (nKz * nK_z) / a; |
190 assert!(β < 1.0); |
185 assert!(β < 1.0); |
191 // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as: |
186 // Now we need κ‖K_μ(π_♯^1 - π_♯^0)γ‖^2 ≤ (1/θ - τ[ℓ_F + ℓ]) ∫ c_2 dγ for κ defined as: |
192 let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); |
187 let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM); |
193 // The factor two in the manuscript disappears due to the definition of 𝚹 being |
188 // The factor two in the manuscript disappears due to the definition of 𝚹 being |
194 // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. |
189 // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2. |
195 let (maybe_ℓ_v0, maybe_transport_lip) = opA.curvature_bound_components(); |
190 let (maybe_ℓ_F0, maybe_transport_lip) = opA.curvature_bound_components(); |
196 let transport_lip = maybe_transport_lip.unwrap(); |
191 let transport_lip = maybe_transport_lip.unwrap(); |
197 let calculate_θ = |ℓ_v, max_transport| { |
192 let calculate_θ = |ℓ_F, max_transport| { |
198 let ℓ_F = ℓ_v + transport_lip * max_transport; |
193 let ℓ_r = transport_lip * max_transport; |
199 config.transport.θ0 / (τ*(ℓ + ℓ_F) + κ * bigθ * max_transport) |
194 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport) |
200 }; |
195 }; |
201 let mut θ_or_adaptive = match maybe_ℓ_v0 { |
196 let mut θ_or_adaptive = match maybe_ℓ_F0 { |
202 // We assume that the residual is decreasing. |
197 // We assume that the residual is decreasing. |
203 Some(ℓ_v0) => TransportStepLength::AdaptiveMax { |
198 Some(ℓ_F0) => TransportStepLength::AdaptiveMax { |
204 l: ℓ_v0 * b.norm2(), // TODO: could estimate computing the real reesidual |
199 l: ℓ_F0 * b.norm2(), // TODO: could estimate computing the real reesidual |
205 max_transport : 0.0, |
200 max_transport: 0.0, |
206 g : calculate_θ |
201 g: calculate_θ, |
207 }, |
202 }, |
208 None => TransportStepLength::FullyAdaptive { |
203 None => TransportStepLength::FullyAdaptive { |
209 l : F::EPSILON, |
204 l: F::EPSILON, |
210 max_transport : 0.0, |
205 max_transport: 0.0, |
211 g : calculate_θ |
206 g: calculate_θ, |
212 }, |
207 }, |
213 }; |
208 }; |
214 // Acceleration is not currently supported |
209 // Acceleration is not currently supported |
215 // let γ = dataterm.factor_of_strong_convexity(); |
210 // let γ = dataterm.factor_of_strong_convexity(); |
216 let ω = 1.0; |
211 let ω = 1.0; |
242 // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have |
239 // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have |
243 // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν, |
240 // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν, |
244 // where A_ν^* becomes a multiplier. |
241 // where A_ν^* becomes a multiplier. |
245 // This is much easier with K_μ = 0, which is the only reason why are enforcing it. |
242 // This is much easier with K_μ = 0, which is the only reason why are enforcing it. |
246 // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. |
243 // TODO: Write a version of initial_transport that can deal with K_μ ≠ 0. |
247 |
244 |
248 let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport( |
245 let (μ_base_masses, mut μ_base_minus_γ0) = |
249 &mut γ1, &mut μ, τ, &mut θ_or_adaptive, v, |
246 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); |
250 ); |
|
251 |
247 |
252 // Solve finite-dimensional subproblem several times until the dual variable for the |
248 // Solve finite-dimensional subproblem several times until the dual variable for the |
253 // regularisation term conforms to the assumptions made for the transport above. |
249 // regularisation term conforms to the assumptions made for the transport above. |
254 let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { |
250 let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop { |
255 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
251 // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b) |
256 let residual_μ̆ = calculate_residual2(Pair(&γ1, &z), |
252 let residual_μ̆ = |
257 Pair(&μ_base_minus_γ0, &zero_z), |
253 calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b); |
258 opA, b); |
|
259 let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); |
254 let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ); |
260 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); |
255 // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0); |
261 |
256 |
262 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
257 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
263 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
258 let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh( |
264 &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), |
259 &mut μ, |
265 τ, ε, &config.insertion, |
260 &mut τv̆, |
266 ®, &state, &mut stats, |
261 &γ1, |
|
262 Some(&μ_base_minus_γ0), |
|
263 τ, |
|
264 ε, |
|
265 &config.insertion, |
|
266 ®, |
|
267 &state, |
|
268 &mut stats, |
267 ); |
269 ); |
268 |
270 |
269 // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}} |
271 // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}} |
270 let mut z_new = τz̆; |
272 let mut z_new = τz̆; |
271 opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ); |
273 opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p / τ); |
272 z_new = fnR.prox(σ_p, z_new + &z); |
274 z_new = fnR.prox(σ_p, z_new + &z); |
273 |
275 |
274 // A posteriori transport adaptation. |
276 // A posteriori transport adaptation. |
275 if aposteriori_transport( |
277 if aposteriori_transport( |
276 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses, |
278 &mut γ1, |
|
279 &mut μ, |
|
280 &mut μ_base_minus_γ0, |
|
281 &μ_base_masses, |
277 Some(z_new.dist(&z, L2)), |
282 Some(z_new.dist(&z, L2)), |
278 ε, &config.transport |
283 ε, |
|
284 &config.transport, |
279 ) { |
285 ) { |
280 break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new) |
286 break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new); |
281 } |
287 } |
282 }; |
288 }; |
283 |
289 |
284 stats.untransported_fraction = Some({ |
290 stats.untransported_fraction = Some({ |
285 assert_eq!(μ_base_masses.len(), γ1.len()); |
291 assert_eq!(μ_base_masses.len(), γ1.len()); |