| 134 calculate_residual, |
134 calculate_residual, |
| 135 L2Squared, |
135 L2Squared, |
| 136 DataTerm, |
136 DataTerm, |
| 137 }; |
137 }; |
| 138 |
138 |
| 139 /// Method for constructing $μ$ on each iteration |
|
| 140 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
|
| 141 #[allow(dead_code)] |
|
| 142 pub enum InsertionStyle { |
|
| 143 /// Resuse previous $μ$ from previous iteration, optimising weights |
|
| 144 /// before inserting new spikes. |
|
| 145 Reuse, |
|
| 146 /// Start each iteration with $μ=0$. |
|
| 147 Zero, |
|
| 148 } |
|
| 149 |
|
| 150 /// Settings for [`pointsource_fb_reg`]. |
139 /// Settings for [`pointsource_fb_reg`]. |
| 151 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
140 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 152 #[serde(default)] |
141 #[serde(default)] |
| 153 pub struct FBConfig<F : Float> { |
142 pub struct FBConfig<F : Float> { |
| 154 /// Step length scaling |
143 /// Step length scaling |
| 155 pub τ0 : F, |
144 pub τ0 : F, |
| 156 /// Generic parameters |
145 /// Generic parameters |
| 157 pub insertion : FBGenericConfig<F>, |
146 pub generic : FBGenericConfig<F>, |
| 158 } |
147 } |
| 159 |
148 |
| 160 /// Settings for the solution of the stepwise optimality condition in algorithms based on |
149 /// Settings for the solution of the stepwise optimality condition in algorithms based on |
| 161 /// [`generic_pointsource_fb_reg`]. |
150 /// [`generic_pointsource_fb_reg`]. |
| 162 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
151 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)] |
| 163 #[serde(default)] |
152 #[serde(default)] |
| 164 pub struct FBGenericConfig<F : Float> { |
153 pub struct FBGenericConfig<F : Float> { |
| 165 /// Method for constructing $μ$ on each iteration; see [`InsertionStyle`]. |
|
| 166 pub insertion_style : InsertionStyle, |
|
| 167 /// Tolerance for point insertion. |
154 /// Tolerance for point insertion. |
| 168 pub tolerance : Tolerance<F>, |
155 pub tolerance : Tolerance<F>, |
| |
156 |
| 169 /// Stop looking for predual maximum (where to isert a new point) below |
157 /// Stop looking for predual maximum (where to isert a new point) below |
| 170 /// `tolerance` multiplied by this factor. |
158 /// `tolerance` multiplied by this factor. |
| |
159 /// |
| |
160 /// Not used by [`super::radon_fb`]. |
| 171 pub insertion_cutoff_factor : F, |
161 pub insertion_cutoff_factor : F, |
| |
162 |
| 172 /// Settings for branch and bound refinement when looking for predual maxima |
163 /// Settings for branch and bound refinement when looking for predual maxima |
| 173 pub refinement : RefinementSettings<F>, |
164 pub refinement : RefinementSettings<F>, |
| |
165 |
| 174 /// Maximum insertions within each outer iteration |
166 /// Maximum insertions within each outer iteration |
| |
167 /// |
| |
168 /// Not used by [`super::radon_fb`]. |
| 175 pub max_insertions : usize, |
169 pub max_insertions : usize, |
| |
170 |
| 176 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. |
171 /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations. |
| |
172 /// |
| |
173 /// Not used by [`super::radon_fb`]. |
| 177 pub bootstrap_insertions : Option<(usize, usize)>, |
174 pub bootstrap_insertions : Option<(usize, usize)>, |
| |
175 |
| 178 /// Inner method settings |
176 /// Inner method settings |
| 179 pub inner : InnerSettings<F>, |
177 pub inner : InnerSettings<F>, |
| |
178 |
| 180 /// Spike merging method |
179 /// Spike merging method |
| 181 pub merging : SpikeMergingMethod<F>, |
180 pub merging : SpikeMergingMethod<F>, |
| |
181 |
| 182 /// Tolerance multiplier for merges |
182 /// Tolerance multiplier for merges |
| 183 pub merge_tolerance_mult : F, |
183 pub merge_tolerance_mult : F, |
| |
184 |
| 184 /// Spike merging method after the last step |
185 /// Spike merging method after the last step |
| 185 pub final_merging : SpikeMergingMethod<F>, |
186 pub final_merging : SpikeMergingMethod<F>, |
| |
187 |
| 186 /// Iterations between merging heuristic tries |
188 /// Iterations between merging heuristic tries |
| 187 pub merge_every : usize, |
189 pub merge_every : usize, |
| |
190 |
| 188 /// Save $μ$ for postprocessing optimisation |
191 /// Save $μ$ for postprocessing optimisation |
| 189 pub postprocessing : bool |
192 pub postprocessing : bool |
| 190 } |
193 } |
| 191 |
194 |
| 192 #[replace_float_literals(F::cast_from(literal))] |
195 #[replace_float_literals(F::cast_from(literal))] |
| 193 impl<F : Float> Default for FBConfig<F> { |
196 impl<F : Float> Default for FBConfig<F> { |
| 194 fn default() -> Self { |
197 fn default() -> Self { |
| 195 FBConfig { |
198 FBConfig { |
| 196 τ0 : 0.99, |
199 τ0 : 0.99, |
| 197 insertion : Default::default() |
200 generic : Default::default(), |
| 198 } |
201 } |
| 199 } |
202 } |
| 200 } |
203 } |
| 201 |
204 |
| 202 #[replace_float_literals(F::cast_from(literal))] |
205 #[replace_float_literals(F::cast_from(literal))] |
| 203 impl<F : Float> Default for FBGenericConfig<F> { |
206 impl<F : Float> Default for FBGenericConfig<F> { |
| 204 fn default() -> Self { |
207 fn default() -> Self { |
| 205 FBGenericConfig { |
208 FBGenericConfig { |
| 206 insertion_style : InsertionStyle::Reuse, |
|
| 207 tolerance : Default::default(), |
209 tolerance : Default::default(), |
| 208 insertion_cutoff_factor : 1.0, |
210 insertion_cutoff_factor : 1.0, |
| 209 refinement : Default::default(), |
211 refinement : Default::default(), |
| 210 max_insertions : 100, |
212 max_insertions : 100, |
| 211 //bootstrap_insertions : None, |
213 //bootstrap_insertions : None, |
| 212 bootstrap_insertions : Some((10, 1)), |
214 bootstrap_insertions : Some((10, 1)), |
| 213 inner : InnerSettings { |
215 inner : InnerSettings { |
| 214 method : InnerMethod::SSN, |
216 method : InnerMethod::Default, |
| 215 .. Default::default() |
217 .. Default::default() |
| 216 }, |
218 }, |
| 217 merging : SpikeMergingMethod::None, |
219 merging : SpikeMergingMethod::None, |
| 218 //merging : Default::default(), |
220 //merging : Default::default(), |
| 219 final_merging : Default::default(), |
221 final_merging : Default::default(), |
| 222 postprocessing : false, |
224 postprocessing : false, |
| 223 } |
225 } |
| 224 } |
226 } |
| 225 } |
227 } |
| 226 |
228 |
| 227 #[replace_float_literals(F::cast_from(literal))] |
229 /// TODO: document. |
| 228 pub(crate) fn μ_diff<F : Float, const N : usize>( |
230 /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same spike |
| 229 μ_new : &DiscreteMeasure<Loc<F, N>, F>, |
231 /// locations, while `ν_delta` may have different locations. |
| 230 μ_base : &DiscreteMeasure<Loc<F, N>, F>, |
|
| 231 ν_delta : Option<&DiscreteMeasure<Loc<F, N>, F>>, |
|
| 232 config : &FBGenericConfig<F> |
|
| 233 ) -> DiscreteMeasure<Loc<F, N>, F> { |
|
| 234 let mut ν : DiscreteMeasure<Loc<F, N>, F> = match config.insertion_style { |
|
| 235 InsertionStyle::Reuse => { |
|
| 236 μ_new.iter_spikes() |
|
| 237 .zip(μ_base.iter_masses().chain(std::iter::repeat(0.0))) |
|
| 238 .map(|(δ, α_base)| (δ.x, α_base - δ.α)) |
|
| 239 .collect() |
|
| 240 }, |
|
| 241 InsertionStyle::Zero => { |
|
| 242 μ_new.iter_spikes() |
|
| 243 .map(|δ| -δ) |
|
| 244 .chain(μ_base.iter_spikes().copied()) |
|
| 245 .collect() |
|
| 246 } |
|
| 247 }; |
|
| 248 ν.prune(); // Potential small performance improvement |
|
| 249 // Add ν_delta if given |
|
| 250 match ν_delta { |
|
| 251 None => ν, |
|
| 252 Some(ν_d) => ν + ν_d, |
|
| 253 } |
|
| 254 } |
|
| 255 |
|
| 256 #[replace_float_literals(F::cast_from(literal))] |
232 #[replace_float_literals(F::cast_from(literal))] |
| 257 pub(crate) fn insert_and_reweigh< |
233 pub(crate) fn insert_and_reweigh< |
| 258 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize |
234 'a, F, GA, 𝒟, BTA, G𝒟, S, K, Reg, State, const N : usize |
| 259 >( |
235 >( |
| 260 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
236 μ : &mut DiscreteMeasure<Loc<F, N>, F>, |
| 282 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
258 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
| 283 Reg : RegTerm<F, N>, |
259 Reg : RegTerm<F, N>, |
| 284 State : AlgIteratorState { |
260 State : AlgIteratorState { |
| 285 |
261 |
| 286 // Maximum insertion count and measure difference calculation depend on insertion style. |
262 // Maximum insertion count and measure difference calculation depend on insertion style. |
| 287 let (m, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
263 let (max_insertions, warn_insertions) = match (state.iteration(), config.bootstrap_insertions) { |
| 288 (i, Some((l, k))) if i <= l => (k, false), |
264 (i, Some((l, k))) if i <= l => (k, false), |
| 289 _ => (config.max_insertions, !state.is_quiet()), |
265 _ => (config.max_insertions, !state.is_quiet()), |
| 290 }; |
266 }; |
| 291 let max_insertions = match config.insertion_style { |
267 |
| 292 InsertionStyle::Zero => { |
268 // TODO: should avoid a copy of μ_base here. |
| 293 todo!("InsertionStyle::Zero does not currently work with FISTA, so diabled."); |
|
| 294 // let n = μ.len(); |
|
| 295 // μ = DiscreteMeasure::new(); |
|
| 296 // n + m |
|
| 297 }, |
|
| 298 InsertionStyle::Reuse => m, |
|
| 299 }; |
|
| 300 |
|
| 301 // TODO: should avoid a second copy of μ here; μ_base already stores a copy. |
|
| 302 let ω0 = op𝒟.apply(match ν_delta { |
269 let ω0 = op𝒟.apply(match ν_delta { |
| 303 None => μ.clone(), |
270 None => μ_base.clone(), |
| 304 Some(ν_d) => &*μ + ν_d, |
271 Some(ν_d) => &*μ_base + ν_d, |
| 305 }); |
272 }); |
| 306 |
273 |
| 307 // Add points to support until within error tolerance or maximum insertion count reached. |
274 // Add points to support until within error tolerance or maximum insertion count reached. |
| 308 let mut count = 0; |
275 let mut count = 0; |
| 309 let (within_tolerances, d) = 'insertion: loop { |
276 let (within_tolerances, d) = 'insertion: loop { |
| 402 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
372 BTNodeLookup: BTNode<F, usize, Bounds<F>, N>, |
| 403 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
373 DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>, |
| 404 Reg : RegTerm<F, N>, |
374 Reg : RegTerm<F, N>, |
| 405 State : AlgIteratorState { |
375 State : AlgIteratorState { |
| 406 if state.iteration() % config.merge_every == 0 { |
376 if state.iteration() % config.merge_every == 0 { |
| 407 let n_before_merge = μ.len(); |
377 stats.merged += μ.merge_spikes(config.merging, |μ_candidate| { |
| 408 μ.merge_spikes(config.merging, |μ_candidate| { |
378 let mut d = minus_τv + op𝒟.preapply(μ_base.sub_matching(&μ_candidate)); |
| 409 let μd = μ_diff(&μ_candidate, &μ_base, None, config); |
|
| 410 let mut d = minus_τv + op𝒟.preapply(μd); |
|
| 411 |
|
| 412 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
379 reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, &config) |
| 413 .then_some(()) |
|
| 414 }); |
380 }); |
| 415 debug_assert!(μ.len() >= n_before_merge); |
|
| 416 stats.merged += μ.len() - n_before_merge; |
|
| 417 } |
381 } |
| 418 |
382 |
| 419 let n_before_prune = μ.len(); |
383 let n_before_prune = μ.len(); |
| 420 μ.prune(); |
384 μ.prune(); |
| 421 debug_assert!(μ.len() <= n_before_prune); |
385 debug_assert!(μ.len() <= n_before_prune); |