src/fb.rs

branch
dev
changeset 34
efa60bc4f743
parent 32
56c8adc32b09
child 35
b087e3eab191
equal deleted inserted replaced
33:aec67cdd6b14 34:efa60bc4f743
134 calculate_residual, 134 calculate_residual,
135 L2Squared, 135 L2Squared,
136 DataTerm, 136 DataTerm,
137 }; 137 };
138 138
139 /// Method for constructing $μ$ on each iteration
140 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
141 #[allow(dead_code)]
142 pub enum InsertionStyle {
143 /// Resuse previous $μ$ from previous iteration, optimising weights
144 /// before inserting new spikes.
145 Reuse,
146 /// Start each iteration with $μ=0$.
147 Zero,
148 }
149
150 /// Settings for [`pointsource_fb_reg`]. 139 /// Settings for [`pointsource_fb_reg`].
151 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 140 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
152 #[serde(default)] 141 #[serde(default)]
153 pub struct FBConfig<F : Float> { 142 pub struct FBConfig<F : Float> {
154 /// Step length scaling 143 /// Step length scaling
155 pub τ0 : F, 144 pub τ0 : F,
156 /// Generic parameters 145 /// Generic parameters
157 pub insertion : FBGenericConfig<F>, 146 pub generic : FBGenericConfig<F>,
158 } 147 }
159 148
160 /// Settings for the solution of the stepwise optimality condition in algorithms based on 149 /// Settings for the solution of the stepwise optimality condition in algorithms based on
161 /// [`generic_pointsource_fb_reg`]. 150 /// [`generic_pointsource_fb_reg`].
162 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] 151 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
163 #[serde(default)] 152 #[serde(default)]
164 pub struct FBGenericConfig<F : Float> { 153 pub struct FBGenericConfig<F : Float> {
165 /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`].
166 pub insertion_style : InsertionStyle,
167 /// Tolerance for point insertion. 154 /// Tolerance for point insertion.
168 pub tolerance : Tolerance<F>, 155 pub tolerance : Tolerance<F>,
156
169 /// Stop looking for predual maximum (where to isert a new point) below 157 /// Stop looking for predual maximum (where to isert a new point) below
170 /// `tolerance` multiplied by this factor. 158 /// `tolerance` multiplied by this factor.
159 ///
160 /// Not used by [`super::radon_fb`].
171 pub insertion_cutoff_factor : F, 161 pub insertion_cutoff_factor : F,
162
172 /// Settings for branch and bound refinement when looking for predual maxima 163 /// Settings for branch and bound refinement when looking for predual maxima
173 pub refinement : RefinementSettings<F>, 164 pub refinement : RefinementSettings<F>,
165
174 /// Maximum insertions within each outer iteration 166 /// Maximum insertions within each outer iteration
167 ///
168 /// Not used by [`super::radon_fb`].
175 pub max_insertions : usize, 169 pub max_insertions : usize,
170
176 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. 171 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations.
172 ///
173 /// Not used by [`super::radon_fb`].
177 pub bootstrap_insertions : Option<(usize, usize)>, 174 pub bootstrap_insertions : Option<(usize, usize)>,
175
178 /// Inner method settings 176 /// Inner method settings
179 pub inner : InnerSettings<F>, 177 pub inner : InnerSettings<F>,
178
180 /// Spike merging method 179 /// Spike merging method
181 pub merging : SpikeMergingMethod<F>, 180 pub merging : SpikeMergingMethod<F>,
181
182 /// Tolerance multiplier for merges 182 /// Tolerance multiplier for merges
183 pub merge_tolerance_mult : F, 183 pub merge_tolerance_mult : F,
184
184 /// Spike merging method after the last step 185 /// Spike merging method after the last step
185 pub final_merging : SpikeMergingMethod<F>, 186 pub final_merging : SpikeMergingMethod<F>,
187
186 /// Iterations between merging heuristic tries 188 /// Iterations between merging heuristic tries
187 pub merge_every : usize, 189 pub merge_every : usize,
190
188 /// Save $μ$ for postprocessing optimisation 191 /// Save $μ$ for postprocessing optimisation
189 pub postprocessing : bool 192 pub postprocessing : bool
190 } 193 }
191 194
192 #[replace_float_literals(F::cast_from(literal))] 195 #[replace_float_literals(F::cast_from(literal))]
193 impl<F : Float> Default for FBConfig<F> { 196 impl<F : Float> Default for FBConfig<F> {
194 fn default() -> Self { 197 fn default() -> Self {
195 FBConfig { 198 FBConfig {
196 τ0 : 0.99, 199 τ0 : 0.99,
197 insertion : Default::default() 200 generic : Default::default(),
198 } 201 }
199 } 202 }
200 } 203 }
201 204
202 #[replace_float_literals(F::cast_from(literal))] 205 #[replace_float_literals(F::cast_from(literal))]
203 impl<F : Float> Default for FBGenericConfig<F> { 206 impl<F : Float> Default for FBGenericConfig<F> {
204 fn default() -> Self { 207 fn default() -> Self {
205 FBGenericConfig { 208 FBGenericConfig {
206 insertion_style : InsertionStyle::Reuse,
207 tolerance : Default::default(), 209 tolerance : Default::default(),
208 insertion_cutoff_factor : 1.0, 210 insertion_cutoff_factor : 1.0,
209 refinement : Default::default(), 211 refinement : Default::default(),
210 max_insertions : 100, 212 max_insertions : 100,
211 //bootstrap_insertions : None, 213 //bootstrap_insertions : None,
212 bootstrap_insertions : Some((10, 1)), 214 bootstrap_insertions : Some((10, 1)),
213 inner : InnerSettings { 215 inner : InnerSettings {
214 method : InnerMethod::SSN, 216 method : InnerMethod::Default,
215 .. Default::default() 217 .. Default::default()
216 }, 218 },
217 merging : SpikeMergingMethod::None, 219 merging : SpikeMergingMethod::None,
218 //merging : Default::default(), 220 //merging : Default::default(),
219 final_merging : Default::default(), 221 final_merging : Default::default(),
222 postprocessing : false, 224 postprocessing : false,
223 } 225 }
224 } 226 }
225 } 227 }
226 228
227 #[replace_float_literals(F::cast_from(literal))] 229 /// TODO: document.
228 pub(crate) fn μ_diff<F : Float, const N : usize>( 230 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike
229 μ_new : &DiscreteMeasure<Loc<F, N>, F>, 231 /// locations, while `ν_delta` may have different locations.
230 μ_base : &DiscreteMeasure<Loc<F, N>, F>,
231 ν_delta : Option<&DiscreteMeasure<Loc<F, N>, F>>,
232 config : &FBGenericConfig<F>
233 ) -> DiscreteMeasure<Loc<F, N>, F> {
234 let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style {
235 InsertionStyle::Reuse => {
236 μ_new.iter_spikes()
237 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0)))
238 .map(|(δ, α_base)| (δ.x, α_base - δ.α))
239 .collect()
240 },
241 InsertionStyle::Zero => {
242 μ_new.iter_spikes()
243 .map(|δ| -δ)
244 .chain(μ_base.iter_spikes().copied())
245 .collect()
246 }
247 };
248 ν.prune(); // Potential small performance improvement
249 // Add ν_delta if given
250 match ν_delta {
251 None => ν,
252 Some(ν_d) => ν + ν_d,
253 }
254 }
255
256 #[replace_float_literals(F::cast_from(literal))] 232 #[replace_float_literals(F::cast_from(literal))]
257 pub(crate) fn insert_and_reweigh< 233 pub(crate) fn insert_and_reweigh<
258 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize 234 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize
259 >( 235 >(
260 μ : &mut DiscreteMeasure<Loc<F, N>, F>, 236 μ : &mut DiscreteMeasure<Loc<F, N>, F>,
282 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 258 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
283 Reg : RegTerm<F, N>, 259 Reg : RegTerm<F, N>,
284 State : AlgIteratorState { 260 State : AlgIteratorState {
285 261
286 // Maximum insertion count and measure difference calculation depend on insertion style. 262 // Maximum insertion count and measure difference calculation depend on insertion style.
287 let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { 263 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) {
288 (i, Some((l, k))) if i <= l => (k, false), 264 (i, Some((l, k))) if i <= l => (k, false),
289 _ => (config.max_insertions, !state.is_quiet()), 265 _ => (config.max_insertions, !state.is_quiet()),
290 }; 266 };
291 let max_insertions = match config.insertion_style { 267
292 InsertionStyle::Zero => { 268 // TODO: should avoid a copy of μ_base here.
293 todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled.");
294 // let n = μ.len();
295 // μ = DiscreteMeasure::new();
296 // n + m
297 },
298 InsertionStyle::Reuse => m,
299 };
300
301 // TODO: should avoid a second copy of μ here; μ_base already stores a copy.
302 let ω0 = op𝒟.apply(match ν_delta { 269 let ω0 = op𝒟.apply(match ν_delta {
303 None => μ.clone(), 270 None => μ_base.clone(),
304 Some(ν_d) => &*μ + ν_d, 271 Some(ν_d) => &*μ_base + ν_d,
305 }); 272 });
306 273
307 // Add points to support until within error tolerance or maximum insertion count reached. 274 // Add points to support until within error tolerance or maximum insertion count reached.
308 let mut count = 0; 275 let mut count = 0;
309 let (within_tolerances, d) = 'insertion: loop { 276 let (within_tolerances, d) = 'insertion: loop {
331 μ.set_masses_dvector(&x); 298 μ.set_masses_dvector(&x);
332 } 299 }
333 300
334 // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality 301 // Form d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv for checking the proximate optimality
335 // conditions in the predual space, and finding new points for insertion, if necessary. 302 // conditions in the predual space, and finding new points for insertion, if necessary.
336 let mut d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_base, ν_delta, config)); 303 let mut d = minus_τv + op𝒟.preapply(match ν_delta {
304 None => μ_base.sub_matching(μ),
305 Some(ν) => μ_base.sub_matching(μ) + ν
306 });
337 307
338 // If no merging heuristic is used, let's be more conservative about spike insertion, 308 // If no merging heuristic is used, let's be more conservative about spike insertion,
339 // and skip it after first round. If merging is done, being more greedy about spike 309 // and skip it after first round. If merging is done, being more greedy about spike
340 // insertion also seems to improve performance. 310 // insertion also seems to improve performance.
341 let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging { 311 let skip_by_rough_check = if let SpikeMergingMethod::None = config.merging {
402 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, 372 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
403 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 373 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
404 Reg : RegTerm<F, N>, 374 Reg : RegTerm<F, N>,
405 State : AlgIteratorState { 375 State : AlgIteratorState {
406 if state.iteration() % config.merge_every == 0 { 376 if state.iteration() % config.merge_every == 0 {
407 let n_before_merge = μ.len(); 377 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| {
408 μ.merge_spikes(config.merging, |μ_candidate| { 378 let mut d = minus_τv + op𝒟.preapply(μ_base.sub_matching(&μ_candidate));
409 let μd = μ_diff(&μ_candidate, &μ_base, None, config);
410 let mut d = minus_τv + op𝒟.preapply(μd);
411
412 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) 379 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config)
413 .then_some(())
414 }); 380 });
415 debug_assert!(μ.len() >= n_before_merge);
416 stats.merged += μ.len() - n_before_merge;
417 } 381 }
418 382
419 let n_before_prune = μ.len(); 383 let n_before_prune = μ.len();
420 μ.prune(); 384 μ.prune();
421 debug_assert!(μ.len() <= n_before_prune); 385 debug_assert!(μ.len() <= n_before_prune);
493 PlotLookup : Plotting<N>, 457 PlotLookup : Plotting<N>,
494 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 458 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
495 Reg : RegTerm<F, N> { 459 Reg : RegTerm<F, N> {
496 460
497 // Set up parameters 461 // Set up parameters
498 let config = &fbconfig.insertion; 462 let config = &fbconfig.generic;
499 let op𝒟norm = op𝒟.opnorm_bound(); 463 let op𝒟norm = op𝒟.opnorm_bound();
500 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); 464 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
501 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 465 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
502 // by τ compared to the conditional gradient approach. 466 // by τ compared to the conditional gradient approach.
503 let tolerance = config.tolerance * τ * reg.tolerance_scaling(); 467 let tolerance = config.tolerance * τ * reg.tolerance_scaling();
619 PlotLookup : Plotting<N>, 583 PlotLookup : Plotting<N>,
620 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, 584 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
621 Reg : RegTerm<F, N> { 585 Reg : RegTerm<F, N> {
622 586
623 // Set up parameters 587 // Set up parameters
624 let config = &fbconfig.insertion; 588 let config = &fbconfig.generic;
625 let op𝒟norm = op𝒟.opnorm_bound(); 589 let op𝒟norm = op𝒟.opnorm_bound();
626 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap(); 590 let τ = fbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
627 let mut λ = 1.0; 591 let mut λ = 1.0;
628 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled 592 // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
629 // by τ compared to the conditional gradient approach. 593 // by τ compared to the conditional gradient approach.

mercurial