| 78 We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by |
78 We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by |
| 79 [`InnerSettings`] in [`FBGenericConfig::inner`]. |
79 [`InnerSettings`] in [`FBGenericConfig::inner`]. |
| 80 */ |
80 */ |
| 81 |
81 |
| 82 use numeric_literals::replace_float_literals; |
82 use numeric_literals::replace_float_literals; |
| 83 use std::cmp::Ordering::*; |
|
| 84 use serde::{Serialize, Deserialize}; |
83 use serde::{Serialize, Deserialize}; |
| 85 use colored::Colorize; |
84 use colored::Colorize; |
| 86 use nalgebra::DVector; |
85 use nalgebra::DVector; |
| 87 |
86 |
| 88 use alg_tools::iterate::{ |
87 use alg_tools::iterate::{ |
| 154 None, |
153 None, |
| 155 /// FISTA-style inertia |
154 /// FISTA-style inertia |
| 156 InertiaFISTA, |
155 InertiaFISTA, |
| 157 } |
156 } |
| 158 |
157 |
| 159 /// Ergodic tolerance application style |
|
| 160 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
| 161 #[allow(dead_code)] |
|
| 162 pub enum ErgodicTolerance<F> { |
|
| 163 /// Non-ergodic iteration-wise tolerance |
|
| 164 NonErgodic, |
|
| 165 /// Bound after `n`th iteration to `factor` times value on that iteration. |
|
| 166 AfterNth{ n : usize, factor : F }, |
|
| 167 } |
|
| 168 |
|
| 169 /// Settings for [`pointsource_fb`]. |
158 /// Settings for [`pointsource_fb`]. |
| 170 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
159 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 171 #[serde(default)] |
160 #[serde(default)] |
| 172 pub struct FBConfig<F : Float> { |
161 pub struct FBConfig<F : Float> { |
| 173 /// Step length scaling |
162 /// Step length scaling |
| 188 /// Tolerance for point insertion. |
177 /// Tolerance for point insertion. |
| 189 pub tolerance : Tolerance<F>, |
178 pub tolerance : Tolerance<F>, |
| 190 /// Stop looking for predual maximum (where to isert a new point) below |
179 /// Stop looking for predual maximum (where to isert a new point) below |
| 191 /// `tolerance` multiplied by this factor. |
180 /// `tolerance` multiplied by this factor. |
| 192 pub insertion_cutoff_factor : F, |
181 pub insertion_cutoff_factor : F, |
| 193 /// Apply tolerance ergodically |
|
| 194 pub ergodic_tolerance : ErgodicTolerance<F>, |
|
| 195 /// Settings for branch and bound refinement when looking for predual maxima |
182 /// Settings for branch and bound refinement when looking for predual maxima |
| 196 pub refinement : RefinementSettings<F>, |
183 pub refinement : RefinementSettings<F>, |
| 197 /// Maximum insertions within each outer iteration |
184 /// Maximum insertions within each outer iteration |
| 198 pub max_insertions : usize, |
185 pub max_insertions : usize, |
| 199 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. |
186 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. |
| 228 fn default() -> Self { |
215 fn default() -> Self { |
| 229 FBGenericConfig { |
216 FBGenericConfig { |
| 230 insertion_style : InsertionStyle::Reuse, |
217 insertion_style : InsertionStyle::Reuse, |
| 231 tolerance : Default::default(), |
218 tolerance : Default::default(), |
| 232 insertion_cutoff_factor : 1.0, |
219 insertion_cutoff_factor : 1.0, |
| 233 ergodic_tolerance : ErgodicTolerance::NonErgodic, |
|
| 234 refinement : Default::default(), |
220 refinement : Default::default(), |
| 235 max_insertions : 100, |
221 max_insertions : 100, |
| 236 //bootstrap_insertions : None, |
222 //bootstrap_insertions : None, |
| 237 bootstrap_insertions : Some((10, 1)), |
223 bootstrap_insertions : Some((10, 1)), |
| 238 inner : InnerSettings { |
224 inner : InnerSettings { |
| 552 let preadjA = opA.preadjoint(); |
538 let preadjA = opA.preadjoint(); |
| 553 |
539 |
| 554 // Initialise iterates |
540 // Initialise iterates |
| 555 let mut μ = DiscreteMeasure::new(); |
541 let mut μ = DiscreteMeasure::new(); |
| 556 |
542 |
| 557 let mut after_nth_bound = F::INFINITY; |
|
| 558 // FIXME: Don't allocate if not needed. |
|
| 559 let mut after_nth_accum = opA.zero_observable(); |
|
| 560 |
|
| 561 let mut inner_iters = 0; |
543 let mut inner_iters = 0; |
| 562 let mut this_iters = 0; |
544 let mut this_iters = 0; |
| 563 let mut pruned = 0; |
545 let mut pruned = 0; |
| 564 let mut merged = 0; |
546 let mut merged = 0; |
| 565 |
547 |
| 626 }, |
608 }, |
| 627 InsertionStyle::Reuse => m, |
609 InsertionStyle::Reuse => m, |
| 628 }; |
610 }; |
| 629 |
611 |
| 630 // Calculate smooth part of surrogate model. |
612 // Calculate smooth part of surrogate model. |
| 631 residual *= -τ; |
|
| 632 if let ErgodicTolerance::AfterNth{ .. } = config.ergodic_tolerance { |
|
| 633 // Negative residual times τ expected here, as set above. |
|
| 634 // TODO: is this the correct location? |
|
| 635 after_nth_accum += &residual; |
|
| 636 } |
|
| 637 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
613 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
| 638 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
614 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
| 639 // the residual and replacing it below before the end of this closure. |
615 // the residual and replacing it below before the end of this closure. |
| |
616 residual *= -τ; |
| 640 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
617 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
| 641 let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b) |
618 let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b) |
| 642 // TODO: should avoid a second copy of μ here; μ_base already stores a copy. |
619 // TODO: should avoid a second copy of μ here; μ_base already stores a copy. |
| 643 let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k |
620 let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k |
| 644 //let g = &minus_τv + ω0; // Linear term of surrogate model |
621 //let g = &minus_τv + ω0; // Linear term of surrogate model |
| 686 false |
663 false |
| 687 } else { |
664 } else { |
| 688 count > 0 |
665 count > 0 |
| 689 }; |
666 }; |
| 690 |
667 |
| 691 // First do a rough check whether we are within bounds and can stop. |
|
| 692 let in_bounds = match config.ergodic_tolerance { |
|
| 693 ErgodicTolerance::NonErgodic => { |
|
| 694 target_bounds.superset(&d.bounds()) |
|
| 695 }, |
|
| 696 ErgodicTolerance::AfterNth{ n, factor } => { |
|
| 697 // Bound -τ∑_{k=0}^{N-1}[A_*(Aμ^k-b)+α] from above. |
|
| 698 match state.iteration().cmp(&n) { |
|
| 699 Less => true, |
|
| 700 Equal => { |
|
| 701 let iter = F::cast_from(state.iteration()); |
|
| 702 let mut tmp = preadjA.apply(&after_nth_accum); |
|
| 703 let (_, v0) = tmp.maximise(refinement_tolerance, |
|
| 704 config.refinement.max_steps); |
|
| 705 let v = v0 - iter * τ * α; |
|
| 706 after_nth_bound = factor * v; |
|
| 707 println!("{}", format!("Set ergodic tolerance to {}", after_nth_bound)); |
|
| 708 true |
|
| 709 }, |
|
| 710 Greater => { |
|
| 711 // TODO: can divide after_nth_accum by N, so use basic tolerance on that. |
|
| 712 let iter = F::cast_from(state.iteration()); |
|
| 713 let mut tmp = preadjA.apply(&after_nth_accum); |
|
| 714 tmp.has_upper_bound(after_nth_bound + iter * τ * α, |
|
| 715 refinement_tolerance, |
|
| 716 config.refinement.max_steps) |
|
| 717 } |
|
| 718 } |
|
| 719 } |
|
| 720 }; |
|
| 721 |
|
| 722 // If preliminary check indicates that we are in bonds, and if it otherwise matches |
668 // If preliminary check indicates that we are in bonds, and if it otherwise matches |
| 723 // the insertion strategy, skip insertion. |
669 // the insertion strategy, skip insertion. |
| 724 if may_break && in_bounds { |
670 if may_break && target_bounds.superset(&d.bounds()) { |
| 725 break 'insertion (true, d) |
671 break 'insertion (true, d) |
| 726 } |
672 } |
| 727 |
673 |
| 728 // If the rough check didn't indicate stopping, find maximising point, maintaining for |
674 // If the rough check didn't indicate stopping, find maximising point, maintaining for |
| 729 // the calculations in the beginning of the loop that v_ξ = (ω0-τv-𝒟μ)(ξ) = d(ξ), |
675 // the calculations in the beginning of the loop that v_ξ = (ω0-τv-𝒟μ)(ξ) = d(ξ), |