| 38 pub adaptation: F, |
38 pub adaptation: F, |
| 39 /// A posteriori transport tolerance multiplier (C_pos) |
39 /// A posteriori transport tolerance multiplier (C_pos) |
| 40 pub tolerance_mult_con: F, |
40 pub tolerance_mult_con: F, |
| 41 /// maximum number of adaptation iterations, until cancelling transport. |
41 /// maximum number of adaptation iterations, until cancelling transport. |
| 42 pub max_attempts: usize, |
42 pub max_attempts: usize, |
| |
43 /// Maximum number of failed transportations for a single source point |
| |
44 pub max_fail: usize, |
| 43 } |
45 } |
| 44 |
46 |
| 45 #[replace_float_literals(F::cast_from(literal))] |
47 #[replace_float_literals(F::cast_from(literal))] |
| 46 impl<F: Float> TransportConfig<F> { |
48 impl<F: Float> TransportConfig<F> { |
| 47 /// Check that the parameters are ok. Panics if not. |
49 /// Check that the parameters are ok. Panics if not. |
| 54 } |
56 } |
| 55 |
57 |
| 56 #[replace_float_literals(F::cast_from(literal))] |
58 #[replace_float_literals(F::cast_from(literal))] |
| 57 impl<F: Float> Default for TransportConfig<F> { |
59 impl<F: Float> Default for TransportConfig<F> { |
| 58 fn default() -> Self { |
60 fn default() -> Self { |
| 59 TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0, max_attempts: 2 } |
61 TransportConfig { |
| |
62 θ0: 0.9, |
| |
63 adaptation: 0.9, |
| |
64 tolerance_mult_con: 100.0, |
| |
65 max_attempts: 2, |
| |
66 max_fail: usize::MAX, |
| |
67 } |
| 60 } |
68 } |
| 61 } |
69 } |
| 62 |
70 |
| 63 /// Settings for [`pointsource_sliding_fb_reg`]. |
71 /// Settings for [`pointsource_sliding_fb_reg`]. |
| 64 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
72 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 112 α_μ_orig: F, |
120 α_μ_orig: F, |
| 113 /// Transported mass |
121 /// Transported mass |
| 114 α_γ: F, |
122 α_γ: F, |
| 115 /// Helper for pruning |
123 /// Helper for pruning |
| 116 prune: bool, |
124 prune: bool, |
| |
125 /// Fail count |
| |
126 fail_count: usize, |
| 117 } |
127 } |
| 118 |
128 |
| 119 #[derive(Clone, Debug, Serialize)] |
129 #[derive(Clone, Debug, Serialize)] |
| 120 pub struct Transport<const N: usize, F: Float> { |
130 pub struct Transport<const N: usize, F: Float> { |
| 121 vec: Vec<SingleTransport<N, F>>, |
131 vec: Vec<SingleTransport<N, F>>, |
| 205 &mut self, |
215 &mut self, |
| 206 μ: &RNDM<N, F>, |
216 μ: &RNDM<N, F>, |
| 207 _τ: F, |
217 _τ: F, |
| 208 τθ_or_adaptive: &mut TransportStepLength<F, G>, |
218 τθ_or_adaptive: &mut TransportStepLength<F, G>, |
| 209 v: D, |
219 v: D, |
| |
220 tconfig: &TransportConfig<F>, |
| 210 ) where |
221 ) where |
| 211 G: Fn(F, F) -> F, |
222 G: Fn(F, F) -> F, |
| 212 D: DifferentiableRealMapping<N, F>, |
223 D: DifferentiableRealMapping<N, F>, |
| 213 { |
224 { |
| 214 use TransportStepLength::*; |
225 use TransportStepLength::*; |
| 215 |
226 |
| 216 // Initialise transport structure weights |
227 // Initialise transport structure weights |
| 217 for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) { |
228 for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) { |
| 218 ρ.α_μ_orig = δ.α; |
229 ρ.α_μ_orig = δ.α; |
| 219 ρ.x = δ.x; |
230 ρ.x = δ.x; |
| 220 // If old transport has opposing sign, the new transport will be none. |
231 if ρ.fail_count > tconfig.max_fail { |
| 221 ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) { |
232 ρ.α_γ = 0.0 |
| 222 0.0 |
|
| 223 } else { |
233 } else { |
| 224 δ.α |
234 // If old transport has opposing sign, the new transport will be none. |
| |
235 ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) { |
| |
236 0.0 |
| |
237 } else { |
| |
238 δ.α |
| |
239 } |
| 225 } |
240 } |
| 226 } |
241 } |
| 227 |
242 |
| 228 let γ_prev_len = self.len(); |
243 let γ_prev_len = self.len(); |
| 229 assert!(μ.len() >= γ_prev_len); |
244 assert!(μ.len() >= γ_prev_len); |
| 231 x: δ.x, |
246 x: δ.x, |
| 232 y: δ.x, // Just something, will be filled properly in the next phase |
247 y: δ.x, // Just something, will be filled properly in the next phase |
| 233 α_μ_orig: δ.α, |
248 α_μ_orig: δ.α, |
| 234 α_γ: δ.α, |
249 α_γ: δ.α, |
| 235 prune: false, |
250 prune: false, |
| |
251 fail_count: 0, |
| 236 })); |
252 })); |
| 237 |
253 |
| 238 // Calculate transport rays. |
254 // Calculate transport rays. |
| 239 match *τθ_or_adaptive { |
255 match *τθ_or_adaptive { |
| 240 Fixed(θ) => { |
256 Fixed(θ) => { |
| 241 for ρ in self.iter_mut() { |
257 for ρ in self.iter_mut() { |
| 242 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ); |
258 if ρ.fail_count <= tconfig.max_fail { |
| |
259 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ); |
| |
260 } |
| 243 } |
261 } |
| 244 } |
262 } |
| 245 AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θτ } => { |
263 AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θτ } => { |
| 246 *max_transport = max_transport.max(self.norm(Radon)); |
264 *max_transport = max_transport.max(self.norm(Radon)); |
| 247 let θτ = calculate_θτ(ℓ_F, *max_transport); |
265 let θτ = calculate_θτ(ℓ_F, *max_transport); |
| 248 for ρ in self.iter_mut() { |
266 for ρ in self.iter_mut() { |
| 249 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ); |
267 if ρ.fail_count <= tconfig.max_fail { |
| |
268 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ); |
| |
269 } |
| 250 } |
270 } |
| 251 } |
271 } |
| 252 FullyAdaptive { |
272 FullyAdaptive { |
| 253 l: ref mut adaptive_ℓ_F, |
273 l: ref mut adaptive_ℓ_F, |
| 254 ref mut max_transport, |
274 ref mut max_transport, |
| 259 // Do two runs through the spikes to update θ, breaking if first run did not cause |
279 // Do two runs through the spikes to update θ, breaking if first run did not cause |
| 260 // a change. |
280 // a change. |
| 261 for _i in 0..=1 { |
281 for _i in 0..=1 { |
| 262 let mut changes = false; |
282 let mut changes = false; |
| 263 for ρ in self.iter_mut() { |
283 for ρ in self.iter_mut() { |
| 264 let dv_x = v.differential(&ρ.x); |
284 if ρ.fail_count < tconfig.max_fail { |
| 265 let g = &dv_x * (ρ.α_γ.signum() * θτ); |
285 let dv_x = v.differential(&ρ.x); |
| 266 ρ.y = ρ.x - g; |
286 let g = &dv_x * (ρ.α_γ.signum() * θτ); |
| 267 let n = g.norm2(); |
287 ρ.y = ρ.x - g; |
| 268 if n >= F::EPSILON { |
288 let n = g.norm2(); |
| 269 // Estimate Lipschitz factor of ∇v |
289 if n >= F::EPSILON { |
| 270 let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n; |
290 // Estimate Lipschitz factor of ∇v |
| 271 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F); |
291 let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n; |
| 272 θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport); |
292 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F); |
| 273 changes = true |
293 θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport); |
| |
294 changes = true |
| |
295 } |
| 274 } |
296 } |
| 275 } |
297 } |
| 276 if !changes { |
298 if !changes { |
| 277 break; |
299 break; |
| 278 } |
300 } |
| 595 |
625 |
| 596 // Run the algorithm |
626 // Run the algorithm |
| 597 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
627 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 598 // Calculate initial transport |
628 // Calculate initial transport |
| 599 let v = f.differential(&μ); |
629 let v = f.differential(&μ); |
| 600 γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v); |
630 γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v, &config.transport); |
| 601 |
631 |
| 602 let mut attempts = 0; |
632 let mut attempts = 0; |
| 603 |
633 |
| 604 // Solve finite-dimensional subproblem several times until the dual variable for the |
634 // Solve finite-dimensional subproblem several times until the dual variable for the |
| 605 // regularisation term conforms to the assumptions made for the transport above. |
635 // regularisation term conforms to the assumptions made for the transport above. |