src/sliding_fb.rs

branch
dev
changeset 62
32328a74c790
parent 61
4f468d35fa29
equal deleted inserted replaced
61:4f468d35fa29 62:32328a74c790
18 use crate::prox_penalty::{ProxPenalty, StepLengthBound}; 18 use crate::prox_penalty::{ProxPenalty, StepLengthBound};
19 use crate::regularisation::SlidingRegTerm; 19 use crate::regularisation::SlidingRegTerm;
20 use crate::types::*; 20 use crate::types::*;
21 use alg_tools::error::DynResult; 21 use alg_tools::error::DynResult;
22 use alg_tools::euclidean::Euclidean; 22 use alg_tools::euclidean::Euclidean;
23 use alg_tools::instance::{ClosedSpace, Instance};
23 use alg_tools::iterate::AlgIteratorFactory; 24 use alg_tools::iterate::AlgIteratorFactory;
24 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping}; 25 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping};
25 use alg_tools::nalgebra_support::ToNalgebraRealField; 26 use alg_tools::nalgebra_support::ToNalgebraRealField;
26 use alg_tools::norms::Norm; 27 use alg_tools::norms::Norm;
27 use anyhow::ensure; 28 use anyhow::ensure;
68 pub transport: TransportConfig<F>, 69 pub transport: TransportConfig<F>,
69 /// Generic parameters 70 /// Generic parameters
70 pub insertion: InsertionConfig<F>, 71 pub insertion: InsertionConfig<F>,
71 /// Guess for curvature bound calculations. 72 /// Guess for curvature bound calculations.
72 pub guess: BoundedCurvatureGuess, 73 pub guess: BoundedCurvatureGuess,
74 /// Always adaptive step length
75 pub always_adaptive_τ: bool,
76 }
77
78 impl<'a, F: Float> Into<FBConfig<F>> for &'a SlidingFBConfig<F> {
79 fn into(self) -> FBConfig<F> {
80 let SlidingFBConfig { τ0, σp0, insertion, always_adaptive_τ, .. } = *self;
81 FBConfig { τ0, σp0, insertion, always_adaptive_τ }
82 }
73 } 83 }
74 84
75 #[replace_float_literals(F::cast_from(literal))] 85 #[replace_float_literals(F::cast_from(literal))]
76 impl<F: Float> Default for SlidingFBConfig<F> { 86 impl<F: Float> Default for SlidingFBConfig<F> {
77 fn default() -> Self { 87 fn default() -> Self {
79 τ0: 0.99, 89 τ0: 0.99,
80 σp0: 0.99, 90 σp0: 0.99,
81 transport: Default::default(), 91 transport: Default::default(),
82 insertion: Default::default(), 92 insertion: Default::default(),
83 guess: BoundedCurvatureGuess::BetterThanZero, 93 guess: BoundedCurvatureGuess::BetterThanZero,
94 always_adaptive_τ: false,
84 } 95 }
85 } 96 }
86 } 97 }
87 98
88 /// Internal type of adaptive transport step length calculation 99 /// Internal type of adaptive transport step length calculation
269 ) -> DynResult<RNDM<N, F>> 280 ) -> DynResult<RNDM<N, F>>
270 where 281 where
271 F: Float + ToNalgebraRealField, 282 F: Float + ToNalgebraRealField,
272 I: AlgIteratorFactory<IterInfo<F>>, 283 I: AlgIteratorFactory<IterInfo<F>>,
273 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>, 284 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>,
274 Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>, 285 Dat::DerivativeDomain:
275 //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>, 286 ClosedMul<F> + DifferentiableRealMapping<N, F, Codomain = F> + ClosedSpace,
276 RNDM<N, F>: SpikeMerging<F>, 287 RNDM<N, F>: SpikeMerging<F>,
277 Reg: SlidingRegTerm<Loc<N, F>, F>, 288 Reg: SlidingRegTerm<Loc<N, F>, F>,
278 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, 289 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>,
279 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, 290 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>,
291 for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>,
280 { 292 {
281 // Check parameters 293 // Check parameters
282 ensure!(config.τ0 > 0.0, "Invalid step length parameter"); 294 ensure!(config.τ0 > 0.0, "Invalid step length parameter");
283 config.transport.check()?; 295 config.transport.check()?;
284 296
290 // let opAnorm = opA.opnorm_bound(Radon, L2); 302 // let opAnorm = opA.opnorm_bound(Radon, L2);
291 //let max_transport = config.max_transport.scale 303 //let max_transport = config.max_transport.scale
292 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); 304 // * reg.radon_norm_bound(b.norm2_squared() / 2.0);
293 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; 305 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
294 let ℓ = 0.0; 306 let ℓ = 0.0;
295 let τ = config.τ0 / prox_penalty.step_length_bound(&f)?; 307 let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, &config.into())?;
308 let τ = adaptive_τ.current();
309 println!("TODO: τ in calculate_θ should be adaptive");
296 let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); 310 let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess);
297 let transport_lip = maybe_transport_lip?; 311 let transport_lip = maybe_transport_lip?;
298 let calculate_θ = |ℓ_F, max_transport| { 312 let calculate_θ = |ℓ_F, max_transport| {
299 let ℓ_r = transport_lip * max_transport; 313 let ℓ_r = transport_lip * max_transport;
300 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r)) 314 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r))
328 let mut stats = IterInfo::new(); 342 let mut stats = IterInfo::new();
329 343
330 // Run the algorithm 344 // Run the algorithm
331 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 345 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
332 // Calculate initial transport 346 // Calculate initial transport
333 let v = f.differential(&μ); 347 let (fμ, v) = f.apply_and_differential(&μ);
348 let τ = adaptive_τ.update(&μ, fμ, &v);
349
334 let (μ_base_masses, mut μ_base_minus_γ0) = 350 let (μ_base_masses, mut μ_base_minus_γ0) =
335 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); 351 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v);
336 352
337 // Solve finite-dimensional subproblem several times until the dual variable for the 353 // Solve finite-dimensional subproblem several times until the dual variable for the
338 // regularisation term conforms to the assumptions made for the transport above. 354 // regularisation term conforms to the assumptions made for the transport above.
369 &config.transport, 385 &config.transport,
370 ) { 386 ) {
371 break 'adapt_transport (maybe_d, within_tolerances, τv̆); 387 break 'adapt_transport (maybe_d, within_tolerances, τv̆);
372 } 388 }
373 }; 389 };
390
391 // We don't treat merge in adaptive Lipschitz.
392 println!("WARNING: finish_step does not work with sliding");
393 adaptive_τ.finish_step(&μ);
374 394
375 stats.untransported_fraction = Some({ 395 stats.untransported_fraction = Some({
376 assert_eq!(μ_base_masses.len(), γ1.len()); 396 assert_eq!(μ_base_masses.len(), γ1.len());
377 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0)); 397 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
378 let source = μ_base_masses.iter().map(|v| v.abs()).sum(); 398 let source = μ_base_masses.iter().map(|v| v.abs()).sum();

mercurial