src/fb.rs

changeset 7
c32171f7cce5
parent 0
eb3c7813b67a
child 8
ea3ca78873e8
equal deleted inserted replaced
6:bcb508479948 7:c32171f7cce5
78 We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by 78 We solve this with either SSN or FB via [`quadratic_nonneg`] as determined by
79 [`InnerSettings`] in [`FBGenericConfig::inner`]. 79 [`InnerSettings`] in [`FBGenericConfig::inner`].
80 */ 80 */
81 81
82 use numeric_literals::replace_float_literals; 82 use numeric_literals::replace_float_literals;
83 use std::cmp::Ordering::*;
84 use serde::{Serialize, Deserialize}; 83 use serde::{Serialize, Deserialize};
85 use colored::Colorize; 84 use colored::Colorize;
86 use nalgebra::DVector; 85 use nalgebra::DVector;
87 86
88 use alg_tools::iterate::{ 87 use alg_tools::iterate::{
154 None, 153 None,
155 /// FISTA-style inertia 154 /// FISTA-style inertia
156 InertiaFISTA, 155 InertiaFISTA,
157 } 156 }
158 157
159 /// Ergodic tolerance application style
160 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
161 #[allow(dead_code)]
162 pub enum ErgodicTolerance<F> {
163 /// Non-ergodic iteration-wise tolerance
164 NonErgodic,
165 /// Bound after `n`th iteration to `factor` times value on that iteration.
166 AfterNth{ n : usize, factor : F },
167 }
168
169 /// Settings for [`pointsource_fb`]. 158 /// Settings for [`pointsource_fb`].
170 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 159 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
171 #[serde(default)] 160 #[serde(default)]
172 pub struct FBConfig<F : Float> { 161 pub struct FBConfig<F : Float> {
173 /// Step length scaling 162 /// Step length scaling
188 /// Tolerance for point insertion. 177 /// Tolerance for point insertion.
189 pub tolerance : Tolerance<F>, 178 pub tolerance : Tolerance<F>,
190 /// Stop looking for predual maximum (where to isert a new point) below 179 /// Stop looking for predual maximum (where to isert a new point) below
191 /// `tolerance` multiplied by this factor. 180 /// `tolerance` multiplied by this factor.
192 pub insertion_cutoff_factor : F, 181 pub insertion_cutoff_factor : F,
193 /// Apply tolerance ergodically
194 pub ergodic_tolerance : ErgodicTolerance<F>,
195 /// Settings for branch and bound refinement when looking for predual maxima 182 /// Settings for branch and bound refinement when looking for predual maxima
196 pub refinement : RefinementSettings<F>, 183 pub refinement : RefinementSettings<F>,
197 /// Maximum insertions within each outer iteration 184 /// Maximum insertions within each outer iteration
198 pub max_insertions : usize, 185 pub max_insertions : usize,
199 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. 186 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations.
228 fn default() -> Self { 215 fn default() -> Self {
229 FBGenericConfig { 216 FBGenericConfig {
230 insertion_style : InsertionStyle::Reuse, 217 insertion_style : InsertionStyle::Reuse,
231 tolerance : Default::default(), 218 tolerance : Default::default(),
232 insertion_cutoff_factor : 1.0, 219 insertion_cutoff_factor : 1.0,
233 ergodic_tolerance : ErgodicTolerance::NonErgodic,
234 refinement : Default::default(), 220 refinement : Default::default(),
235 max_insertions : 100, 221 max_insertions : 100,
236 //bootstrap_insertions : None, 222 //bootstrap_insertions : None,
237 bootstrap_insertions : Some((10, 1)), 223 bootstrap_insertions : Some((10, 1)),
238 inner : InnerSettings { 224 inner : InnerSettings {
552 let preadjA = opA.preadjoint(); 538 let preadjA = opA.preadjoint();
553 539
554 // Initialise iterates 540 // Initialise iterates
555 let mut μ = DiscreteMeasure::new(); 541 let mut μ = DiscreteMeasure::new();
556 542
557 let mut after_nth_bound = F::INFINITY;
558 // FIXME: Don't allocate if not needed.
559 let mut after_nth_accum = opA.zero_observable();
560
561 let mut inner_iters = 0; 543 let mut inner_iters = 0;
562 let mut this_iters = 0; 544 let mut this_iters = 0;
563 let mut pruned = 0; 545 let mut pruned = 0;
564 let mut merged = 0; 546 let mut merged = 0;
565 547
626 }, 608 },
627 InsertionStyle::Reuse => m, 609 InsertionStyle::Reuse => m,
628 }; 610 };
629 611
630 // Calculate smooth part of surrogate model. 612 // Calculate smooth part of surrogate model.
631 residual *= -τ;
632 if let ErgodicTolerance::AfterNth{ .. } = config.ergodic_tolerance {
633 // Negative residual times τ expected here, as set above.
634 // TODO: is this the correct location?
635 after_nth_accum += &residual;
636 }
637 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable` 613 // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
638 // has no significant overhead. For some reosn Rust doesn't allow us simply moving 614 // has no significant overhead. For some reosn Rust doesn't allow us simply moving
639 // the residual and replacing it below before the end of this closure. 615 // the residual and replacing it below before the end of this closure.
616 residual *= -τ;
640 let r = std::mem::replace(&mut residual, opA.empty_observable()); 617 let r = std::mem::replace(&mut residual, opA.empty_observable());
641 let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b) 618 let minus_τv = preadjA.apply(r); // minus_τv = -τA^*(Aμ^k-b)
642 // TODO: should avoid a second copy of μ here; μ_base already stores a copy. 619 // TODO: should avoid a second copy of μ here; μ_base already stores a copy.
643 let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k 620 let ω0 = op𝒟.apply(μ.clone()); // 𝒟μ^k
644 //let g = &minus_τv + ω0; // Linear term of surrogate model 621 //let g = &minus_τv + ω0; // Linear term of surrogate model
686 false 663 false
687 } else { 664 } else {
688 count > 0 665 count > 0
689 }; 666 };
690 667
691 // First do a rough check whether we are within bounds and can stop.
692 let in_bounds = match config.ergodic_tolerance {
693 ErgodicTolerance::NonErgodic => {
694 target_bounds.superset(&d.bounds())
695 },
696 ErgodicTolerance::AfterNth{ n, factor } => {
697 // Bound -τ∑_{k=0}^{N-1}[A_*(Aμ^k-b)+α] from above.
698 match state.iteration().cmp(&n) {
699 Less => true,
700 Equal => {
701 let iter = F::cast_from(state.iteration());
702 let mut tmp = preadjA.apply(&after_nth_accum);
703 let (_, v0) = tmp.maximise(refinement_tolerance,
704 config.refinement.max_steps);
705 let v = v0 - iter * τ * α;
706 after_nth_bound = factor * v;
707 println!("{}", format!("Set ergodic tolerance to {}", after_nth_bound));
708 true
709 },
710 Greater => {
711 // TODO: can divide after_nth_accum by N, so use basic tolerance on that.
712 let iter = F::cast_from(state.iteration());
713 let mut tmp = preadjA.apply(&after_nth_accum);
714 tmp.has_upper_bound(after_nth_bound + iter * τ * α,
715 refinement_tolerance,
716 config.refinement.max_steps)
717 }
718 }
719 }
720 };
721
722 // If preliminary check indicates that we are in bonds, and if it otherwise matches 668 // If preliminary check indicates that we are in bonds, and if it otherwise matches
723 // the insertion strategy, skip insertion. 669 // the insertion strategy, skip insertion.
724 if may_break && in_bounds { 670 if may_break && target_bounds.superset(&d.bounds()) {
725 break 'insertion (true, d) 671 break 'insertion (true, d)
726 } 672 }
727 673
728 // If the rough check didn't indicate stopping, find maximising point, maintaining for 674 // If the rough check didn't indicate stopping, find maximising point, maintaining for
729 // the calculations in the beginning of the loop that v_ξ = (ω0-τv-𝒟μ)(ξ) = d(ξ), 675 // the calculations in the beginning of the loop that v_ξ = (ω0-τv-𝒟μ)(ξ) = d(ξ),

mercurial