src/sliding_fb.rs

branch
dev
changeset 68
00d0881f89a6
parent 63
7a8a55fd41c0
equal deleted inserted replaced
67:95bb12bdb6ac 68:00d0881f89a6
38 pub adaptation: F, 38 pub adaptation: F,
39 /// A posteriori transport tolerance multiplier (C_pos) 39 /// A posteriori transport tolerance multiplier (C_pos)
40 pub tolerance_mult_con: F, 40 pub tolerance_mult_con: F,
41 /// maximum number of adaptation iterations, until cancelling transport. 41 /// maximum number of adaptation iterations, until cancelling transport.
42 pub max_attempts: usize, 42 pub max_attempts: usize,
43 /// Maximum number of failed transportations for a single source point
44 pub max_fail: usize,
43 } 45 }
44 46
45 #[replace_float_literals(F::cast_from(literal))] 47 #[replace_float_literals(F::cast_from(literal))]
46 impl<F: Float> TransportConfig<F> { 48 impl<F: Float> TransportConfig<F> {
47 /// Check that the parameters are ok. Panics if not. 49 /// Check that the parameters are ok. Panics if not.
54 } 56 }
55 57
56 #[replace_float_literals(F::cast_from(literal))] 58 #[replace_float_literals(F::cast_from(literal))]
57 impl<F: Float> Default for TransportConfig<F> { 59 impl<F: Float> Default for TransportConfig<F> {
58 fn default() -> Self { 60 fn default() -> Self {
59 TransportConfig { θ0: 0.9, adaptation: 0.9, tolerance_mult_con: 100.0, max_attempts: 2 } 61 TransportConfig {
62 θ0: 0.9,
63 adaptation: 0.9,
64 tolerance_mult_con: 100.0,
65 max_attempts: 2,
66 max_fail: usize::MAX,
67 }
60 } 68 }
61 } 69 }
62 70
63 /// Settings for [`pointsource_sliding_fb_reg`]. 71 /// Settings for [`pointsource_sliding_fb_reg`].
64 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 72 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
112 α_μ_orig: F, 120 α_μ_orig: F,
113 /// Transported mass 121 /// Transported mass
114 α_γ: F, 122 α_γ: F,
115 /// Helper for pruning 123 /// Helper for pruning
116 prune: bool, 124 prune: bool,
125 /// Fail count
126 fail_count: usize,
117 } 127 }
118 128
119 #[derive(Clone, Debug, Serialize)] 129 #[derive(Clone, Debug, Serialize)]
120 pub struct Transport<const N: usize, F: Float> { 130 pub struct Transport<const N: usize, F: Float> {
121 vec: Vec<SingleTransport<N, F>>, 131 vec: Vec<SingleTransport<N, F>>,
205 &mut self, 215 &mut self,
206 μ: &RNDM<N, F>, 216 μ: &RNDM<N, F>,
207 _τ: F, 217 _τ: F,
208 τθ_or_adaptive: &mut TransportStepLength<F, G>, 218 τθ_or_adaptive: &mut TransportStepLength<F, G>,
209 v: D, 219 v: D,
220 tconfig: &TransportConfig<F>,
210 ) where 221 ) where
211 G: Fn(F, F) -> F, 222 G: Fn(F, F) -> F,
212 D: DifferentiableRealMapping<N, F>, 223 D: DifferentiableRealMapping<N, F>,
213 { 224 {
214 use TransportStepLength::*; 225 use TransportStepLength::*;
215 226
216 // Initialise transport structure weights 227 // Initialise transport structure weights
217 for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) { 228 for (δ, ρ) in izip!(μ.iter_spikes(), self.iter_mut()) {
218 ρ.α_μ_orig = δ.α; 229 ρ.α_μ_orig = δ.α;
219 ρ.x = δ.x; 230 ρ.x = δ.x;
220 // If old transport has opposing sign, the new transport will be none. 231 if ρ.fail_count > tconfig.max_fail {
221 ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) { 232 ρ.α_γ = 0.0
222 0.0
223 } else { 233 } else {
224 δ.α 234 // If old transport has opposing sign, the new transport will be none.
235 ρ.α_γ = if (ρ.α_γ > 0.0 && δ.α < 0.0) || (ρ.α_γ < 0.0 && δ.α > 0.0) {
236 0.0
237 } else {
238 δ.α
239 }
225 } 240 }
226 } 241 }
227 242
228 let γ_prev_len = self.len(); 243 let γ_prev_len = self.len();
229 assert!(μ.len() >= γ_prev_len); 244 assert!(μ.len() >= γ_prev_len);
231 x: δ.x, 246 x: δ.x,
232 y: δ.x, // Just something, will be filled properly in the next phase 247 y: δ.x, // Just something, will be filled properly in the next phase
233 α_μ_orig: δ.α, 248 α_μ_orig: δ.α,
234 α_γ: δ.α, 249 α_γ: δ.α,
235 prune: false, 250 prune: false,
251 fail_count: 0,
236 })); 252 }));
237 253
238 // Calculate transport rays. 254 // Calculate transport rays.
239 match *τθ_or_adaptive { 255 match *τθ_or_adaptive {
240 Fixed(θ) => { 256 Fixed(θ) => {
241 for ρ in self.iter_mut() { 257 for ρ in self.iter_mut() {
242 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ); 258 if ρ.fail_count <= tconfig.max_fail {
259 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θ);
260 }
243 } 261 }
244 } 262 }
245 AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θτ } => { 263 AdaptiveMax { l: ℓ_F, ref mut max_transport, g: ref calculate_θτ } => {
246 *max_transport = max_transport.max(self.norm(Radon)); 264 *max_transport = max_transport.max(self.norm(Radon));
247 let θτ = calculate_θτ(ℓ_F, *max_transport); 265 let θτ = calculate_θτ(ℓ_F, *max_transport);
248 for ρ in self.iter_mut() { 266 for ρ in self.iter_mut() {
249 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ); 267 if ρ.fail_count <= tconfig.max_fail {
268 ρ.y = ρ.x - v.differential(&ρ.x) * (ρ.α_γ.signum() * θτ);
269 }
250 } 270 }
251 } 271 }
252 FullyAdaptive { 272 FullyAdaptive {
253 l: ref mut adaptive_ℓ_F, 273 l: ref mut adaptive_ℓ_F,
254 ref mut max_transport, 274 ref mut max_transport,
259 // Do two runs through the spikes to update θ, breaking if first run did not cause 279 // Do two runs through the spikes to update θ, breaking if first run did not cause
260 // a change. 280 // a change.
261 for _i in 0..=1 { 281 for _i in 0..=1 {
262 let mut changes = false; 282 let mut changes = false;
263 for ρ in self.iter_mut() { 283 for ρ in self.iter_mut() {
264 let dv_x = v.differential(&ρ.x); 284 if ρ.fail_count < tconfig.max_fail {
265 let g = &dv_x * (ρ.α_γ.signum() * θτ); 285 let dv_x = v.differential(&ρ.x);
266 ρ.y = ρ.x - g; 286 let g = &dv_x * (ρ.α_γ.signum() * θτ);
267 let n = g.norm2(); 287 ρ.y = ρ.x - g;
268 if n >= F::EPSILON { 288 let n = g.norm2();
269 // Estimate Lipschitz factor of ∇v 289 if n >= F::EPSILON {
270 let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n; 290 // Estimate Lipschitz factor of ∇v
271 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F); 291 let this_ℓ_F = (dv_x - v.differential(&ρ.y)).norm2() / n;
272 θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport); 292 *adaptive_ℓ_F = adaptive_ℓ_F.max(this_ℓ_F);
273 changes = true 293 θτ = calculate_θτ(*adaptive_ℓ_F, *max_transport);
294 changes = true
295 }
274 } 296 }
275 } 297 }
276 if !changes { 298 if !changes {
277 break; 299 break;
278 } 300 }
429 for ρ in self.iter_mut() { 451 for ρ in self.iter_mut() {
430 ρ.α_γ = 0.0; 452 ρ.α_γ = 0.0;
431 } 453 }
432 } 454 }
433 455
456 for ρ in self.iter_mut() {
457 if ρ.α_γ == 0.0 {
458 ρ.fail_count += 1;
459 } else if all_ok {
460 ρ.fail_count = 0;
461 }
462 }
463
434 all_ok 464 all_ok
435 } 465 }
436 466
437 /// Returns $‖μ\^k - π\_♯\^0γ\^{k+1}‖$ 467 /// Returns $‖μ\^k - π\_♯\^0γ\^{k+1}‖$
438 pub(crate) fn μ0_minus_γ0_radon(&self) -> F { 468 pub(crate) fn μ0_minus_γ0_radon(&self) -> F {
595 625
596 // Run the algorithm 626 // Run the algorithm
597 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 627 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
598 // Calculate initial transport 628 // Calculate initial transport
599 let v = f.differential(&μ); 629 let v = f.differential(&μ);
600 γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v); 630 γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v, &config.transport);
601 631
602 let mut attempts = 0; 632 let mut attempts = 0;
603 633
604 // Solve finite-dimensional subproblem several times until the dual variable for the 634 // Solve finite-dimensional subproblem several times until the dual variable for the
605 // regularisation term conforms to the assumptions made for the transport above. 635 // regularisation term conforms to the assumptions made for the transport above.

mercurial