diff -r c5d8bd1a7728 -r 6316d68b58af src/forward_pdps.rs --- a/src/forward_pdps.rs Thu Jan 23 23:35:28 2025 +0100 +++ b/src/forward_pdps.rs Thu Jan 23 23:34:05 2025 +0100 @@ -51,10 +51,9 @@ #[replace_float_literals(F::cast_from(literal))] impl Default for ForwardPDPSConfig { fn default() -> Self { - let τ0 = 0.99; ForwardPDPSConfig { - τ0, - σd0 : 0.1, + τ0 : 0.99, + σd0 : 0.05, σp0 : 0.99, insertion : Default::default() } @@ -172,7 +171,6 @@ for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) { // Calculate initial transport let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ); - let z_base = z.clone(); let μ_base = μ.clone(); // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes. @@ -182,34 +180,34 @@ ®, &state, &mut stats, ); - // // Merge spikes. - // // This expects the prune below to prune γ. - // // TODO: This may not work correctly in all cases. - // let ins = &config.insertion; - // if ins.merge_now(&state) { - // if let SpikeMergingMethod::None = ins.merging { - // } else { - // stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| { - // let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0; - // let mut d = &τv̆ + op𝒟.preapply(ν); - // reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins) - // }); - // } - // } + // Merge spikes. + // This crucially expects the merge routine to be stable with respect to spike locations, + // and not to performing any pruning. That is be to done below simultaneously for γ. + // Merge spikes. + // This crucially expects the merge routine to be stable with respect to spike locations, + // and not to performing any pruning. That is be to done below simultaneously for γ. + let ins = &config.insertion; + if ins.merge_now(&state) { + stats.merged += prox_penalty.merge_spikes_no_fitness( + &mut μ, &mut τv, &μ_base, None, τ, ε, ins, ®, + //Some(|μ̃ : &RNDM| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()), + ); + } // Prune spikes with zero weight. stats.pruned += prune_with_stats(&mut μ); // Do z variable primal update - z.axpy(-σ_p/τ, τz, 1.0); // TODO: simplify nasty factors - opKz.adjoint().gemv(&mut z, -σ_p, &y, 1.0); - z = fnR.prox(σ_p, z); + let mut z_new = τz; + opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ); + z_new = fnR.prox(σ_p, z_new + &z); // Do dual update // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0); // y = y + σ_d K[(1+ω)(μ,z)^{k+1}] - opKz.gemv(&mut y, σ_d*(1.0 + ω), &z, 1.0); + opKz.gemv(&mut y, σ_d*(1.0 + ω), &z_new, 1.0); // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b - opKz.gemv(&mut y, -σ_d*ω, z_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b + opKz.gemv(&mut y, -σ_d*ω, z, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b y = starH.prox(σ_d, y); + z = z_new; // Update residual residual = calculate_residual(Pair(&μ, &z), opA, b); @@ -236,7 +234,7 @@ + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z)) }; - μ.merge_spikes_fitness(config.insertion.merging, fit, |&v| v); + μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v); μ.prune(); Pair(μ, z) }