78 We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by |
78 We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by |
79 [`InnerSettings`] in [`FBGenericConfig::inner`]. |
79 [`InnerSettings`] in [`FBGenericConfig::inner`]. |
80 */ |
80 */ |
81 |
81 |
82 use numeric_literals::replace_float_literals; |
82 use numeric_literals::replace_float_literals; |
83 use std::cmp::Ordering::*; |
|
84 use serde::{Serialize, Deserialize}; |
83 use serde::{Serialize, Deserialize}; |
85 use colored::Colorize; |
84 use colored::Colorize; |
86 use nalgebra::DVector; |
85 use nalgebra::DVector; |
87 |
86 |
88 use alg_tools::iterate::{ |
87 use alg_tools::iterate::{ |
154 None, |
153 None, |
155 /// FISTA-style inertia |
154 /// FISTA-style inertia |
156 InertiaFISTA, |
155 InertiaFISTA, |
157 } |
156 } |
158 |
157 |
159 /// Ergodic tolerance application style |
|
160 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
161 #[allow(dead_code)] |
|
162 pub enum ErgodicTolerance<F> { |
|
163 /// Non-ergodic iteration-wise tolerance |
|
164 NonErgodic, |
|
165 /// Bound after `n`th iteration to `factor` times value on that iteration. |
|
166 AfterNth{ n : usize, factor : F }, |
|
167 } |
|
168 |
|
169 /// Settings for [`pointsource_fb`]. |
158 /// Settings for [`pointsource_fb`]. |
170 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
159 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
171 #[serde(default)] |
160 #[serde(default)] |
172 pub struct FBConfig<F : Float> { |
161 pub struct FBConfig<F : Float> { |
173 /// Step length scaling |
162 /// Step length scaling |
188 /// Tolerance for point insertion. |
177 /// Tolerance for point insertion. |
189 pub tolerance : Tolerance<F>, |
178 pub tolerance : Tolerance<F>, |
190 /// Stop looking for predual maximum (where to isert a new point) below |
179 /// Stop looking for predual maximum (where to isert a new point) below |
191 /// `tolerance` multiplied by this factor. |
180 /// `tolerance` multiplied by this factor. |
192 pub insertion_cutoff_factor : F, |
181 pub insertion_cutoff_factor : F, |
193 /// Apply tolerance ergodically |
|
194 pub ergodic_tolerance : ErgodicTolerance<F>, |
|
195 /// Settings for branch and bound refinement when looking for predual maxima |
182 /// Settings for branch and bound refinement when looking for predual maxima |
196 pub refinement : RefinementSettings<F>, |
183 pub refinement : RefinementSettings<F>, |
197 /// Maximum insertions within each outer iteration |
184 /// Maximum insertions within each outer iteration |
198 pub max_insertions : usize, |
185 pub max_insertions : usize, |
199 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. |
186 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. |
228 fn default() -> Self { |
215 fn default() -> Self { |
229 FBGenericConfig { |
216 FBGenericConfig { |
230 insertion_style : InsertionStyle::Reuse, |
217 insertion_style : InsertionStyle::Reuse, |
231 tolerance : Default::default(), |
218 tolerance : Default::default(), |
232 insertion_cutoff_factor : 1.0, |
219 insertion_cutoff_factor : 1.0, |
233 ergodic_tolerance : ErgodicTolerance::NonErgodic, |
|
234 refinement : Default::default(), |
220 refinement : Default::default(), |
235 max_insertions : 100, |
221 max_insertions : 100, |
236 //bootstrap_insertions : None, |
222 //bootstrap_insertions : None, |
237 bootstrap_insertions : Some((10, 1)), |
223 bootstrap_insertions : Some((10, 1)), |
238 inner : InnerSettings { |
224 inner : InnerSettings { |
552 let preadjA = opA.preadjoint(); |
538 let preadjA = opA.preadjoint(); |
553 |
539 |
554 // Initialise iterates |
540 // Initialise iterates |
555 let mut μ = DiscreteMeasure::new(); |
541 let mut μ = DiscreteMeasure::new(); |
556 |
542 |
557 let mut after_nth_bound = F::INFINITY; |
|
558 // FIXME: Don't allocate if not needed. |
|
559 let mut after_nth_accum = opA.zero_observable(); |
|
560 |
|
561 let mut inner_iters = 0; |
543 let mut inner_iters = 0; |
562 let mut this_iters = 0; |
544 let mut this_iters = 0; |
563 let mut pruned = 0; |
545 let mut pruned = 0; |
564 let mut merged = 0; |
546 let mut merged = 0; |
565 |
547 |
626 }, |
608 }, |
627 InsertionStyle::Reuse => m, |
609 InsertionStyle::Reuse => m, |
628 }; |
610 }; |
629 |
611 |
630 // Calculate smooth part of surrogate model. |
612 // Calculate smooth part of surrogate model. |
631 residual *= -τ; |
|
632 if let ErgodicTolerance::AfterNth{ .. } = config.ergodic_tolerance { |
|
633 // Negative residual times τ expected here, as set above. |
|
634 // TODO: is this the correct location? |
|
635 after_nth_accum += &residual; |
|
636 } |
|
637 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
613 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` |
638 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
614 // has no significant overhead. For some reosn Rust doesn't allow us simply moving |
639 // the residual and replacing it below before the end of this closure. |
615 // the residual and replacing it below before the end of this closure. |
|
616 residual *= -τ; |
640 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
617 let r = std::mem::replace(&mut residual, opA.empty_observable()); |
641 let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b) |
618 let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b) |
642 // TODO: should avoid a second copy of μ here; μ_base already stores a copy. |
619 // TODO: should avoid a second copy of μ here; μ_base already stores a copy. |
643 let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k |
620 let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k |
644 //let g = &minus_τv + ω0; // Linear term of surrogate model |
621 //let g = &minus_τv + ω0; // Linear term of surrogate model |
686 false |
663 false |
687 } else { |
664 } else { |
688 count > 0 |
665 count > 0 |
689 }; |
666 }; |
690 |
667 |
691 // First do a rough check whether we are within bounds and can stop. |
|
692 let in_bounds = match config.ergodic_tolerance { |
|
693 ErgodicTolerance::NonErgodic => { |
|
694 target_bounds.superset(&d.bounds()) |
|
695 }, |
|
696 ErgodicTolerance::AfterNth{ n, factor } => { |
|
697 // Bound -τ∑_{k=0}^{N-1}[A_*(Aμ^k-b)+α] from above. |
|
698 match state.iteration().cmp(&n) { |
|
699 Less => true, |
|
700 Equal => { |
|
701 let iter = F::cast_from(state.iteration()); |
|
702 let mut tmp = preadjA.apply(&after_nth_accum); |
|
703 let (_, v0) = tmp.maximise(refinement_tolerance, |
|
704 config.refinement.max_steps); |
|
705 let v = v0 - iter * τ * α; |
|
706 after_nth_bound = factor * v; |
|
707 println!("{}", format!("Set ergodic tolerance to {}", after_nth_bound)); |
|
708 true |
|
709 }, |
|
710 Greater => { |
|
711 // TODO: can divide after_nth_accum by N, so use basic tolerance on that. |
|
712 let iter = F::cast_from(state.iteration()); |
|
713 let mut tmp = preadjA.apply(&after_nth_accum); |
|
714 tmp.has_upper_bound(after_nth_bound + iter * τ * α, |
|
715 refinement_tolerance, |
|
716 config.refinement.max_steps) |
|
717 } |
|
718 } |
|
719 } |
|
720 }; |
|
721 |
|
722 // If preliminary check indicates that we are in bonds, and if it otherwise matches |
668 // If preliminary check indicates that we are in bonds, and if it otherwise matches |
723 // the insertion strategy, skip insertion. |
669 // the insertion strategy, skip insertion. |
724 if may_break && in_bounds { |
670 if may_break && target_bounds.superset(&d.bounds()) { |
725 break 'insertion (true, d) |
671 break 'insertion (true, d) |
726 } |
672 } |
727 |
673 |
728 // If the rough check didn't indicate stopping, find maximising point, maintaining for |
674 // If the rough check didn't indicate stopping, find maximising point, maintaining for |
729 // the calculations in the beginning of the loop that v_ξ = (ω0-τv-𝒟μ)(ξ) = d(ξ), |
675 // the calculations in the beginning of the loop that v_ξ = (ω0-τv-𝒟μ)(ξ) = d(ξ), |