Mon, 13 Apr 2026 22:29:26 -0500
Automatic transport disabling after sufficient failures, for efficiency
| src/sliding_fb.rs | file | annotate | diff | comparison | revisions | |
| src/sliding_pdps.rs | file | annotate | diff | comparison | revisions |
--- a/src/sliding_fb.rs Thu Mar 19 18:21:17 2026 -0500 +++ b/src/sliding_fb.rs Mon Apr 13 22:29:26 2026 -0500 @@ -40,6 +40,8 @@ pub tolerance_mult_con: F, /// maximum number of adaptation iterations, until cancelling transport. pub max_attempts: usize, + /// Maximum number of failed transportations for a single source point + pub max_fail: usize, } #[replace_float_literals(F::cast_from(literal))] @@ -56,7 +58,13 @@ #[replace_float_literals(F::cast_from(literal))] impl<F: Float> Default for TransportConfig<F> { fn default() -> Self { - TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0, max_attempts: 2 } + TransportConfig { + θ0: 0.9, + adaptation: 0.9, + tolerance_mult_con: 100.0, + max_attempts: 2, + max_fail: usize::MAX, + } } } @@ -114,6 +122,8 @@ α_γ: F, /// Helper for pruning prune: bool, + /// Fail count + fail_count: usize, } #[derive(Clone, Debug, Serialize)] @@ -207,6 +217,7 @@ _τ: F, τθ_or_adaptive: &mut TransportStepLength<F, G>, v: D, + tconfig: &TransportConfig<F>, ) where G: Fn(F, F) -> F, D: DifferentiableRealMapping<N, F>, @@ -217,11 +228,15 @@ for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) { ρ.α_μ_orig = δ.α; ρ.x = δ.x; - // If old transport has opposing sign, the new transport will be none. - ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) { - 0.0 + if ρ.fail_count > tconfig.max_fail { + ρ.α_γ = 0.0 } else { - δ.α + // If old transport has opposing sign, the new transport will be none. + ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) { + 0.0 + } else { + δ.α + } } } @@ -233,20 +248,25 @@ α_μ_orig: δ.α, α_γ: δ.α, prune: false, + fail_count: 0, })); // Calculate transport rays. match *τθ_or_adaptive { Fixed(θ) => { for ρ in self.iter_mut() { - ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ); + if ρ.fail_count <= tconfig.max_fail { + ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ); + } } } AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θτ } => { *max_transport = max_transport.max(self.norm(Radon)); let θτ = calculate_θτ(ℓ_F, *max_transport); for ρ in self.iter_mut() { - ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ); + if ρ.fail_count <= tconfig.max_fail { + ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ); + } } } FullyAdaptive { @@ -261,16 +281,18 @@ for _i in 0..=1 { let mut changes = false; for ρ in self.iter_mut() { - let dv_x = v.differential(&ρ.x); - let g = &dv_x * (ρ.α_γ.signum() * θτ); - ρ.y = ρ.x - g; - let n = g.norm2(); - if n >= F::EPSILON { - // Estimate Lipschitz factor of ∇v - let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n; - *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F); - θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport); - changes = true + if ρ.fail_count < tconfig.max_fail { + let dv_x = v.differential(&ρ.x); + let g = &dv_x * (ρ.α_γ.signum() * θτ); + ρ.y = ρ.x - g; + let n = g.norm2(); + if n >= F::EPSILON { + // Estimate Lipschitz factor of ∇v + let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n; + *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F); + θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport); + changes = true + } } } if !changes { @@ -431,6 +453,14 @@ } } + for ρ in self.iter_mut() { + if ρ.α_γ == 0.0 { + ρ.fail_count += 1; + } else if all_ok { + ρ.fail_count = 0; + } + } + all_ok } @@ -597,7 +627,7 @@ for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate initial transport let v = f.differential(&μ); - γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v); + γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v, &config.transport); let mut attempts = 0;