src/sliding_fb.rs

branch
dev
changeset 68
00d0881f89a6
parent 63
7a8a55fd41c0
--- a/src/sliding_fb.rs	Thu Mar 19 18:21:17 2026 -0500
+++ b/src/sliding_fb.rs	Mon Apr 13 22:29:26 2026 -0500
@@ -40,6 +40,8 @@
     pub tolerance_mult_con: F,
     /// maximum number of adaptation iterations, until cancelling transport.
     pub max_attempts: usize,
+    /// Maximum number of failed transportations for a single source point
+    pub max_fail: usize,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
@@ -56,7 +58,13 @@
 #[replace_float_literals(F::cast_from(literal))]
 impl<F: Float> Default for TransportConfig<F> {
     fn default() -> Self {
-        TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0, max_attempts: 2 }
+        TransportConfig {
+            θ0: 0.9,
+            adaptation: 0.9,
+            tolerance_mult_con: 100.0,
+            max_attempts: 2,
+            max_fail: usize::MAX,
+        }
     }
 }
 
@@ -114,6 +122,8 @@
     α_γ: F,
     /// Helper for pruning
     prune: bool,
+    /// Fail count
+    fail_count: usize,
 }
 
 #[derive(Clone, Debug, Serialize)]
@@ -207,6 +217,7 @@
         _τ: F,
         τθ_or_adaptive: &mut TransportStepLength<F, G>,
         v: D,
+        tconfig: &TransportConfig<F>,
     ) where
         G: Fn(F, F) -> F,
         D: DifferentiableRealMapping<N, F>,
@@ -217,11 +228,15 @@
         for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) {
             ρ.α_μ_orig = δ.α;
             ρ.x = δ.x;
-            // If old transport has opposing sign, the new transport will be none.
-            ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) {
-                0.0
+            if ρ.fail_count > tconfig.max_fail {
+                ρ.α_γ = 0.0
             } else {
-                δ.α
+                // If old transport has opposing sign, the new transport will be none.
+                ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) {
+                    0.0
+                } else {
+                    δ.α
+                }
             }
         }
 
@@ -233,20 +248,25 @@
             α_μ_orig: δ.α,
             α_γ: δ.α,
             prune: false,
+            fail_count: 0,
         }));
 
         // Calculate transport rays.
         match *τθ_or_adaptive {
             Fixed(θ) => {
                 for ρ in self.iter_mut() {
-                    ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ);
+                    if ρ.fail_count <= tconfig.max_fail {
+                        ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ);
+                    }
                 }
             }
             AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θτ } => {
                 *max_transport = max_transport.max(self.norm(Radon));
                 let θτ = calculate_θτ(ℓ_F, *max_transport);
                 for ρ in self.iter_mut() {
-                    ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ);
+                    if ρ.fail_count <= tconfig.max_fail {
+                        ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ);
+                    }
                 }
             }
             FullyAdaptive {
@@ -261,16 +281,18 @@
                 for _i in 0..=1 {
                     let mut changes = false;
                     for ρ in self.iter_mut() {
-                        let dv_x = v.differential(&ρ.x);
-                        let g = &dv_x * (ρ.α_γ.signum() * θτ);
-                        ρ.y = ρ.x - g;
-                        let n = g.norm2();
-                        if n >= F::EPSILON {
-                            // Estimate Lipschitz factor of ∇v
-                            let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n;
-                            *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F);
-                            θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport);
-                            changes = true
+                        if ρ.fail_count < tconfig.max_fail {
+                            let dv_x = v.differential(&ρ.x);
+                            let g = &dv_x * (ρ.α_γ.signum() * θτ);
+                            ρ.y = ρ.x - g;
+                            let n = g.norm2();
+                            if n >= F::EPSILON {
+                                // Estimate Lipschitz factor of ∇v
+                                let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n;
+                                *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F);
+                                θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport);
+                                changes = true
+                            }
                         }
                     }
                     if !changes {
@@ -431,6 +453,14 @@
             }
         }
 
+        for ρ in self.iter_mut() {
+            if ρ.α_γ == 0.0 {
+                ρ.fail_count += 1;
+            } else if all_ok {
+                ρ.fail_count = 0;
+            }
+        }
+
         all_ok
     }
 
@@ -597,7 +627,7 @@
     for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate initial transport
         let v = f.differential(&μ);
-        γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v);
+        γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v, &config.transport);
 
         let mut attempts = 0;
 

mercurial