134 calculate_residual, |
134 calculate_residual, |
135 L2Squared, |
135 L2Squared, |
136 DataTerm, |
136 DataTerm, |
137 }; |
137 }; |
138 |
138 |
139 /// Method for constructing $μ$ on each iteration |
|
140 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
141 #[allow(dead_code)] |
|
142 pub enum InsertionStyle { |
|
143 /// Resuse previous $μ$ from previous iteration, optimising weights |
|
144 /// before inserting new spikes. |
|
145 Reuse, |
|
146 /// Start each iteration with $μ=0$. |
|
147 Zero, |
|
148 } |
|
149 |
|
150 /// Settings for [`pointsource_fb_reg`]. |
139 /// Settings for [`pointsource_fb_reg`]. |
151 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
140 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
152 #[serde(default)] |
141 #[serde(default)] |
153 pub struct FBConfig<F : Float> { |
142 pub struct FBConfig<F : Float> { |
154 /// Step length scaling |
143 /// Step length scaling |
155 pub τ0 : F, |
144 pub τ0 : F, |
156 /// Generic parameters |
145 /// Generic parameters |
157 pub insertion : FBGenericConfig<F>, |
146 pub generic : FBGenericConfig<F>, |
158 } |
147 } |
159 |
148 |
160 /// Settings for the solution of the stepwise optimality condition in algorithms based on |
149 /// Settings for the solution of the stepwise optimality condition in algorithms based on |
161 /// [`generic_pointsource_fb_reg`]. |
150 /// [`generic_pointsource_fb_reg`]. |
162 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
151 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
163 #[serde(default)] |
152 #[serde(default)] |
164 pub struct FBGenericConfig<F : Float> { |
153 pub struct FBGenericConfig<F : Float> { |
165 /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`]. |
|
166 pub insertion_style : InsertionStyle, |
|
167 /// Tolerance for point insertion. |
154 /// Tolerance for point insertion. |
168 pub tolerance : Tolerance<F>, |
155 pub tolerance : Tolerance<F>, |
|
156 |
169 /// Stop looking for predual maximum (where to isert a new point) below |
157 /// Stop looking for predual maximum (where to isert a new point) below |
170 /// `tolerance` multiplied by this factor. |
158 /// `tolerance` multiplied by this factor. |
|
159 /// |
|
160 /// Not used by [`super::radon_fb`]. |
171 pub insertion_cutoff_factor : F, |
161 pub insertion_cutoff_factor : F, |
|
162 |
172 /// Settings for branch and bound refinement when looking for predual maxima |
163 /// Settings for branch and bound refinement when looking for predual maxima |
173 pub refinement : RefinementSettings<F>, |
164 pub refinement : RefinementSettings<F>, |
|
165 |
174 /// Maximum insertions within each outer iteration |
166 /// Maximum insertions within each outer iteration |
|
167 /// |
|
168 /// Not used by [`super::radon_fb`]. |
175 pub max_insertions : usize, |
169 pub max_insertions : usize, |
|
170 |
176 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. |
171 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. |
|
172 /// |
|
173 /// Not used by [`super::radon_fb`]. |
177 pub bootstrap_insertions : Option<(usize, usize)>, |
174 pub bootstrap_insertions : Option<(usize, usize)>, |
|
175 |
178 /// Inner method settings |
176 /// Inner method settings |
179 pub inner : InnerSettings<F>, |
177 pub inner : InnerSettings<F>, |
|
178 |
180 /// Spike merging method |
179 /// Spike merging method |
181 pub merging : SpikeMergingMethod<F>, |
180 pub merging : SpikeMergingMethod<F>, |
|
181 |
182 /// Tolerance multiplier for merges |
182 /// Tolerance multiplier for merges |
183 pub merge_tolerance_mult : F, |
183 pub merge_tolerance_mult : F, |
|
184 |
184 /// Spike merging method after the last step |
185 /// Spike merging method after the last step |
185 pub final_merging : SpikeMergingMethod<F>, |
186 pub final_merging : SpikeMergingMethod<F>, |
|
187 |
186 /// Iterations between merging heuristic tries |
188 /// Iterations between merging heuristic tries |
187 pub merge_every : usize, |
189 pub merge_every : usize, |
|
190 |
188 /// Save $μ$ for postprocessing optimisation |
191 /// Save $μ$ for postprocessing optimisation |
189 pub postprocessing : bool |
192 pub postprocessing : bool |
190 } |
193 } |
191 |
194 |
192 #[replace_float_literals(F::cast_from(literal))] |
195 #[replace_float_literals(F::cast_from(literal))] |
193 impl<F : Float> Default for FBConfig<F> { |
196 impl<F : Float> Default for FBConfig<F> { |
194 fn default() -> Self { |
197 fn default() -> Self { |
195 FBConfig { |
198 FBConfig { |
196 τ0 : 0.99, |
199 τ0 : 0.99, |
197 insertion : Default::default() |
200 generic : Default::default(), |
198 } |
201 } |
199 } |
202 } |
200 } |
203 } |
201 |
204 |
202 #[replace_float_literals(F::cast_from(literal))] |
205 #[replace_float_literals(F::cast_from(literal))] |
203 impl<F : Float> Default for FBGenericConfig<F> { |
206 impl<F : Float> Default for FBGenericConfig<F> { |
204 fn default() -> Self { |
207 fn default() -> Self { |
205 FBGenericConfig { |
208 FBGenericConfig { |
206 insertion_style : InsertionStyle::Reuse, |
|
207 tolerance : Default::default(), |
209 tolerance : Default::default(), |
208 insertion_cutoff_factor : 1.0, |
210 insertion_cutoff_factor : 1.0, |
209 refinement : Default::default(), |
211 refinement : Default::default(), |
210 max_insertions : 100, |
212 max_insertions : 100, |
211 //bootstrap_insertions : None, |
213 //bootstrap_insertions : None, |
212 bootstrap_insertions : Some((10, 1)), |
214 bootstrap_insertions : Some((10, 1)), |
213 inner : InnerSettings { |
215 inner : InnerSettings { |
214 method : InnerMethod::SSN, |
216 method : InnerMethod::Default, |
215 .. Default::default() |
217 .. Default::default() |
216 }, |
218 }, |
217 merging : SpikeMergingMethod::None, |
219 merging : SpikeMergingMethod::None, |
218 //merging : Default::default(), |
220 //merging : Default::default(), |
219 final_merging : Default::default(), |
221 final_merging : Default::default(), |
222 postprocessing : false, |
224 postprocessing : false, |
223 } |
225 } |
224 } |
226 } |
225 } |
227 } |
226 |
228 |
227 #[replace_float_literals(F::cast_from(literal))] |
229 /// TODO: document. |
228 pub(crate) fn μ_diff<F : Float, const N : usize>( |
230 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike |
229 μ_new : &DiscreteMeasure<Loc<F, N>, F>, |
231 /// locations, while `ν_delta` may have different locations. |
230 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
|
231 ν_delta : Option<&DiscreteMeasure<Loc<F, N>, F>>, |
|
232 config : &FBGenericConfig<F> |
|
233 ) -> DiscreteMeasure<Loc<F, N>, F> { |
|
234 let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style { |
|
235 InsertionStyle::Reuse => { |
|
236 μ_new.iter_spikes() |
|
237 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0))) |
|
238 .map(|(δ, α_base)| (δ.x, α_base - δ.α)) |
|
239 .collect() |
|
240 }, |
|
241 InsertionStyle::Zero => { |
|
242 μ_new.iter_spikes() |
|
243 .map(|δ| -δ) |
|
244 .chain(μ_base.iter_spikes().copied()) |
|
245 .collect() |
|
246 } |
|
247 }; |
|
248 ν.prune(); // Potential small performance improvement |
|
249 // Add ν_delta if given |
|
250 match ν_delta { |
|
251 None => ν, |
|
252 Some(ν_d) => ν + ν_d, |
|
253 } |
|
254 } |
|
255 |
|
256 #[replace_float_literals(F::cast_from(literal))] |
232 #[replace_float_literals(F::cast_from(literal))] |
257 pub(crate) fn insert_and_reweigh< |
233 pub(crate) fn insert_and_reweigh< |
258 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize |
234 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize |
259 >( |
235 >( |
260 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
236 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
282 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
258 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
283 Reg : RegTerm<F, N>, |
259 Reg : RegTerm<F, N>, |
284 State : AlgIteratorState { |
260 State : AlgIteratorState { |
285 |
261 |
286 // Maximum insertion count and measure difference calculation depend on insertion style. |
262 // Maximum insertion count and measure difference calculation depend on insertion style. |
287 let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
263 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
288 (i, Some((l, k))) if i <= l => (k, false), |
264 (i, Some((l, k))) if i <= l => (k, false), |
289 _ => (config.max_insertions, !state.is_quiet()), |
265 _ => (config.max_insertions, !state.is_quiet()), |
290 }; |
266 }; |
291 let max_insertions = match config.insertion_style { |
267 |
292 InsertionStyle::Zero => { |
268 // TODO: should avoid a copy of μ_base here. |
293 todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled."); |
|
294 // let n = μ.len(); |
|
295 // μ = DiscreteMeasure::new(); |
|
296 // n + m |
|
297 }, |
|
298 InsertionStyle::Reuse => m, |
|
299 }; |
|
300 |
|
301 // TODO: should avoid a second copy of μ here; μ_base already stores a copy. |
|
302 let ω0 = op𝒟.apply(match ν_delta { |
269 let ω0 = op𝒟.apply(match ν_delta { |
303 None => μ.clone(), |
270 None => μ_base.clone(), |
304 Some(ν_d) => &*μ + ν_d, |
271 Some(ν_d) => &*μ_base + ν_d, |
305 }); |
272 }); |
306 |
273 |
307 // Add points to support until within error tolerance or maximum insertion count reached. |
274 // Add points to support until within error tolerance or maximum insertion count reached. |
308 let mut count = 0; |
275 let mut count = 0; |
309 let (within_tolerances, d) = 'insertion: loop { |
276 let (within_tolerances, d) = 'insertion: loop { |
402 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
372 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
403 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
373 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
404 Reg : RegTerm<F, N>, |
374 Reg : RegTerm<F, N>, |
405 State : AlgIteratorState { |
375 State : AlgIteratorState { |
406 if state.iteration() % config.merge_every == 0 { |
376 if state.iteration() % config.merge_every == 0 { |
407 let n_before_merge = μ.len(); |
377 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
408 μ.merge_spikes(config.merging, |μ_candidate| { |
378 let mut d = minus_τv + op𝒟.preapply(μ_base.sub_matching(&μ_candidate)); |
409 let μd = μ_diff(&μ_candidate, &μ_base, None, config); |
|
410 let mut d = minus_τv + op𝒟.preapply(μd); |
|
411 |
|
412 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
379 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
413 .then_some(()) |
|
414 }); |
380 }); |
415 debug_assert!(μ.len() >= n_before_merge); |
|
416 stats.merged += μ.len() - n_before_merge; |
|
417 } |
381 } |
418 |
382 |
419 let n_before_prune = μ.len(); |
383 let n_before_prune = μ.len(); |
420 μ.prune(); |
384 μ.prune(); |
421 debug_assert!(μ.len() <= n_before_prune); |
385 debug_assert!(μ.len() <= n_before_prune); |