src/sliding_fb.rs

branch
dev
changeset 62
32328a74c790
parent 61
4f468d35fa29
--- a/src/sliding_fb.rs	Thu Feb 26 11:38:43 2026 -0500
+++ b/src/sliding_fb.rs	Fri Jan 16 19:39:22 2026 -0500
@@ -20,6 +20,7 @@
 use crate::types::*;
 use alg_tools::error::DynResult;
 use alg_tools::euclidean::Euclidean;
+use alg_tools::instance::{ClosedSpace, Instance};
 use alg_tools::iterate::AlgIteratorFactory;
 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping};
 use alg_tools::nalgebra_support::ToNalgebraRealField;
@@ -70,6 +71,15 @@
     pub insertion: InsertionConfig<F>,
     /// Guess for curvature bound calculations.
     pub guess: BoundedCurvatureGuess,
+    /// Always adaptive step length
+    pub always_adaptive_τ: bool,
+}
+
+impl<'a, F: Float> Into<FBConfig<F>> for &'a SlidingFBConfig<F> {
+    fn into(self) -> FBConfig<F> {
+        let SlidingFBConfig { τ0, σp0, insertion, always_adaptive_τ, .. } = *self;
+        FBConfig { τ0, σp0, insertion, always_adaptive_τ }
+    }
 }
 
 #[replace_float_literals(F::cast_from(literal))]
@@ -81,6 +91,7 @@
             transport: Default::default(),
             insertion: Default::default(),
             guess: BoundedCurvatureGuess::BetterThanZero,
+            always_adaptive_τ: false,
         }
     }
 }
@@ -271,12 +282,13 @@
     F: Float + ToNalgebraRealField,
     I: AlgIteratorFactory<IterInfo<F>>,
     Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>,
-    Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>,
-    //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>,
+    Dat::DerivativeDomain:
+        ClosedMul<F> + DifferentiableRealMapping<N, F, Codomain = F> + ClosedSpace,
     RNDM<N, F>: SpikeMerging<F>,
     Reg: SlidingRegTerm<Loc<N, F>, F>,
     P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
     Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
+    for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>,
 {
     // Check parameters
     ensure!(config.τ0 > 0.0, "Invalid step length parameter");
@@ -292,7 +304,9 @@
     //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0);
     //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
     let ℓ = 0.0;
-    let τ = config.τ0 / prox_penalty.step_length_bound(&f)?;
+    let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, &config.into())?;
+    let τ = adaptive_τ.current();
+    println!("TODO: τ in calculate_θ should be adaptive");
     let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess);
     let transport_lip = maybe_transport_lip?;
     let calculate_θ = |ℓ_F, max_transport| {
@@ -330,7 +344,9 @@
     // Run the algorithm
     for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate initial transport
-        let v = f.differential(&μ);
+        let (fμ, v) = f.apply_and_differential(&μ);
+        let τ = adaptive_τ.update(&μ, fμ, &v);
+
         let (μ_base_masses, mut μ_base_minus_γ0) =
             initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v);
 
@@ -372,6 +388,10 @@
             }
         };
 
+        // We don't treat merge in adaptive Lipschitz.
+        println!("WARNING: finish_step does not work with sliding");
+        adaptive_τ.finish_step(&μ);
+
         stats.untransported_fraction = Some({
             assert_eq!(μ_base_masses.len(), γ1.len());
             let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));

mercurial