src/fb.rs

changeset 1
d4fd5f32d10e
parent 0
eb3c7813b67a
equal deleted inserted replaced
0:eb3c7813b67a 1:d4fd5f32d10e
82 use numeric_literals::replace_float_literals; 82 use numeric_literals::replace_float_literals;
83 use std::cmp::Ordering::*; 83 use std::cmp::Ordering::*;
84 use serde::{Serialize, Deserialize}; 84 use serde::{Serialize, Deserialize};
85 use colored::Colorize; 85 use colored::Colorize;
86 use nalgebra::DVector; 86 use nalgebra::DVector;
87 use clap::Parser;
87 88
88 use alg_tools::iterate::{ 89 use alg_tools::iterate::{
89 AlgIteratorFactory, 90 AlgIteratorFactory,
90 AlgIteratorState, 91 AlgIteratorState,
91 }; 92 };
144 Reuse, 145 Reuse,
145 /// Start each iteration with $μ=0$. 146 /// Start each iteration with $μ=0$.
146 Zero, 147 Zero,
147 } 148 }
148 149
150 impl Default for InsertionStyle {
151 fn default() -> Self {
152 Self::Reuse
153 }
154 }
149 /// Meta-algorithm type 155 /// Meta-algorithm type
150 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 156 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
151 #[allow(dead_code)] 157 #[allow(dead_code)]
152 pub enum FBMetaAlgorithm { 158 pub enum FBMetaAlgorithm {
153 /// No meta-algorithm 159 /// No meta-algorithm
164 NonErgodic, 170 NonErgodic,
165 /// Bound after `n`th iteration to `factor` times value on that iteration. 171 /// Bound after `n`th iteration to `factor` times value on that iteration.
166 AfterNth{ n : usize, factor : F }, 172 AfterNth{ n : usize, factor : F },
167 } 173 }
168 174
175 impl<F : ClapFloat> Default for ErgodicTolerance<F> {
176 fn default() -> Self {
177 Self::NonErgodic
178 }
179 }
180
169 /// Settings for [`pointsource_fb`]. 181 /// Settings for [`pointsource_fb`].
170 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 182 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
171 #[serde(default)] 183 #[serde(default)]
172 pub struct FBConfig<F : Float> { 184 pub struct FBConfig<F : ClapFloat> {
173 /// Step length scaling 185 /// Step length scaling
174 pub τ0 : F, 186 pub τ0 : F,
175 /// Meta-algorithm to apply 187 /// Meta-algorithm to apply
176 pub meta : FBMetaAlgorithm, 188 pub meta : FBMetaAlgorithm,
177 /// Generic parameters 189 /// Generic parameters
178 pub insertion : FBGenericConfig<F>, 190 pub insertion : FBGenericConfig<F>,
179 } 191 }
180 192
181 /// Settings for the solution of the stepwise optimality condition in algorithms based on 193 /// Settings for the solution of the stepwise optimality condition in algorithms based on
182 /// [`generic_pointsource_fb`]. 194 /// [`generic_pointsource_fb`].
183 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 195 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug, Parser)]
184 #[serde(default)] 196 #[serde(default)]
185 pub struct FBGenericConfig<F : Float> { 197 pub struct FBGenericConfig<F : ClapFloat> {
198 #[clap(skip)]
186 /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`]. 199 /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`].
187 pub insertion_style : InsertionStyle, 200 pub insertion_style : InsertionStyle,
201 #[clap(skip)]
188 /// Tolerance for point insertion. 202 /// Tolerance for point insertion.
189 pub tolerance : Tolerance<F>, 203 pub tolerance : Tolerance<F>,
190 /// Stop looking for predual maximum (where to isert a new point) below 204 /// Stop looking for predual maximum (where to isert a new point) below
191 /// `tolerance` multiplied by this factor. 205 /// `tolerance` multiplied by this factor.
192 pub insertion_cutoff_factor : F, 206 pub insertion_cutoff_factor : F,
207 #[clap(skip)]
193 /// Apply tolerance ergodically 208 /// Apply tolerance ergodically
194 pub ergodic_tolerance : ErgodicTolerance<F>, 209 pub ergodic_tolerance : ErgodicTolerance<F>,
210 #[clap(skip)]
195 /// Settings for branch and bound refinement when looking for predual maxima 211 /// Settings for branch and bound refinement when looking for predual maxima
196 pub refinement : RefinementSettings<F>, 212 pub refinement : RefinementSettings<F>,
197 /// Maximum insertions within each outer iteration 213 /// Maximum insertions within each outer iteration
198 pub max_insertions : usize, 214 pub max_insertions : usize,
199 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. 215 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations.
200 pub bootstrap_insertions : Option<(usize, usize)>, 216 pub bootstrap_insertions : Option<(usize, usize)>,
217 #[clap(skip)]
201 /// Inner method settings 218 /// Inner method settings
202 pub inner : InnerSettings<F>, 219 pub inner : InnerSettings<F>,
203 /// Spike merging method 220 /// Spike merging method
204 pub merging : SpikeMergingMethod<F>, 221 pub merging : SpikeMergingMethod<F>,
205 /// Tolerance multiplier for merges 222 /// Tolerance multiplier for merges
211 /// Save $μ$ for postprocessing optimisation 228 /// Save $μ$ for postprocessing optimisation
212 pub postprocessing : bool 229 pub postprocessing : bool
213 } 230 }
214 231
215 #[replace_float_literals(F::cast_from(literal))] 232 #[replace_float_literals(F::cast_from(literal))]
216 impl<F : Float> Default for FBConfig<F> { 233 impl<F : ClapFloat> Default for FBConfig<F> {
217 fn default() -> Self { 234 fn default() -> Self {
218 FBConfig { 235 FBConfig {
219 τ0 : 0.99, 236 τ0 : 0.99,
220 meta : FBMetaAlgorithm::None, 237 meta : FBMetaAlgorithm::None,
221 insertion : Default::default() 238 insertion : Default::default()
222 } 239 }
223 } 240 }
224 } 241 }
225 242
226 #[replace_float_literals(F::cast_from(literal))] 243 #[replace_float_literals(F::cast_from(literal))]
227 impl<F : Float> Default for FBGenericConfig<F> { 244 impl<F : ClapFloat> Default for FBGenericConfig<F> {
228 fn default() -> Self { 245 fn default() -> Self {
229 FBGenericConfig { 246 FBGenericConfig {
230 insertion_style : InsertionStyle::Reuse, 247 insertion_style : InsertionStyle::Reuse,
231 tolerance : Default::default(), 248 tolerance : Default::default(),
232 insertion_cutoff_factor : 1.0, 249 insertion_cutoff_factor : 1.0,
455 op𝒟 : &'a 𝒟, 472 op𝒟 : &'a 𝒟,
456 config : &FBConfig<F>, 473 config : &FBConfig<F>,
457 iterator : I, 474 iterator : I,
458 plotter : SeqPlotter<F, N> 475 plotter : SeqPlotter<F, N>
459 ) -> DiscreteMeasure<Loc<F, N>, F> 476 ) -> DiscreteMeasure<Loc<F, N>, F>
460 where F : Float + ToNalgebraRealField, 477 where F : ClapFloat + ToNalgebraRealField,
461 I : AlgIteratorFactory<IterInfo<F, N>>, 478 I : AlgIteratorFactory<IterInfo<F, N>>,
462 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>, 479 for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
463 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow 480 //+ std::ops::Mul<F, Output=A::Observable>, <-- FIXME: compiler overflow
464 A::Observable : std::ops::MulAssign<F>, 481 A::Observable : std::ops::MulAssign<F>,
465 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 482 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
519 iterator : I, 536 iterator : I,
520 mut plotter : SeqPlotter<F, N>, 537 mut plotter : SeqPlotter<F, N>,
521 mut residual : A::Observable, 538 mut residual : A::Observable,
522 mut specialisation : Spec, 539 mut specialisation : Spec,
523 ) -> DiscreteMeasure<Loc<F, N>, F> 540 ) -> DiscreteMeasure<Loc<F, N>, F>
524 where F : Float + ToNalgebraRealField, 541 where F : ClapFloat + ToNalgebraRealField,
525 I : AlgIteratorFactory<IterInfo<F, N>>, 542 I : AlgIteratorFactory<IterInfo<F, N>>,
526 Spec : FBSpecialisation<F, A::Observable, N>, 543 Spec : FBSpecialisation<F, A::Observable, N>,
527 A::Observable : std::ops::MulAssign<F>, 544 A::Observable : std::ops::MulAssign<F>,
528 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone, 545 GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
529 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>> 546 A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>

mercurial