| 18 use crate::prox_penalty::{ProxPenalty, StepLengthBound}; |
18 use crate::prox_penalty::{ProxPenalty, StepLengthBound}; |
| 19 use crate::regularisation::SlidingRegTerm; |
19 use crate::regularisation::SlidingRegTerm; |
| 20 use crate::types::*; |
20 use crate::types::*; |
| 21 use alg_tools::error::DynResult; |
21 use alg_tools::error::DynResult; |
| 22 use alg_tools::euclidean::Euclidean; |
22 use alg_tools::euclidean::Euclidean; |
| |
23 use alg_tools::instance::{ClosedSpace, Instance}; |
| 23 use alg_tools::iterate::AlgIteratorFactory; |
24 use alg_tools::iterate::AlgIteratorFactory; |
| 24 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping}; |
25 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping}; |
| 25 use alg_tools::nalgebra_support::ToNalgebraRealField; |
26 use alg_tools::nalgebra_support::ToNalgebraRealField; |
| 26 use alg_tools::norms::Norm; |
27 use alg_tools::norms::Norm; |
| 27 use anyhow::ensure; |
28 use anyhow::ensure; |
| 68 pub transport: TransportConfig<F>, |
69 pub transport: TransportConfig<F>, |
| 69 /// Generic parameters |
70 /// Generic parameters |
| 70 pub insertion: InsertionConfig<F>, |
71 pub insertion: InsertionConfig<F>, |
| 71 /// Guess for curvature bound calculations. |
72 /// Guess for curvature bound calculations. |
| 72 pub guess: BoundedCurvatureGuess, |
73 pub guess: BoundedCurvatureGuess, |
| |
74 /// Always adaptive step length |
| |
75 pub always_adaptive_τ: bool, |
| |
76 } |
| |
77 |
| |
78 impl<'a, F: Float> Into<FBConfig<F>> for &'a SlidingFBConfig<F> { |
| |
79 fn into(self) -> FBConfig<F> { |
| |
80 let SlidingFBConfig { τ0, σp0, insertion, always_adaptive_τ, .. } = *self; |
| |
81 FBConfig { τ0, σp0, insertion, always_adaptive_τ } |
| |
82 } |
| 73 } |
83 } |
| 74 |
84 |
| 75 #[replace_float_literals(F::cast_from(literal))] |
85 #[replace_float_literals(F::cast_from(literal))] |
| 76 impl<F: Float> Default for SlidingFBConfig<F> { |
86 impl<F: Float> Default for SlidingFBConfig<F> { |
| 77 fn default() -> Self { |
87 fn default() -> Self { |
| 79 τ0: 0.99, |
89 τ0: 0.99, |
| 80 σp0: 0.99, |
90 σp0: 0.99, |
| 81 transport: Default::default(), |
91 transport: Default::default(), |
| 82 insertion: Default::default(), |
92 insertion: Default::default(), |
| 83 guess: BoundedCurvatureGuess::BetterThanZero, |
93 guess: BoundedCurvatureGuess::BetterThanZero, |
| |
94 always_adaptive_τ: false, |
| 84 } |
95 } |
| 85 } |
96 } |
| 86 } |
97 } |
| 87 |
98 |
| 88 /// Internal type of adaptive transport step length calculation |
99 /// Internal type of adaptive transport step length calculation |
| 269 ) -> DynResult<RNDM<N, F>> |
280 ) -> DynResult<RNDM<N, F>> |
| 270 where |
281 where |
| 271 F: Float + ToNalgebraRealField, |
282 F: Float + ToNalgebraRealField, |
| 272 I: AlgIteratorFactory<IterInfo<F>>, |
283 I: AlgIteratorFactory<IterInfo<F>>, |
| 273 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>, |
284 Dat: DifferentiableMapping<RNDM<N, F>, Codomain = F> + BoundedCurvature<F>, |
| 274 Dat::DerivativeDomain: DifferentiableRealMapping<N, F> + ClosedMul<F>, |
285 Dat::DerivativeDomain: |
| 275 //for<'a> Dat::Differential<'a>: Lipschitz<&'a P, FloatType = F>, |
286 ClosedMul<F> + DifferentiableRealMapping<N, F, Codomain = F> + ClosedSpace, |
| 276 RNDM<N, F>: SpikeMerging<F>, |
287 RNDM<N, F>: SpikeMerging<F>, |
| 277 Reg: SlidingRegTerm<Loc<N, F>, F>, |
288 Reg: SlidingRegTerm<Loc<N, F>, F>, |
| 278 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
289 P: ProxPenalty<Loc<N, F>, Dat::DerivativeDomain, Reg, F> + StepLengthBound<F, Dat>, |
| 279 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
290 Plot: Plotter<P::ReturnMapping, Dat::DerivativeDomain, RNDM<N, F>>, |
| |
291 for<'a> &'a Dat::DerivativeDomain: Instance<Dat::DerivativeDomain>, |
| 280 { |
292 { |
| 281 // Check parameters |
293 // Check parameters |
| 282 ensure!(config.τ0 > 0.0, "Invalid step length parameter"); |
294 ensure!(config.τ0 > 0.0, "Invalid step length parameter"); |
| 283 config.transport.check()?; |
295 config.transport.check()?; |
| 284 |
296 |
| 290 // let opAnorm = opA.opnorm_bound(Radon, L2); |
302 // let opAnorm = opA.opnorm_bound(Radon, L2); |
| 291 //let max_transport = config.max_transport.scale |
303 //let max_transport = config.max_transport.scale |
| 292 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
304 // * reg.radon_norm_bound(b.norm2_squared() / 2.0); |
| 293 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
305 //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport; |
| 294 let ℓ = 0.0; |
306 let ℓ = 0.0; |
| 295 let τ = config.τ0 / prox_penalty.step_length_bound(&f)?; |
307 let mut adaptive_τ = AdaptiveStepLength::new(f, prox_penalty, &config.into())?; |
| |
308 let τ = adaptive_τ.current(); |
| |
309 println!("TODO: τ in calculate_θ should be adaptive"); |
| 296 let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); |
310 let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess); |
| 297 let transport_lip = maybe_transport_lip?; |
311 let transport_lip = maybe_transport_lip?; |
| 298 let calculate_θ = |ℓ_F, max_transport| { |
312 let calculate_θ = |ℓ_F, max_transport| { |
| 299 let ℓ_r = transport_lip * max_transport; |
313 let ℓ_r = transport_lip * max_transport; |
| 300 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r)) |
314 config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r)) |
| 328 let mut stats = IterInfo::new(); |
342 let mut stats = IterInfo::new(); |
| 329 |
343 |
| 330 // Run the algorithm |
344 // Run the algorithm |
| 331 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
345 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 332 // Calculate initial transport |
346 // Calculate initial transport |
| 333 let v = f.differential(&μ); |
347 let (fμ, v) = f.apply_and_differential(&μ); |
| |
348 let τ = adaptive_τ.update(&μ, fμ, &v); |
| |
349 |
| 334 let (μ_base_masses, mut μ_base_minus_γ0) = |
350 let (μ_base_masses, mut μ_base_minus_γ0) = |
| 335 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); |
351 initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v); |
| 336 |
352 |
| 337 // Solve finite-dimensional subproblem several times until the dual variable for the |
353 // Solve finite-dimensional subproblem several times until the dual variable for the |
| 338 // regularisation term conforms to the assumptions made for the transport above. |
354 // regularisation term conforms to the assumptions made for the transport above. |
| 369 &config.transport, |
385 &config.transport, |
| 370 ) { |
386 ) { |
| 371 break 'adapt_transport (maybe_d, within_tolerances, τv̆); |
387 break 'adapt_transport (maybe_d, within_tolerances, τv̆); |
| 372 } |
388 } |
| 373 }; |
389 }; |
| |
390 |
| |
391 // We don't treat merge in adaptive Lipschitz. |
| |
392 println!("WARNING: finish_step does not work with sliding"); |
| |
393 adaptive_τ.finish_step(&μ); |
| 374 |
394 |
| 375 stats.untransported_fraction = Some({ |
395 stats.untransported_fraction = Some({ |
| 376 assert_eq!(μ_base_masses.len(), γ1.len()); |
396 assert_eq!(μ_base_masses.len(), γ1.len()); |
| 377 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0)); |
397 let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0)); |
| 378 let source = μ_base_masses.iter().map(|v| v.abs()).sum(); |
398 let source = μ_base_masses.iter().map(|v| v.abs()).sum(); |