src/sliding_pdps.rs

branch
dev
changeset 39
6316d68b58af
parent 37
c5d8bd1a7728
child 41
b6bdb6cb4d44
--- a/src/sliding_pdps.rs	Thu Jan 23 23:35:28 2025 +0100
+++ b/src/sliding_pdps.rs	Thu Jan 23 23:34:05 2025 +0100
@@ -12,7 +12,7 @@
 use alg_tools::iterate::AlgIteratorFactory;
 use alg_tools::euclidean::Euclidean;
 use alg_tools::mapping::{Mapping, DifferentiableRealMapping, Instance};
-use alg_tools::norms::Norm;
+use alg_tools::norms::{Norm, Dist};
 use alg_tools::direct_product::Pair;
 use alg_tools::nalgebra_support::ToNalgebraRealField;
 use alg_tools::linops::{
@@ -45,7 +45,11 @@
     initial_transport,
     aposteriori_transport,
 };
-use crate::dataterm::{calculate_residual, calculate_residual2};
+use crate::dataterm::{
+    calculate_residual2,
+    calculate_residual,
+};
+
 
 /// Settings for [`pointsource_sliding_pdps_pair`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
@@ -66,12 +70,11 @@
 #[replace_float_literals(F::cast_from(literal))]
 impl<F : Float> Default for SlidingPDPSConfig<F> {
     fn default() -> Self {
-        let τ0 = 0.99;
         SlidingPDPSConfig {
-            τ0,
-            σd0 : 0.1,
+            τ0 : 0.99,
+            σd0 : 0.05,
             σp0 : 0.99,
-            transport : Default::default(),
+            transport : TransportConfig { θ0 : 0.1, ..Default::default()},
             insertion : Default::default()
         }
     }
@@ -134,7 +137,7 @@
     for<'b> KOpZ::Adjoint<'b> : GEMV<F, Y>,
     Y : AXPY<F> + Euclidean<F, Output=Y> + Clone + ClosedAdd,
     for<'b> &'b Y : Instance<Y>,
-    Z : AXPY<F, Owned=Z> + Euclidean<F, Output=Z> + Clone + Norm<F, L2>,
+    Z : AXPY<F, Owned=Z> + Euclidean<F, Output=Z> + Clone + Norm<F, L2> + Dist<F, L2>,
     for<'b> &'b Z : Instance<Z>,
     R : Prox<Z, Codomain=F>,
     H : Conjugable<Y, F, Codomain=F>,
@@ -202,7 +205,7 @@
             g : calculate_θ
         },
         None => TransportStepLength::FullyAdaptive{
-            l : 0.0,
+            l : F::EPSILON,
             max_transport : 0.0,
             g : calculate_θ
         },
@@ -234,7 +237,6 @@
         // Calculate initial transport
         let Pair(v, _) = opA.preadjoint().apply(&residual);
         //opKμ.preadjoint().apply_add(&mut v, y);
-        let z_base = z.clone();
         // We want to proceed as in Example 4.12 but with v and v̆ as in §5.
         // With A(ν, z) = A_μ ν + A_z z, following Example 5.1, we have
         // P_ℳ[F'(ν, z) + Ξ(ν, z, y)]= A_ν^*[A_ν ν + A_z z] + K_μ ν = A_ν^*A(ν, z) + K_μ ν,
@@ -250,28 +252,33 @@
 
         // Solve finite-dimensional subproblem several times until the dual variable for the
         // regularisation term conforms to the assumptions made for the transport above.
-        let (maybe_d, _within_tolerances, Pair(τv̆, τz̆)) = 'adapt_transport: loop {
+        let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop {
             // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
             let residual_μ̆ = calculate_residual2(Pair(&γ1, &z),
                                                  Pair(&μ_base_minus_γ0, &zero_z),
                                                  opA, b);
-            let mut τv̆z = opA.preadjoint().apply(residual_μ̆ * τ);
+            let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ);
             // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0);
 
             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
             let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
-                &mut μ, &mut τv̆z.0, &γ1, Some(&μ_base_minus_γ0),
+                &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0),
                 τ, ε, &config.insertion,
                 &reg, &state, &mut stats,
             );
 
+            // Do z variable primal update here to able to estimate B_{v̆^k-v^{k+1}}
+            let mut z_new = τz̆;
+            opKz.adjoint().gemv(&mut z_new, -σ_p, &y, -σ_p/τ);
+            z_new = fnR.prox(σ_p, z_new + &z);
+
             // A posteriori transport adaptation.
-            // TODO: this does not properly treat v^{k+1} - v̆^k that depends on z^{k+1}!
             if aposteriori_transport(
                 &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses,
+                Some(z_new.dist(&z, L2)),
                 ε, &config.transport
             ) {
-                break 'adapt_transport (maybe_d, within_tolerances, τv̆z)
+                break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new)
             }
         };
 
@@ -287,20 +294,16 @@
             (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
         });
 
-        // // Merge spikes.
-        // // This expects the prune below to prune γ.
-        // // TODO: This may not work correctly in all cases.
-        // let ins = &config.insertion;
-        // if ins.merge_now(&state) {
-        //     if let SpikeMergingMethod::None = ins.merging {
-        //     } else {
-        //         stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| {
-        //             let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0;
-        //             let mut d = &τv̆ + op𝒟.preapply(ν);
-        //             reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins)
-        //         });
-        //     }
-        // }
+        // Merge spikes.
+        // This crucially expects the merge routine to be stable with respect to spike locations,
+        // and not to performing any pruning. That is be to done below simultaneously for γ.
+        let ins = &config.insertion;
+        if ins.merge_now(&state) {
+            stats.merged += prox_penalty.merge_spikes_no_fitness(
+                &mut μ, &mut τv̆, &γ1, Some(&μ_base_minus_γ0), τ, ε, ins, &reg,
+                //Some(|μ̃ : &RNDM<F, N>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()),
+            );
+        }
 
         // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
         // latter needs to be pruned when μ is.
@@ -313,16 +316,13 @@
             μ = μ_new;
         }
 
-        // Do z variable primal update
-        z.axpy(-σ_p/τ, τz̆, 1.0); // TODO: simplify nasty factors
-        opKz.adjoint().gemv(&mut z, -σ_p, &y, 1.0);
-        z = fnR.prox(σ_p, z);
         // Do dual update
         // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0);    // y = y + σ_d K[(1+ω)(μ,z)^{k+1}]
-        opKz.gemv(&mut y, σ_d*(1.0 + ω), &z, 1.0);
+        opKz.gemv(&mut y, σ_d*(1.0 + ω), &z_new, 1.0);
         // opKμ.gemv(&mut y, -σ_d*ω, μ_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
-        opKz.gemv(&mut y, -σ_d*ω, z_base, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
+        opKz.gemv(&mut y, -σ_d*ω, z, 1.0);// y = y + σ_d K[(1+ω)(μ,z)^{k+1} - ω (μ,z)^k]-b
         y = starH.prox(σ_d, y);
+        z = z_new;
 
         // Update residual
         residual = calculate_residual(Pair(&μ, &z), opA, b);
@@ -349,7 +349,7 @@
         + fnH.apply(/* opKμ.apply(&μ̃) + */ opKz.apply(&z))
     };
 
-    μ.merge_spikes_fitness(config.insertion.merging, fit, |&v| v);
+    μ.merge_spikes_fitness(config.insertion.final_merging_method(), fit, |&v| v);
     μ.prune();
     Pair(μ, z)
 }

mercurial