src/forward_pdps.rs

branch
dev
changeset 39
6316d68b58af
parent 37
c5d8bd1a7728
--- a/src/forward_pdps.rs	Thu Jan 23 23:35:28 2025 +0100
+++ b/src/forward_pdps.rs	Thu Jan 23 23:34:05 2025 +0100
@@ -51,10 +51,9 @@
 #[replace_float_literals(F::cast_from(literal))]
 impl<F : Float> Default for ForwardPDPSConfig<F> {
     fn default() -> Self {
-        let τ0 = 0.99;
         ForwardPDPSConfig {
-            τ0,
-            σd0 : 0.1,
+            τ0 : 0.99,
+            σd0 : 0.05,
             σp0 : 0.99,
             insertion : Default::default()
         }
@@ -172,7 +171,6 @@
     for state in iterator.iter_init(|| full_stats(&residual, &μ, &z, ε, stats.clone())) {
         // Calculate initial transport
         let Pair(mut τv, τz) = opA.preadjoint().apply(residual * τ);
-        let z_base = z.clone();
         let μ_base = μ.clone();
 
         // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
@@ -182,34 +180,34 @@
             &reg, &state, &mut stats,
         );
 
-        // // Merge spikes.
-        // // This expects the prune below to prune γ.
-        // // TODO: This may not work correctly in all cases.
-        // let ins = &config.insertion;
-        // if ins.merge_now(&state) {
-        //     if let SpikeMergingMethod::None = ins.merging {
-        //     } else {
-        //         stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| {
-        //             let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0;
-        //             let mut d = &τv̆ + op𝒟.preapply(ν);
-        //             reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins)
-        //         });
-        //     }
-        // }
+        // Merge spikes.
+        // This crucially expects the merge routine to be stable with respect to spike locations,
+        // and not to performing any pruning. That is be to done below simultaneously for γ.
+        // Merge spikes.
+        // This crucially expects the merge routine to be stable with respect to spike locations,
+        // and not to performing any pruning. That is be to done below simultaneously for γ.
+        let ins = &config.insertion;
+        if ins.merge_now(&state) {
+            stats.merged += prox_penalty.merge_spikes_no_fitness(
+                &mut μ, &mut τv, &μ_base, None, τ, ε, ins, &reg,
+                //Some(|μ̃ : &RNDM<F, N>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()),
+            );
+        }
 
         // Prune spikes with zero weight.
         stats.pruned += prune_with_stats(&mut μ);
 
         // Do z variable primal update
-        z.axpy(-σ_p/τ, τz, 1.0); // TODO: simplify nasty factors
-        opKz.adjoint().gemv(&mut z, -σ_p, &y, 1.0);
-        z = fnR.prox(σ_p, z);
+        let mut z_new = τz;
+        opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ);
+        z_new = fnR.prox(σ_p, z_new + &z);
         // Do dual update
         // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0);    // y = y + σ_d K[(1+ω)(μ,z)^{k+1}]
-        opKz.gemv(&mut y, σ_d*(1.0 + ω), &z, 1.0);
+        opKz.gemv(&mut y, σ_d*(1.0 + ω), &z_new, 1.0);
         // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
-        opKz.gemv(&mut y, -σ_d*ω, z_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
+        opKz.gemv(&mut y, -σ_d*ω, z, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
         y = starH.prox(σ_d, y);
+        z = z_new;
 
         // Update residual
         residual = calculate_residual(Pair(&μ, &z), opA, b);
@@ -236,7 +234,7 @@
         + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z))
     };
 
-    μ.merge_spikes_fitness(config.insertion.merging, fit, |&v| v);
+    μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v);
     μ.prune();
     Pair(μ, z)
 }

mercurial