src/forward_pdps.rs

branch
dev
changeset 39
6316d68b58af
parent 37
c5d8bd1a7728
equal deleted inserted replaced
37:c5d8bd1a7728 39:6316d68b58af
49 } 49 }
50 50
51 #[replace_float_literals(F::cast_from(literal))] 51 #[replace_float_literals(F::cast_from(literal))]
52 impl<F : Float> Default for ForwardPDPSConfig<F> { 52 impl<F : Float> Default for ForwardPDPSConfig<F> {
53 fn default() -> Self { 53 fn default() -> Self {
54 let τ0 = 0.99;
55 ForwardPDPSConfig { 54 ForwardPDPSConfig {
56 τ0, 55 τ0 : 0.99,
57 σd0 : 0.1, 56 σd0 : 0.05,
58 σp0 : 0.99, 57 σp0 : 0.99,
59 insertion : Default::default() 58 insertion : Default::default()
60 } 59 }
61 } 60 }
62 } 61 }
170 169
171 // Run the algorithm 170 // Run the algorithm
172 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { 171 for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) {
173 // Calculate initial transport 172 // Calculate initial transport
174 let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ); 173 let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ);
175 let z_base = z.clone();
176 let μ_base = μ.clone(); 174 let μ_base = μ.clone();
177 175
178 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. 176 // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
179 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 177 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh(
180 &mut μ, &mut τv, &μ_base, None, 178 &mut μ, &mut τv, &μ_base, None,
181 τ, ε, &config.insertion, 179 τ, ε, &config.insertion,
182 &reg, &state, &mut stats, 180 &reg, &state, &mut stats,
183 ); 181 );
184 182
185 // // Merge spikes. 183 // Merge spikes.
186 // // This expects the prune below to prune γ. 184 // This crucially expects the merge routine to be stable with respect to spike locations,
187 // // TODO: This may not work correctly in all cases. 185 // and not to performing any pruning. That is be to done below simultaneously for γ.
188 // let ins = &config.insertion; 186 // Merge spikes.
189 // if ins.merge_now(&state) { 187 // This crucially expects the merge routine to be stable with respect to spike locations,
190 // if let SpikeMergingMethod::None = ins.merging { 188 // and not to performing any pruning. That is be to done below simultaneously for γ.
191 // } else { 189 let ins = &config.insertion;
192 // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { 190 if ins.merge_now(&state) {
193 // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; 191 stats.merged += prox_penalty.merge_spikes_no_fitness(
194 // let mut d = &τv̆ + op𝒟.preapply(ν); 192 &mut μ, &mut τv, &μ_base, None, τ, ε, ins, &reg,
195 // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) 193 //Some(|μ̃ : &RNDM<F, N>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()),
196 // }); 194 );
197 // } 195 }
198 // }
199 196
200 // Prune spikes with zero weight. 197 // Prune spikes with zero weight.
201 stats.pruned += prune_with_stats(&mut μ); 198 stats.pruned += prune_with_stats(&mut μ);
202 199
203 // Do z variable primal update 200 // Do z variable primal update
204 z.axpy(-σ_p/τ, τz, 1.0); // TODO: simplify nasty factors 201 let mut z_new = τz;
205 opKz.adjoint().gemv(&mut z, -σ_p, &y, 1.0); 202 opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ);
206 z = fnR.prox(σ_p, z); 203 z_new = fnR.prox(σ_p, z_new + &z);
207 // Do dual update 204 // Do dual update
208 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] 205 // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}]
209 opKz.gemv(&mut y, σ_d*(1.0 + ω), &z, 1.0); 206 opKz.gemv(&mut y, σ_d*(1.0 + ω), &z_new, 1.0);
210 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b 207 // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
211 opKz.gemv(&mut y, -σ_d*ω, z_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b 208 opKz.gemv(&mut y, -σ_d*ω, z, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
212 y = starH.prox(σ_d, y); 209 y = starH.prox(σ_d, y);
210 z = z_new;
213 211
214 // Update residual 212 // Update residual
215 residual = calculate_residual(Pair(&μ, &z), opA, b); 213 residual = calculate_residual(Pair(&μ, &z), opA, b);
216 214
217 // Update step length parameters 215 // Update step length parameters
234 (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2() 232 (opA.apply(Pair(μ̃, &z))-b).norm2_squared_div2()
235 //+ fnR.apply(z) + reg.apply(μ) 233 //+ fnR.apply(z) + reg.apply(μ)
236 + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) 234 + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z))
237 }; 235 };
238 236
239 μ.merge_spikes_fitness(config.insertion.merging, fit, |&v| v); 237 μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v);
240 μ.prune(); 238 μ.prune();
241 Pair(μ, z) 239 Pair(μ, z)
242 } 240 }

mercurial