| 1 /*! |
1 /*! |
| 2 Solver for the point source localisation problem using a |
2 Solver for the point source localisation problem using a |
| 3 primal-dual proximal splitting with a forward step. |
3 primal-dual proximal splitting with a forward step. |
| 4 */ |
4 */ |
| 5 |
5 |
| |
6 use crate::fb::*; |
| |
7 use crate::measures::merging::SpikeMerging; |
| |
8 use crate::measures::{DiscreteMeasure, RNDM}; |
| |
9 use crate::plot::Plotter; |
| |
10 use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair}; |
| |
11 use crate::regularisation::RegTerm; |
| |
12 use crate::types::*; |
| |
13 use alg_tools::convex::{Conjugable, Prox, Zero}; |
| |
14 use alg_tools::direct_product::Pair; |
| |
15 use alg_tools::error::DynResult; |
| |
16 use alg_tools::euclidean::ClosedEuclidean; |
| |
17 use alg_tools::iterate::AlgIteratorFactory; |
| |
18 use alg_tools::linops::{BoundedLinear, IdOp, SimplyAdjointable, ZeroOp, AXPY, GEMV}; |
| |
19 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance}; |
| |
20 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| |
21 use alg_tools::norms::L2; |
| |
22 use anyhow::ensure; |
| 6 use numeric_literals::replace_float_literals; |
23 use numeric_literals::replace_float_literals; |
| 7 use serde::{Serialize, Deserialize}; |
24 use serde::{Deserialize, Serialize}; |
| 8 |
|
| 9 use alg_tools::iterate::AlgIteratorFactory; |
|
| 10 use alg_tools::euclidean::Euclidean; |
|
| 11 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance}; |
|
| 12 use alg_tools::norms::Norm; |
|
| 13 use alg_tools::direct_product::Pair; |
|
| 14 use alg_tools::nalgebra_support::ToNalgebraRealField; |
|
| 15 use alg_tools::linops::{ |
|
| 16 BoundedLinear, AXPY, GEMV, Adjointable, IdOp, |
|
| 17 }; |
|
| 18 use alg_tools::convex::{Conjugable, Prox}; |
|
| 19 use alg_tools::norms::{L2, PairNorm}; |
|
| 20 |
|
| 21 use crate::types::*; |
|
| 22 use crate::measures::{DiscreteMeasure, Radon, RNDM}; |
|
| 23 use crate::measures::merging::SpikeMerging; |
|
| 24 use crate::forward_model::{ |
|
| 25 ForwardModel, |
|
| 26 AdjointProductPairBoundedBy, |
|
| 27 }; |
|
| 28 use crate::plot::{ |
|
| 29 SeqPlotter, |
|
| 30 Plotting, |
|
| 31 PlotLookup |
|
| 32 }; |
|
| 33 use crate::fb::*; |
|
| 34 use crate::regularisation::RegTerm; |
|
| 35 use crate::dataterm::calculate_residual; |
|
| 36 |
25 |
| 37 /// Settings for [`pointsource_forward_pdps_pair`]. |
26 /// Settings for [`pointsource_forward_pdps_pair`]. |
| 38 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
27 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 39 #[serde(default)] |
28 #[serde(default)] |
| 40 pub struct ForwardPDPSConfig<F : Float> { |
29 pub struct ForwardPDPSConfig<F: Float> { |
| 41 /// Primal step length scaling. |
30 /// Overall primal step length scaling. |
| 42 pub τ0 : F, |
31 pub τ0: F, |
| 43 /// Primal step length scaling. |
32 /// Primal step length scaling for additional variable. |
| 44 pub σp0 : F, |
33 pub σp0: F, |
| 45 /// Dual step length scaling. |
34 /// Dual step length scaling for additional variable. |
| 46 pub σd0 : F, |
35 /// |
| |
36 /// Taken zero for [`pointsource_fb_pair`]. |
| |
37 pub σd0: F, |
| 47 /// Generic parameters |
38 /// Generic parameters |
| 48 pub insertion : FBGenericConfig<F>, |
39 pub insertion: InsertionConfig<F>, |
| 49 } |
40 } |
| 50 |
41 |
| 51 #[replace_float_literals(F::cast_from(literal))] |
42 #[replace_float_literals(F::cast_from(literal))] |
| 52 impl<F : Float> Default for ForwardPDPSConfig<F> { |
43 impl<F: Float> Default for ForwardPDPSConfig<F> { |
| 53 fn default() -> Self { |
44 fn default() -> Self { |
| 54 ForwardPDPSConfig { |
45 ForwardPDPSConfig { τ0: 0.99, σd0: 0.05, σp0: 0.99, insertion: Default::default() } |
| 55 τ0 : 0.99, |
|
| 56 σd0 : 0.05, |
|
| 57 σp0 : 0.99, |
|
| 58 insertion : Default::default() |
|
| 59 } |
|
| 60 } |
46 } |
| 61 } |
47 } |
| 62 |
48 |
| 63 type MeasureZ<F, Z, const N : usize> = Pair<RNDM<F, N>, Z>; |
49 type MeasureZ<F, Z, const N: usize> = Pair<RNDM<N, F>, Z>; |
| 64 |
50 |
| 65 /// Iteratively solve the pointsource localisation with an additional variable |
51 /// Iteratively solve the pointsource localisation with an additional variable |
| 66 /// using primal-dual proximal splitting with a forward step. |
52 /// using primal-dual proximal splitting with a forward step. |
| |
53 /// |
| |
54 /// The problem is |
| |
55 /// $$ |
| |
56 /// \min_{μ, z}~ F(μ, z) + R(z) + H(K_z z) + Q(μ), |
| |
57 /// $$ |
| |
58 /// where |
| |
59 /// * The data term $F$ is given in `f`, |
| |
60 /// * the measure (Radon or positivity-constrained Radon) regulariser in $Q$ is given in `reg`, |
| |
61 /// * the functions $R$ and $H$ are given in `fnR` and `fnH`, and |
| |
62 /// * the operator $K_z$ in `opKz`. |
| |
63 /// |
| |
64 /// This is dualised to |
| |
65 /// $$ |
| |
66 /// \min_{μ, z}\max_y~ F(μ, z) + R(z) + ⟨K_z z, y⟩ + Q(μ) - H^*(y). |
| |
67 /// $$ |
| |
68 /// |
| |
69 /// The algorithm is controlled by: |
| |
70 /// * the proximal penalty in `prox_penalty`. |
| |
71 /// * the initial iterates in `z`, `y` |
| |
72 /// * The configuration in `config`. |
| |
73 /// * The `iterator` that controls stopping and reporting. |
| |
74 /// Moreover, plotting is performed by `plotter`. |
| |
75 /// |
| |
76 /// The step lengths need to satisfy |
| |
77 /// $$ |
| |
78 /// τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 |
| |
79 /// $$ ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
80 /// with $1 > σ_p L_z$ and $1 > τ L$. |
| |
81 /// Since we are given “scalings” $τ_0$, $σ_{p,0}$, and $σ_{d,0}$ in `config`, we take |
| |
82 /// $σ_d=σ_{d,0}/‖K_z‖$, and $σ_p = σ_{p,0} / (L_z σ_d‖K_z‖)$. This satisfies the |
| |
83 /// part $[σ_p L_z + σ_pσ_d‖K_z‖^2] < 1$. Then with these cohices, we solve |
| |
84 /// $$ |
| |
85 /// τ = τ_0 \frac{1 - σ_{p,0}}{(σ_d M (1-σ_p L_z) + (1 - σ_{p,0} L)}. |
| |
86 /// $$ |
| 67 #[replace_float_literals(F::cast_from(literal))] |
87 #[replace_float_literals(F::cast_from(literal))] |
| 68 pub fn pointsource_forward_pdps_pair< |
88 pub fn pointsource_forward_pdps_pair< |
| 69 F, I, A, S, Reg, P, Z, R, Y, /*KOpM, */ KOpZ, H, const N : usize |
89 F, |
| |
90 I, |
| |
91 S, |
| |
92 Dat, |
| |
93 Reg, |
| |
94 P, |
| |
95 Z, |
| |
96 R, |
| |
97 Y, |
| |
98 /*KOpM, */ KOpZ, |
| |
99 H, |
| |
100 Plot, |
| |
101 const N: usize, |
| 70 >( |
102 >( |
| 71 opA : &A, |
103 f: &Dat, |
| 72 b : &A::Observable, |
104 reg: &Reg, |
| 73 reg : Reg, |
105 prox_penalty: &P, |
| 74 prox_penalty : &P, |
106 config: &ForwardPDPSConfig<F>, |
| 75 config : &ForwardPDPSConfig<F>, |
107 iterator: I, |
| 76 iterator : I, |
108 mut plotter: Plot, |
| 77 mut plotter : SeqPlotter<F, N>, |
109 (μ0, mut z, mut y): (Option<RNDM<N, F>>, Z, Y), |
| 78 //opKμ : KOpM, |
110 //opKμ : KOpM, |
| 79 opKz : &KOpZ, |
111 opKz: &KOpZ, |
| 80 fnR : &R, |
112 fnR: &R, |
| 81 fnH : &H, |
113 fnH: &H, |
| 82 mut z : Z, |
114 ) -> DynResult<MeasureZ<F, Z, N>> |
| 83 mut y : Y, |
|
| 84 ) -> MeasureZ<F, Z, N> |
|
| 85 where |
115 where |
| 86 F : Float + ToNalgebraRealField, |
116 F: Float + ToNalgebraRealField, |
| 87 I : AlgIteratorFactory<IterInfo<F, N>>, |
117 I: AlgIteratorFactory<IterInfo<F>>, |
| 88 A : ForwardModel< |
118 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>, |
| 89 MeasureZ<F, Z, N>, |
119 //Pair<S, Z>: ClosedMul<F>, // Doesn't really need to be closed, if make this signature more complex… |
| 90 F, |
120 S: DifferentiableRealMapping<N, F> + ClosedMul<F>, |
| 91 PairNorm<Radon, L2, L2>, |
121 RNDM<N, F>: SpikeMerging<F>, |
| 92 PreadjointCodomain = Pair<S, Z>, |
122 Reg: RegTerm<Loc<N, F>, F>, |
| 93 > |
123 P: ProxPenalty<Loc<N, F>, S, Reg, F>, |
| 94 + AdjointProductPairBoundedBy<MeasureZ<F, Z, N>, P, IdOp<Z>, FloatType=F>, |
124 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>, |
| 95 S: DifferentiableRealMapping<F, N>, |
125 KOpZ: BoundedLinear<Z, L2, L2, F, Codomain = Y> |
| 96 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>, |
126 + GEMV<F, Z, Y> |
| 97 PlotLookup : Plotting<N>, |
127 + SimplyAdjointable<Z, Y, Codomain = Y, AdjointCodomain = Z>, |
| 98 RNDM<F, N> : SpikeMerging<F>, |
128 KOpZ::SimpleAdjoint: GEMV<F, Y, Z>, |
| 99 Reg : RegTerm<F, N>, |
129 Y: ClosedEuclidean<F>, |
| 100 P : ProxPenalty<F, S, Reg, N>, |
130 for<'b> &'b Y: Instance<Y>, |
| 101 KOpZ : BoundedLinear<Z, L2, L2, F, Codomain=Y> |
131 Z: ClosedEuclidean<F>, |
| 102 + GEMV<F, Z> |
132 for<'b> &'b Z: Instance<Z>, |
| 103 + Adjointable<Z, Y, AdjointCodomain = Z>, |
133 R: Prox<Z, Codomain = F>, |
| 104 for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>, |
134 H: Conjugable<Y, F, Codomain = F>, |
| 105 Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd, |
135 for<'b> H::Conjugate<'b>: Prox<Y>, |
| 106 for<'b> &'b Y : Instance<Y>, |
136 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>, |
| 107 Z : AXPY<F, Owned=Z> + Euclidean<F, Output=Z> + Clone + Norm<F, L2>, |
|
| 108 for<'b> &'b Z : Instance<Z>, |
|
| 109 R : Prox<Z, Codomain=F>, |
|
| 110 H : Conjugable<Y, F, Codomain=F>, |
|
| 111 for<'b> H::Conjugate<'b> : Prox<Y>, |
|
| 112 { |
137 { |
| 113 |
|
| 114 // Check parameters |
138 // Check parameters |
| 115 assert!(config.τ0 > 0.0 && |
139 // ensure!( |
| 116 config.τ0 < 1.0 && |
140 // config.τ0 > 0.0 |
| 117 config.σp0 > 0.0 && |
141 // && config.τ0 < 1.0 |
| 118 config.σp0 < 1.0 && |
142 // && config.σp0 > 0.0 |
| 119 config.σd0 > 0.0 && |
143 // && config.σp0 < 1.0 |
| 120 config.σp0 * config.σd0 <= 1.0, |
144 // && config.σd0 >= 0.0 |
| 121 "Invalid step length parameters"); |
145 // && config.σp0 * config.σd0 <= 1.0, |
| |
146 // "Invalid step length parameters" |
| |
147 // ); |
| 122 |
148 |
| 123 // Initialise iterates |
149 // Initialise iterates |
| 124 let mut μ = DiscreteMeasure::new(); |
150 let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new()); |
| 125 let mut residual = calculate_residual(Pair(&μ, &z), opA, b); |
|
| 126 |
151 |
| 127 // Set up parameters |
152 // Set up parameters |
| 128 let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt(); |
153 let bigM = 0.0; //opKμ.adjoint_product_bound(prox_penalty).unwrap().sqrt(); |
| 129 let nKz = opKz.opnorm_bound(L2, L2); |
154 let nKz = opKz.opnorm_bound(L2, L2)?; |
| 130 let opIdZ = IdOp::new(); |
155 let idOpZ = IdOp::new(); |
| 131 let (l, l_z) = opA.adjoint_product_pair_bound(prox_penalty, &opIdZ).unwrap(); |
156 let opKz_adj = opKz.adjoint(); |
| |
157 let (l, l_z) = Pair(prox_penalty, &idOpZ).step_length_bound_pair(&f)?; |
| 132 // We need to satisfy |
158 // We need to satisfy |
| 133 // |
159 // |
| 134 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 |
160 // τσ_dM(1-σ_p L_z)/(1 - τ L) + [σ_p L_z + σ_pσ_d‖K_z‖^2] < 1 |
| 135 // ^^^^^^^^^^^^^^^^^^^^^^^^^ |
161 // ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 136 // with 1 > σ_p L_z and 1 > τ L. |
162 // with 1 > σ_p L_z and 1 > τ L. |
| 137 // |
163 // |
| 138 // To do so, we first solve σ_p and σ_d from standard PDPS step length condition |
164 // To do so, we first solve σ_p and σ_d from standard PDPS step length condition |
| 139 // ^^^^^ < 1. then we solve τ from the rest. |
165 // ^^^^^ < 1. then we solve τ from the rest. |
| 140 let σ_d = config.σd0 / nKz; |
166 // If opKZ is the zero operator, then we set σ_d = 0 for τ to be calculated correctly below. |
| |
167 let σ_d = if nKz == 0.0 { 0.0 } else { config.σd0 / nKz }; |
| 141 let σ_p = config.σp0 / (l_z + config.σd0 * nKz); |
168 let σ_p = config.σp0 / (l_z + config.σd0 * nKz); |
| 142 // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} |
169 // Observe that = 1 - ^^^^^^^^^^^^^^^^^^^^^ = 1 - σ_{p,0} |
| 143 // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) |
170 // We get the condition τσ_d M (1-σ_p L_z) < (1-σ_{p,0})*(1-τ L) |
| 144 // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) |
171 // ⟺ τ [ σ_d M (1-σ_p L_z) + (1-σ_{p,0}) L ] < (1-σ_{p,0}) |
| 145 let φ = 1.0 - config.σp0; |
172 let φ = 1.0 - config.σp0; |
| 146 let a = 1.0 - σ_p * l_z; |
173 let a = 1.0 - σ_p * l_z; |
| 147 let τ = config.τ0 * φ / ( σ_d * bigM * a + φ * l ); |
174 let τ = config.τ0 * φ / (σ_d * bigM * a + φ * l); |
| 148 // Acceleration is not currently supported |
175 // Acceleration is not currently supported |
| 149 // let γ = dataterm.factor_of_strong_convexity(); |
176 // let γ = dataterm.factor_of_strong_convexity(); |
| 150 let ω = 1.0; |
177 let ω = 1.0; |
| 151 |
178 |
| 152 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
179 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled |
| 155 let mut ε = tolerance.initial(); |
182 let mut ε = tolerance.initial(); |
| 156 |
183 |
| 157 let starH = fnH.conjugate(); |
184 let starH = fnH.conjugate(); |
| 158 |
185 |
| 159 // Statistics |
186 // Statistics |
| 160 let full_stats = |residual : &A::Observable, μ : &RNDM<F, N>, z : &Z, ε, stats| IterInfo { |
187 let full_stats = |μ: &RNDM<N, F>, z: &Z, ε, stats| IterInfo { |
| 161 value : residual.norm2_squared_div2() + fnR.apply(z) |
188 value: f.apply(Pair(μ, z)) |
| 162 + reg.apply(μ) + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)), |
189 + fnR.apply(z) |
| 163 n_spikes : μ.len(), |
190 + reg.apply(μ) |
| |
191 + fnH.apply(/* opKμ.apply(μ) + */ opKz.apply(z)), |
| |
192 n_spikes: μ.len(), |
| 164 ε, |
193 ε, |
| 165 // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), |
194 // postprocessing: config.insertion.postprocessing.then(|| μ.clone()), |
| 166 .. stats |
195 ..stats |
| 167 }; |
196 }; |
| 168 let mut stats = IterInfo::new(); |
197 let mut stats = IterInfo::new(); |
| 169 |
198 |
| 170 // Run the algorithm |
199 // Run the algorithm |
| 171 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { |
200 for state in iterator.iter_init(|| full_stats(&μ, &z, ε, stats.clone())) { |
| 172 // Calculate initial transport |
201 // Calculate initial transport |
| 173 let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ); |
202 let Pair(mut τv, τz) = f.differential(Pair(&μ, &z)); |
| 174 let μ_base = μ.clone(); |
203 let μ_base = μ.clone(); |
| 175 |
204 |
| 176 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
205 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. |
| 177 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
206 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
| 178 &mut μ, &mut τv, &μ_base, None, |
207 &mut μ, |
| 179 τ, ε, &config.insertion, |
208 &mut τv, |
| 180 ®, &state, &mut stats, |
209 &μ_base, |
| 181 ); |
210 None, |
| |
211 τ, |
| |
212 ε, |
| |
213 &config.insertion, |
| |
214 ®, |
| |
215 &state, |
| |
216 &mut stats, |
| |
217 )?; |
| 182 |
218 |
| 183 // Merge spikes. |
219 // Merge spikes. |
| 184 // This crucially expects the merge routine to be stable with respect to spike locations, |
220 // This crucially expects the merge routine to be stable with respect to spike locations, |
| 185 // and not to performing any pruning. That is be to done below simultaneously for γ. |
221 // and not to performing any pruning. That is be to done below simultaneously for γ. |
| 186 // Merge spikes. |
222 // Merge spikes. |
| 187 // This crucially expects the merge routine to be stable with respect to spike locations, |
223 // This crucially expects the merge routine to be stable with respect to spike locations, |
| 188 // and not to performing any pruning. That is be to done below simultaneously for γ. |
224 // and not to performing any pruning. That is be to done below simultaneously for γ. |
| 189 let ins = &config.insertion; |
225 let ins = &config.insertion; |
| 190 if ins.merge_now(&state) { |
226 if ins.merge_now(&state) { |
| 191 stats.merged += prox_penalty.merge_spikes_no_fitness( |
227 stats.merged += prox_penalty.merge_spikes_no_fitness( |
| 192 &mut μ, &mut τv, &μ_base, None, τ, ε, ins, ®, |
228 &mut μ, &mut τv, &μ_base, None, τ, ε, ins, |
| 193 //Some(|μ̃ : &RNDM<F, N>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), |
229 ®, |
| |
230 //Some(|μ̃ : &RNDM<N, F>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), |
| 194 ); |
231 ); |
| 195 } |
232 } |
| 196 |
233 |
| 197 // Prune spikes with zero weight. |
234 // Prune spikes with zero weight. |
| 198 stats.pruned += prune_with_stats(&mut μ); |
235 stats.pruned += prune_with_stats(&mut μ); |
| 199 |
236 |
| 200 // Do z variable primal update |
237 // Do z variable primal update |
| 201 let mut z_new = τz; |
238 let mut z_new = τz; |
| 202 opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ); |
239 opKz_adj.gemv(&mut z_new, -σ_p, &y, -σ_p / τ); |
| 203 z_new = fnR.prox(σ_p, z_new + &z); |
240 z_new = fnR.prox(σ_p, z_new + &z); |
| 204 // Do dual update |
241 // Do dual update |
| 205 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] |
242 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] |
| 206 opKz.gemv(&mut y, σ_d*(1.0 + ω), &z_new, 1.0); |
243 opKz.gemv(&mut y, σ_d * (1.0 + ω), &z_new, 1.0); |
| 207 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
244 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
| 208 opKz.gemv(&mut y, -σ_d*ω, z, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
245 opKz.gemv(&mut y, -σ_d * ω, z, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b |
| 209 y = starH.prox(σ_d, y); |
246 y = starH.prox(σ_d, y); |
| 210 z = z_new; |
247 z = z_new; |
| 211 |
|
| 212 // Update residual |
|
| 213 residual = calculate_residual(Pair(&μ, &z), opA, b); |
|
| 214 |
248 |
| 215 // Update step length parameters |
249 // Update step length parameters |
| 216 // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
250 // let ω = pdpsconfig.acceleration.accelerate(&mut τ, &mut σ, γ); |
| 217 |
251 |
| 218 // Give statistics if requested |
252 // Give statistics if requested |
| 219 let iter = state.iteration(); |
253 let iter = state.iteration(); |
| 220 stats.this_iters += 1; |
254 stats.this_iters += 1; |
| 221 |
255 |
| 222 state.if_verbose(|| { |
256 state.if_verbose(|| { |
| 223 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); |
257 plotter.plot_spikes(iter, maybe_d.as_ref(), Some(&τv), &μ); |
| 224 full_stats(&residual, &μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) |
258 full_stats(&μ, &z, ε, std::mem::replace(&mut stats, IterInfo::new())) |
| 225 }); |
259 }); |
| 226 |
260 |
| 227 // Update main tolerance for next iteration |
261 // Update main tolerance for next iteration |
| 228 ε = tolerance.update(ε, iter); |
262 ε = tolerance.update(ε, iter); |
| 229 } |
263 } |
| 230 |
264 |
| 231 let fit = |μ̃ : &RNDM<F, N>| { |
265 let fit = |μ̃: &RNDM<N, F>| { |
| 232 (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2() |
266 f.apply(Pair(μ̃, &z)) /*+ fnR.apply(z) + reg.apply(μ)*/ |
| 233 //+ fnR.apply(z) + reg.apply(μ) |
|
| 234 + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) |
267 + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) |
| 235 }; |
268 }; |
| 236 |
269 |
| 237 μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v); |
270 μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v); |
| 238 μ.prune(); |
271 μ.prune(); |
| 239 Pair(μ, z) |
272 Ok(Pair(μ, z)) |
| 240 } |
273 } |
| |
274 |
| |
275 /// Iteratively solve the pointsource localisation with an additional variable |
| |
276 /// using forward-backward splitting. |
| |
277 /// |
| |
278 /// The implementation uses [`pointsource_forward_pdps_pair`] with appropriate dummy |
| |
279 /// variables, operators, and functions. |
| |
280 #[replace_float_literals(F::cast_from(literal))] |
| |
281 pub fn pointsource_fb_pair<F, I, S, Dat, Reg, P, Z, R, Plot, const N: usize>( |
| |
282 f: &Dat, |
| |
283 reg: &Reg, |
| |
284 prox_penalty: &P, |
| |
285 config: &FBConfig<F>, |
| |
286 iterator: I, |
| |
287 plotter: Plot, |
| |
288 (μ0, z): (Option<RNDM<N, F>>, Z), |
| |
289 //opKμ : KOpM, |
| |
290 fnR: &R, |
| |
291 ) -> DynResult<MeasureZ<F, Z, N>> |
| |
292 where |
| |
293 F: Float + ToNalgebraRealField, |
| |
294 I: AlgIteratorFactory<IterInfo<F>>, |
| |
295 Dat: DifferentiableMapping<MeasureZ<F, Z, N>, Codomain = F, DerivativeDomain = Pair<S, Z>>, |
| |
296 S: DifferentiableRealMapping<N, F> + ClosedMul<F>, |
| |
297 RNDM<N, F>: SpikeMerging<F>, |
| |
298 Reg: RegTerm<Loc<N, F>, F>, |
| |
299 P: ProxPenalty<Loc<N, F>, S, Reg, F>, |
| |
300 for<'a> Pair<&'a P, &'a IdOp<Z>>: StepLengthBoundPair<F, Dat>, |
| |
301 Z: ClosedEuclidean<F> + AXPY<Field = F> + Clone, |
| |
302 for<'b> &'b Z: Instance<Z>, |
| |
303 R: Prox<Z, Codomain = F>, |
| |
304 Plot: Plotter<P::ReturnMapping, S, RNDM<N, F>>, |
| |
305 // We should not need to explicitly require this: |
| |
306 for<'b> &'b Loc<0, F>: Instance<Loc<0, F>>, |
| |
307 { |
| |
308 let opKz = ZeroOp::new_dualisable(Loc([]), z.dual_origin()); |
| |
309 let fnH = Zero::new(); |
| |
310 // Convert config. We don't implement From (that could be done with the o2o crate), as σd0 |
| |
311 // needs to be chosen in a general case; for the problem of this fucntion, anything is valid. |
| |
312 let &FBConfig { τ0, σp0, insertion } = config; |
| |
313 let pdps_config = ForwardPDPSConfig { τ0, σp0, insertion, σd0: 0.0 }; |
| |
314 |
| |
315 pointsource_forward_pdps_pair( |
| |
316 f, |
| |
317 reg, |
| |
318 prox_penalty, |
| |
319 &pdps_config, |
| |
320 iterator, |
| |
321 plotter, |
| |
322 (μ0, z, Loc([])), |
| |
323 &opKz, |
| |
324 fnR, |
| |
325 &fnH, |
| |
326 ) |
| |
327 } |