--- a/src/sliding_fb.rs Thu Feb 26 11:38:43 2026 -0500 +++ b/src/sliding_fb.rs Fri Jan 16 19:39:22 2026 -0500 @@ -20,6 +20,7 @@ use crate::types::*; use alg_tools::error::DynResult; use alg_tools::euclidean::Euclidean; +use alg_tools::instance::{ClosedSpace, Instance}; use alg_tools::iterate::AlgIteratorFactory; use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping}; use alg_tools::nalgebra_support::ToNalgebraRealField; @@ -70,6 +71,15 @@ pub insertion: InsertionConfig<F>, /// Guess for curvature bound calculations. pub guess: BoundedCurvatureGuess, + /// Always adaptive step length + pub always_adaptive_τ: bool, +} + +impl<'a, F: Float> Into<FBConfig<F>> for &'a SlidingFBConfig<F> { + fn into(self) -> FBConfig<F> { + let SlidingFBConfig { τ0, σp0, insertion, always_adaptive_τ, .. } = *self; + FBConfig { τ0, σp0, insertion, always_adaptive_τ } + } } #[replace_float_literals(F::cast_from(literal))] @@ -81,6 +91,7 @@ transport: Default::default(), insertion: Default::default(), guess: BoundedCurvatureGuess::BetterThanZero, + always_adaptive_τ: false, } } } @@ -271,12 +282,13 @@ F: Float + ToNalgebraRealField, I: AlgIteratorFactory<IterInfo<F>>, Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>, - Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>, - //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>, + Dat::DerivativeDomain: + ClosedMul<F> + DifferentiableRealMapping<N, F, Codomain = F> + ClosedSpace, RNDM<N, F>: SpikeMerging<F>, Reg: SlidingRegTerm<Loc<N, F>, F>, P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, + for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>, { // Check parameters ensure!(config.τ0 > 0.0, "Invalid step length parameter"); @@ -292,7 +304,9 @@ // * reg.radon_norm_bound(b.norm2_squared() / 2.0); //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; let ℓ = 0.0; - let τ = config.τ0 / prox_penalty.step_length_bound(&f)?; + let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, &config.into())?; + let τ = adaptive_τ.current(); + println!("TODO: τ in calculate_θ should be adaptive"); let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); let transport_lip = maybe_transport_lip?; let calculate_θ = |ℓ_F, max_transport| { @@ -330,7 +344,9 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate initial transport - let v = f.differential(&μ); + let (fμ, v) = f.apply_and_differential(&μ); + let τ = adaptive_τ.update(&μ, fμ, &v); + let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); @@ -372,6 +388,10 @@ } }; + // We don't treat merge in adaptive Lipschitz. + println!("WARNING: finish_step does not work with sliding"); + adaptive_τ.finish_step(&μ); + stats.untransported_fraction = Some({ assert_eq!(μ_base_masses.len(), γ1.len()); let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));