src/fb.rs

branch
dev
changeset 62
32328a74c790
parent 61
4f468d35fa29
--- a/src/fb.rs	Thu Feb 26 11:38:43 2026 -0500
+++ b/src/fb.rs	Fri Jan 16 19:39:22 2026 -0500
@@ -80,14 +80,16 @@
 use crate::measures::merging::SpikeMerging;
 use crate::measures::{DiscreteMeasure, RNDM};
 use crate::plot::Plotter;
+use crate::prox_penalty::StepLengthBoundValue;
 pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound};
 use crate::regularisation::RegTerm;
 use crate::types::*;
 use alg_tools::error::DynResult;
-use alg_tools::instance::Instance;
+use alg_tools::instance::{ClosedSpace, Instance};
 use alg_tools::iterate::AlgIteratorFactory;
-use alg_tools::mapping::DifferentiableMapping;
+use alg_tools::mapping::{DifferentiableMapping, Mapping};
 use alg_tools::nalgebra_support::ToNalgebraRealField;
+use anyhow::anyhow;
 use colored::Colorize;
 use numeric_literals::replace_float_literals;
 use serde::{Deserialize, Serialize};
@@ -102,16 +104,14 @@
     pub σp0: F,
     /// Generic parameters
     pub insertion: InsertionConfig<F>,
+    /// Always adaptive step length
+    pub always_adaptive_τ: bool,
 }
 
 #[replace_float_literals(F::cast_from(literal))]
 impl<F: Float> Default for FBConfig<F> {
     fn default() -> Self {
-        FBConfig {
-            τ0: 0.99,
-            σp0: 0.99,
-            insertion: Default::default(),
-        }
+        FBConfig { τ0: 0.99, σp0: 0.99, always_adaptive_τ: false, insertion: Default::default() }
     }
 }
 
@@ -122,6 +122,125 @@
     n_before_prune - μ.len()
 }
 
+/// Adaptive step length and Lipschitz parameter estimation state.
+#[derive(Clone, Debug, Serialize)]
+pub enum AdaptiveStepLength<const N: usize, F: Float> {
+    Adaptive {
+        l: F,
+        μ_old: RNDM<N, F>,
+        fμ_old: F,
+        μ_dist: F,
+        τ0: F,
+        l_is_initial: bool,
+    },
+    Fixed {
+        τ: F,
+    },
+}
+
+#[replace_float_literals(F::cast_from(literal))]
+impl<const N: usize, F: Float> AdaptiveStepLength<N, F> {
+    pub fn new<Dat, Reg, P>(f: &Dat, prox_penalty: &P, fbconfig: &FBConfig<F>) -> DynResult<Self>
+    where
+        F: ToNalgebraRealField,
+        Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
+        P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
+        Reg: RegTerm<Loc<N, F>, F>,
+    {
+        match (
+            prox_penalty.step_length_bound(&f),
+            fbconfig.always_adaptive_τ,
+        ) {
+            (StepLengthBoundValue::LipschitzFactor(l), false) => {
+                Ok(AdaptiveStepLength::Fixed { τ: fbconfig.τ0 / l })
+            }
+            (StepLengthBoundValue::LipschitzFactor(l), true) => {
+                let μ_old = DiscreteMeasure::new();
+                let fμ_old = f.apply(&μ_old);
+                Ok(AdaptiveStepLength::Adaptive {
+                    l: l,
+                    μ_old,
+                    fμ_old,
+                    μ_dist: 0.0,
+                    τ0: fbconfig.τ0,
+                    l_is_initial: false,
+                })
+            }
+            (StepLengthBoundValue::UnreliableLipschitzFactor(l), _) => {
+                println!("Lipschitz factor is unreliable; calculating adaptively.");
+                let μ_old = DiscreteMeasure::new();
+                let fμ_old = f.apply(&μ_old);
+                Ok(AdaptiveStepLength::Adaptive {
+                    l: l,
+                    μ_old,
+                    fμ_old,
+                    μ_dist: 0.0,
+                    τ0: fbconfig.τ0,
+                    l_is_initial: true,
+                })
+            }
+            (StepLengthBoundValue::Failure, _) => Err(anyhow!("No Lipschitz estimate available")),
+        }
+    }
+
+    /// Returns the current value of the step length parameter.
+    pub fn current(&self) -> F {
+        match *self {
+            AdaptiveStepLength::Adaptive { τ0, l, .. } => τ0 / l,
+            AdaptiveStepLength::Fixed { τ } => τ,
+        }
+    }
+
+    /// Update daptive Lipschitz factor and return new step length parameter `τ`.
+    ///
+    /// Inputs:
+    /// * `μ`: current point
+    /// * `fμ`: value of the function `f` at `μ`.
+    /// * `ν`: derivative of the function `f` at `μ`.
+    /// * `τ0`: fractional step length parameter in $[0, 1)$.
+    pub fn update<'a, G>(&mut self, μ: &RNDM<N, F>, fμ: F, v: &'a G) -> F
+    where
+        G: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace,
+        &'a G: Instance<G>,
+    {
+        match self {
+            AdaptiveStepLength::Adaptive { l, μ_old, fμ_old, μ_dist, τ0, l_is_initial } => {
+                // Estimate step length parameter
+                let b = *fμ_old - fμ - μ_old.apply(v) + μ.apply(v);
+                let d = *μ_dist;
+                if d.abs() > F::EPSILON && μ.len() > 0 && μ_old.len() > 0 {
+                    let lc = b / (d * d / 2.0);
+                    dbg!(b, d, lc);
+                    if *l_is_initial {
+                        *l = lc;
+                        *l_is_initial = false;
+                    } else {
+                        *l = l.max(lc);
+                    }
+                }
+
+                // Store for next iteration
+                *μ_old = μ.clone();
+                *fμ_old = fμ;
+
+                return *τ0 / *l;
+            }
+            AdaptiveStepLength::Fixed { τ } => *τ,
+        }
+    }
+
+    /// Finalises a step, storing μ and its distance to the previous μ.
+    ///
+    /// This is not included in [`Self::update`], as this function is to be called
+    /// before pruning and merging, while μ and its previous version in their internal
+    /// presentation still having matching indices for the same coordinate.
+    pub fn finish_step(&mut self, μ: &RNDM<N, F>) {
+        if let AdaptiveStepLength::Adaptive { μ_dist, μ_old, .. } = self {
+            *μ_dist = μ.dist_matching(&μ_old);
+        }
+    }
+}
+
 #[replace_float_literals(F::cast_from(literal))]
 pub(crate) fn postprocess<F: Float, Dat: Fn(&RNDM<N, F>) -> F, const N: usize>(
     mut μ: RNDM<N, F>,
@@ -157,24 +276,26 @@
     fbconfig: &FBConfig<F>,
     iterator: I,
     mut plotter: Plot,
-    μ0 : Option<RNDM<N, F>>,
+    μ0: Option<RNDM<N, F>>,
 ) -> DynResult<RNDM<N, F>>
 where
     F: Float + ToNalgebraRealField,
     I: AlgIteratorFactory<IterInfo<F>>,
     RNDM<N, F>: SpikeMerging<F>,
     Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
-    Dat::DerivativeDomain: ClosedMul<F>,
+    Dat::DerivativeDomain: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace,
     Reg: RegTerm<Loc<N, F>, F>,
     P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
     Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
+    for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>,
 {
     // Set up parameters
     let config = &fbconfig.insertion;
-    let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?;
+    let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, fbconfig)?;
+
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
-    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
+    let tolerance = config.tolerance * adaptive_τ.current() * reg.tolerance_scaling();
     let mut ε = tolerance.initial();
 
     // Initialise iterates
@@ -194,7 +315,10 @@
     for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
         // TODO: optimise τ to be applied to residual.
-        let mut τv = f.differential(&μ) * τ;
+        let (fμ, v) = f.apply_and_differential(&μ);
+        let τ = adaptive_τ.update(&μ, fμ, &v);
+        dbg!(τ);
+        let mut τv = v * τ;
 
         // Save current base point
         let μ_base = μ.clone();
@@ -204,6 +328,9 @@
             &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
         )?;
 
+        // We don't treat merge in adaptive Lipschitz.
+        adaptive_τ.finish_step(&μ);
+
         // Prune and possibly merge spikes
         if config.merge_now(&state) {
             stats.merged += prox_penalty.merge_spikes(
@@ -257,25 +384,27 @@
     fbconfig: &FBConfig<F>,
     iterator: I,
     mut plotter: Plot,
-    μ0: Option<RNDM<N, F>>
+    μ0: Option<RNDM<N, F>>,
 ) -> DynResult<RNDM<N, F>>
 where
     F: Float + ToNalgebraRealField,
     I: AlgIteratorFactory<IterInfo<F>>,
     RNDM<N, F>: SpikeMerging<F>,
     Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F>,
-    Dat::DerivativeDomain: ClosedMul<F>,
+    Dat::DerivativeDomain: ClosedMul<F> + Mapping<Loc<N, F>, Codomain = F> + ClosedSpace,
     Reg: RegTerm<Loc<N, F>, F>,
     P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
     Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
+    for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>,
 {
     // Set up parameters
     let config = &fbconfig.insertion;
-    let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?;
+    let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, fbconfig)?;
+
     let mut λ = 1.0;
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
-    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
+    let tolerance = config.tolerance * adaptive_τ.current() * reg.tolerance_scaling();
     let mut ε = tolerance.initial();
 
     // Initialise iterates
@@ -296,7 +425,9 @@
     // Run the algorithm
     for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
         // Calculate smooth part of surrogate model.
-        let mut τv = f.differential(&μ) * τ;
+        let (fμ, v) = f.apply_and_differential(&μ);
+        let τ = adaptive_τ.update(&μ, fμ, &v);
+        let mut τv = v * τ;
 
         // Save current base point
         let μ_base = μ.clone();
@@ -306,6 +437,9 @@
             &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats,
         )?;
 
+        // We don't treat merge in adaptive Lipschitz.
+        adaptive_τ.finish_step(&μ);
+
         // (Do not) merge spikes.
         if config.merge_now(&state) && !warned_merging {
             let err = format!("Merging not supported for μFISTA");

mercurial