| 105 } |
105 } |
| 106 |
106 |
| 107 #[replace_float_literals(F::cast_from(literal))] |
107 #[replace_float_literals(F::cast_from(literal))] |
| 108 impl<F: Float> Default for FBConfig<F> { |
108 impl<F: Float> Default for FBConfig<F> { |
| 109 fn default() -> Self { |
109 fn default() -> Self { |
| 110 FBConfig { |
110 FBConfig { τ0: 0.99, σp0: 0.99, insertion: Default::default() } |
| 111 τ0: 0.99, |
|
| 112 σp0: 0.99, |
|
| 113 insertion: Default::default(), |
|
| 114 } |
|
| 115 } |
111 } |
| 116 } |
112 } |
| 117 |
113 |
| 118 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize { |
114 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize { |
| 119 let n_before_prune = μ.len(); |
115 let n_before_prune = μ.len(); |
| 155 reg: &Reg, |
151 reg: &Reg, |
| 156 prox_penalty: &P, |
152 prox_penalty: &P, |
| 157 fbconfig: &FBConfig<F>, |
153 fbconfig: &FBConfig<F>, |
| 158 iterator: I, |
154 iterator: I, |
| 159 mut plotter: Plot, |
155 mut plotter: Plot, |
| 160 μ0 : Option<RNDM<N, F>>, |
156 μ0: Option<RNDM<N, F>>, |
| 161 ) -> DynResult<RNDM<N, F>> |
157 ) -> DynResult<RNDM<N, F>> |
| 162 where |
158 where |
| 163 F: Float + ToNalgebraRealField, |
159 F: Float + ToNalgebraRealField, |
| 164 I: AlgIteratorFactory<IterInfo<F>>, |
160 I: AlgIteratorFactory<IterInfo<F>>, |
| 165 RNDM<N, F>: SpikeMerging<F>, |
161 RNDM<N, F>: SpikeMerging<F>, |
| 194 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
190 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 195 // Calculate smooth part of surrogate model. |
191 // Calculate smooth part of surrogate model. |
| 196 // TODO: optimise τ to be applied to residual. |
192 // TODO: optimise τ to be applied to residual. |
| 197 let mut τv = f.differential(&μ) * τ; |
193 let mut τv = f.differential(&μ) * τ; |
| 198 |
194 |
| 199 // Save current base point |
195 // Save current base point for merge |
| 200 let μ_base = μ.clone(); |
196 let μ_base_len = μ.len(); |
| |
197 let maybe_μ_base = config.merge_now(&state).then(|| μ.clone()); |
| 201 |
198 |
| 202 // Insert and reweigh |
199 // Insert and reweigh |
| 203 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
200 let (maybe_d, _within_tolerances) = prox_penalty |
| 204 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
201 .insert_and_reweigh(&mut μ, &mut τv, τ, ε, config, ®, &state, &mut stats)?; |
| 205 )?; |
202 |
| |
203 stats.inserted += μ.len() - μ_base_len; |
| 206 |
204 |
| 207 // Prune and possibly merge spikes |
205 // Prune and possibly merge spikes |
| 208 if config.merge_now(&state) { |
206 if let Some(μ_base) = maybe_μ_base { |
| 209 stats.merged += prox_penalty.merge_spikes( |
207 stats.merged += prox_penalty.merge_spikes( |
| 210 &mut μ, |
208 &mut μ, |
| 211 &mut τv, |
209 &mut τv, |
| 212 &μ_base, |
210 &μ_base, |
| 213 None, |
|
| 214 τ, |
211 τ, |
| 215 ε, |
212 ε, |
| 216 config, |
213 config, |
| 217 ®, |
214 ®, |
| 218 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)), |
215 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)), |
| 255 reg: &Reg, |
252 reg: &Reg, |
| 256 prox_penalty: &P, |
253 prox_penalty: &P, |
| 257 fbconfig: &FBConfig<F>, |
254 fbconfig: &FBConfig<F>, |
| 258 iterator: I, |
255 iterator: I, |
| 259 mut plotter: Plot, |
256 mut plotter: Plot, |
| 260 μ0: Option<RNDM<N, F>> |
257 μ0: Option<RNDM<N, F>>, |
| 261 ) -> DynResult<RNDM<N, F>> |
258 ) -> DynResult<RNDM<N, F>> |
| 262 where |
259 where |
| 263 F: Float + ToNalgebraRealField, |
260 F: Float + ToNalgebraRealField, |
| 264 I: AlgIteratorFactory<IterInfo<F>>, |
261 I: AlgIteratorFactory<IterInfo<F>>, |
| 265 RNDM<N, F>: SpikeMerging<F>, |
262 RNDM<N, F>: SpikeMerging<F>, |
| 296 // Run the algorithm |
293 // Run the algorithm |
| 297 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
294 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { |
| 298 // Calculate smooth part of surrogate model. |
295 // Calculate smooth part of surrogate model. |
| 299 let mut τv = f.differential(&μ) * τ; |
296 let mut τv = f.differential(&μ) * τ; |
| 300 |
297 |
| 301 // Save current base point |
298 let μ_base_len = μ.len(); |
| 302 let μ_base = μ.clone(); |
|
| 303 |
299 |
| 304 // Insert new spikes and reweigh |
300 // Insert new spikes and reweigh |
| 305 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( |
301 let (maybe_d, _within_tolerances) = prox_penalty |
| 306 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, &state, &mut stats, |
302 .insert_and_reweigh(&mut μ, &mut τv, τ, ε, config, ®, &state, &mut stats)?; |
| 307 )?; |
303 |
| |
304 stats.inserted += μ.len() - μ_base_len; |
| 308 |
305 |
| 309 // (Do not) merge spikes. |
306 // (Do not) merge spikes. |
| 310 if config.merge_now(&state) && !warned_merging { |
307 if config.merge_now(&state) && !warned_merging { |
| 311 let err = format!("Merging not supported for μFISTA"); |
308 let err = format!("Merging not supported for μFISTA"); |
| 312 println!("{}", err.red()); |
309 println!("{}", err.red()); |