src/fb.rs

branch
dev
changeset 63
7a8a55fd41c0
parent 61
4f468d35fa29
equal deleted inserted replaced
61:4f468d35fa29 63:7a8a55fd41c0
105 } 105 }
106 106
107 #[replace_float_literals(F::cast_from(literal))] 107 #[replace_float_literals(F::cast_from(literal))]
108 impl<F: Float> Default for FBConfig<F> { 108 impl<F: Float> Default for FBConfig<F> {
109 fn default() -> Self { 109 fn default() -> Self {
110 FBConfig { 110 FBConfig { τ0: 0.99, σp0: 0.99, insertion: Default::default() }
111 τ0: 0.99,
112 σp0: 0.99,
113 insertion: Default::default(),
114 }
115 } 111 }
116 } 112 }
117 113
118 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize { 114 pub(crate) fn prune_with_stats<F: Float, const N: usize>(μ: &mut RNDM<N, F>) -> usize {
119 let n_before_prune = μ.len(); 115 let n_before_prune = μ.len();
155 reg: &Reg, 151 reg: &Reg,
156 prox_penalty: &P, 152 prox_penalty: &P,
157 fbconfig: &FBConfig<F>, 153 fbconfig: &FBConfig<F>,
158 iterator: I, 154 iterator: I,
159 mut plotter: Plot, 155 mut plotter: Plot,
160 μ0 : Option<RNDM<N, F>>, 156 μ0: Option<RNDM<N, F>>,
161 ) -> DynResult<RNDM<N, F>> 157 ) -> DynResult<RNDM<N, F>>
162 where 158 where
163 F: Float + ToNalgebraRealField, 159 F: Float + ToNalgebraRealField,
164 I: AlgIteratorFactory<IterInfo<F>>, 160 I: AlgIteratorFactory<IterInfo<F>>,
165 RNDM<N, F>: SpikeMerging<F>, 161 RNDM<N, F>: SpikeMerging<F>,
194 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 190 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
195 // Calculate smooth part of surrogate model. 191 // Calculate smooth part of surrogate model.
196 // TODO: optimise τ to be applied to residual. 192 // TODO: optimise τ to be applied to residual.
197 let mut τv = f.differential(&μ) * τ; 193 let mut τv = f.differential(&μ) * τ;
198 194
199 // Save current base point 195 // Save current base point for merge
200 let μ_base = μ.clone(); 196 let μ_base_len = μ.len();
197 let maybe_μ_base = config.merge_now(&state).then(|| μ.clone());
201 198
202 // Insert and reweigh 199 // Insert and reweigh
203 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 200 let (maybe_d, _within_tolerances) = prox_penalty
204 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats, 201 .insert_and_reweigh(&mut μ, &mut τv, τ, ε, config, &reg, &state, &mut stats)?;
205 )?; 202
203 stats.inserted += μ.len() - μ_base_len;
206 204
207 // Prune and possibly merge spikes 205 // Prune and possibly merge spikes
208 if config.merge_now(&state) { 206 if let Some(μ_base) = maybe_μ_base {
209 stats.merged += prox_penalty.merge_spikes( 207 stats.merged += prox_penalty.merge_spikes(
210 &mut μ, 208 &mut μ,
211 &mut τv, 209 &mut τv,
212 &μ_base, 210 &μ_base,
213 None,
214 τ, 211 τ,
215 ε, 212 ε,
216 config, 213 config,
217 &reg, 214 &reg,
218 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)), 215 Some(|μ̃: &RNDM<N, F>| f.apply(μ̃)),
255 reg: &Reg, 252 reg: &Reg,
256 prox_penalty: &P, 253 prox_penalty: &P,
257 fbconfig: &FBConfig<F>, 254 fbconfig: &FBConfig<F>,
258 iterator: I, 255 iterator: I,
259 mut plotter: Plot, 256 mut plotter: Plot,
260 μ0: Option<RNDM<N, F>> 257 μ0: Option<RNDM<N, F>>,
261 ) -> DynResult<RNDM<N, F>> 258 ) -> DynResult<RNDM<N, F>>
262 where 259 where
263 F: Float + ToNalgebraRealField, 260 F: Float + ToNalgebraRealField,
264 I: AlgIteratorFactory<IterInfo<F>>, 261 I: AlgIteratorFactory<IterInfo<F>>,
265 RNDM<N, F>: SpikeMerging<F>, 262 RNDM<N, F>: SpikeMerging<F>,
296 // Run the algorithm 293 // Run the algorithm
297 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) { 294 for state in iterator.iter_init(|| full_stats(&μ, ε, stats.clone())) {
298 // Calculate smooth part of surrogate model. 295 // Calculate smooth part of surrogate model.
299 let mut τv = f.differential(&μ) * τ; 296 let mut τv = f.differential(&μ) * τ;
300 297
301 // Save current base point 298 let μ_base_len = μ.len();
302 let μ_base = μ.clone();
303 299
304 // Insert new spikes and reweigh 300 // Insert new spikes and reweigh
305 let (maybe_d, _within_tolerances) = prox_penalty.insert_and_reweigh( 301 let (maybe_d, _within_tolerances) = prox_penalty
306 &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg, &state, &mut stats, 302 .insert_and_reweigh(&mut μ, &mut τv, τ, ε, config, &reg, &state, &mut stats)?;
307 )?; 303
304 stats.inserted += μ.len() - μ_base_len;
308 305
309 // (Do not) merge spikes. 306 // (Do not) merge spikes.
310 if config.merge_now(&state) && !warned_merging { 307 if config.merge_now(&state) && !warned_merging {
311 let err = format!("Merging not supported for μFISTA"); 308 let err = format!("Merging not supported for μFISTA");
312 println!("{}", err.red()); 309 println!("{}", err.red());

mercurial