diff -r 4f468d35fa29 -r 32328a74c790 src/fb.rs --- a/src/fb.rs Thu Feb 26 11:38:43 2026 -0500 +++ b/src/fb.rs Fri Jan 16 19:39:22 2026 -0500 @@ -80,14 +80,16 @@ use crate::measures::merging::SpikeMerging; use crate::measures::{DiscreteMeasure, RNDM}; use crate::plot::Plotter; +use crate::prox_penalty::StepLengthBoundValue; pub use crate::prox_penalty::{InsertionConfig, ProxPenalty, StepLengthBound}; use crate::regularisation::RegTerm; use crate::types::*; use alg_tools::error::DynResult; -use alg_tools::instance::Instance; +use alg_tools::instance::{ClosedSpace, Instance}; use alg_tools::iterate::AlgIteratorFactory; -use alg_tools::mapping::DifferentiableMapping; +use alg_tools::mapping::{DifferentiableMapping, Mapping}; use alg_tools::nalgebra_support::ToNalgebraRealField; +use anyhow::anyhow; use colored::Colorize; use numeric_literals::replace_float_literals; use serde::{Deserialize, Serialize}; @@ -102,16 +104,14 @@ pub σp0: F, /// Generic parameters pub insertion: InsertionConfig, + /// Always adaptive step length + pub always_adaptive_τ: bool, } #[replace_float_literals(F::cast_from(literal))] impl Default for FBConfig { fn default() -> Self { - FBConfig { - τ0: 0.99, - σp0: 0.99, - insertion: Default::default(), - } + FBConfig { τ0: 0.99, σp0: 0.99, always_adaptive_τ: false, insertion: Default::default() } } } @@ -122,6 +122,125 @@ n_before_prune - μ.len() } +/// Adaptive step length and Lipschitz parameter estimation state. +#[derive(Clone, Debug, Serialize)] +pub enum AdaptiveStepLength { + Adaptive { + l: F, + μ_old: RNDM, + fμ_old: F, + μ_dist: F, + τ0: F, + l_is_initial: bool, + }, + Fixed { + τ: F, + }, +} + +#[replace_float_literals(F::cast_from(literal))] +impl AdaptiveStepLength { + pub fn new(f: &Dat, prox_penalty: &P, fbconfig: &FBConfig) -> DynResult + where + F: ToNalgebraRealField, + Dat: DifferentiableMapping, Codomain = F>, + P: ProxPenalty, Dat::DerivativeDomain, Reg, F> + StepLengthBound, + Reg: RegTerm, F>, + { + match ( + prox_penalty.step_length_bound(&f), + fbconfig.always_adaptive_τ, + ) { + (StepLengthBoundValue::LipschitzFactor(l), false) => { + Ok(AdaptiveStepLength::Fixed { τ: fbconfig.τ0 / l }) + } + (StepLengthBoundValue::LipschitzFactor(l), true) => { + let μ_old = DiscreteMeasure::new(); + let fμ_old = f.apply(&μ_old); + Ok(AdaptiveStepLength::Adaptive { + l: l, + μ_old, + fμ_old, + μ_dist: 0.0, + τ0: fbconfig.τ0, + l_is_initial: false, + }) + } + (StepLengthBoundValue::UnreliableLipschitzFactor(l), _) => { + println!("Lipschitz factor is unreliable; calculating adaptively."); + let μ_old = DiscreteMeasure::new(); + let fμ_old = f.apply(&μ_old); + Ok(AdaptiveStepLength::Adaptive { + l: l, + μ_old, + fμ_old, + μ_dist: 0.0, + τ0: fbconfig.τ0, + l_is_initial: true, + }) + } + (StepLengthBoundValue::Failure, _) => Err(anyhow!("No Lipschitz estimate available")), + } + } + + /// Returns the current value of the step length parameter. + pub fn current(&self) -> F { + match *self { + AdaptiveStepLength::Adaptive { τ0, l, .. } => τ0 / l, + AdaptiveStepLength::Fixed { τ } => τ, + } + } + + /// Update daptive Lipschitz factor and return new step length parameter `τ`. + /// + /// Inputs: + /// * `μ`: current point + /// * `fμ`: value of the function `f` at `μ`. + /// * `ν`: derivative of the function `f` at `μ`. + /// * `τ0`: fractional step length parameter in $[0, 1)$. + pub fn update<'a, G>(&mut self, μ: &RNDM, fμ: F, v: &'a G) -> F + where + G: ClosedMul + Mapping, Codomain = F> + ClosedSpace, + &'a G: Instance, + { + match self { + AdaptiveStepLength::Adaptive { l, μ_old, fμ_old, μ_dist, τ0, l_is_initial } => { + // Estimate step length parameter + let b = *fμ_old - fμ - μ_old.apply(v) + μ.apply(v); + let d = *μ_dist; + if d.abs() > F::EPSILON && μ.len() > 0 && μ_old.len() > 0 { + let lc = b / (d * d / 2.0); + dbg!(b, d, lc); + if *l_is_initial { + *l = lc; + *l_is_initial = false; + } else { + *l = l.max(lc); + } + } + + // Store for next iteration + *μ_old = μ.clone(); + *fμ_old = fμ; + + return *τ0 / *l; + } + AdaptiveStepLength::Fixed { τ } => *τ, + } + } + + /// Finalises a step, storing μ and its distance to the previous μ. + /// + /// This is not included in [`Self::update`], as this function is to be called + /// before pruning and merging, while μ and its previous version in their internal + /// presentation still having matching indices for the same coordinate. + pub fn finish_step(&mut self, μ: &RNDM) { + if let AdaptiveStepLength::Adaptive { μ_dist, μ_old, .. } = self { + *μ_dist = μ.dist_matching(&μ_old); + } + } +} + #[replace_float_literals(F::cast_from(literal))] pub(crate) fn postprocess) -> F, const N: usize>( mut μ: RNDM, @@ -157,24 +276,26 @@ fbconfig: &FBConfig, iterator: I, mut plotter: Plot, - μ0 : Option>, + μ0: Option>, ) -> DynResult> where F: Float + ToNalgebraRealField, I: AlgIteratorFactory>, RNDM: SpikeMerging, Dat: DifferentiableMapping, Codomain = F>, - Dat::DerivativeDomain: ClosedMul, + Dat::DerivativeDomain: ClosedMul + Mapping, Codomain = F> + ClosedSpace, Reg: RegTerm, F>, P: ProxPenalty, Dat::DerivativeDomain, Reg, F> + StepLengthBound, Plot: Plotter>, + for<'a> &'a Dat::DerivativeDomain: Instance, { // Set up parameters let config = &fbconfig.insertion; - let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; + let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, fbconfig)?; + // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. - let tolerance = config.tolerance * τ * reg.tolerance_scaling(); + let tolerance = config.tolerance * adaptive_τ.current() * reg.tolerance_scaling(); let mut ε = tolerance.initial(); // Initialise iterates @@ -194,7 +315,10 @@ for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. // TODO: optimise τ to be applied to residual. - let mut τv = f.differential(&μ) * τ; + let (fμ, v) = f.apply_and_differential(&μ); + let τ = adaptive_τ.update(&μ, fμ, &v); + dbg!(τ); + let mut τv = v * τ; // Save current base point let μ_base = μ.clone(); @@ -204,6 +328,9 @@ &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, )?; + // We don't treat merge in adaptive Lipschitz. + adaptive_τ.finish_step(&μ); + // Prune and possibly merge spikes if config.merge_now(&state) { stats.merged += prox_penalty.merge_spikes( @@ -257,25 +384,27 @@ fbconfig: &FBConfig, iterator: I, mut plotter: Plot, - μ0: Option> + μ0: Option>, ) -> DynResult> where F: Float + ToNalgebraRealField, I: AlgIteratorFactory>, RNDM: SpikeMerging, Dat: DifferentiableMapping, Codomain = F>, - Dat::DerivativeDomain: ClosedMul, + Dat::DerivativeDomain: ClosedMul + Mapping, Codomain = F> + ClosedSpace, Reg: RegTerm, F>, P: ProxPenalty, Dat::DerivativeDomain, Reg, F> + StepLengthBound, Plot: Plotter>, + for<'a> &'a Dat::DerivativeDomain: Instance, { // Set up parameters let config = &fbconfig.insertion; - let τ = fbconfig.τ0 / prox_penalty.step_length_bound(&f)?; + let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, fbconfig)?; + let mut λ = 1.0; // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled // by τ compared to the conditional gradient approach. - let tolerance = config.tolerance * τ * reg.tolerance_scaling(); + let tolerance = config.tolerance * adaptive_τ.current() * reg.tolerance_scaling(); let mut ε = tolerance.initial(); // Initialise iterates @@ -296,7 +425,9 @@ // Run the algorithm for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { // Calculate smooth part of surrogate model. - let mut τv = f.differential(&μ) * τ; + let (fμ, v) = f.apply_and_differential(&μ); + let τ = adaptive_τ.update(&μ, fμ, &v); + let mut τv = v * τ; // Save current base point let μ_base = μ.clone(); @@ -306,6 +437,9 @@ &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, )?; + // We don't treat merge in adaptive Lipschitz. + adaptive_τ.finish_step(&μ); + // (Do not) merge spikes. if config.merge_now(&state) && !warned_merging { let err = format!("Merging not supported for μFISTA");