162 ) -> RNDM<F, N> |
159 ) -> RNDM<F, N> |
163 where |
160 where |
164 RNDM<F, N> : SpikeMerging<F>, |
161 RNDM<F, N> : SpikeMerging<F>, |
165 for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>, |
162 for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>, |
166 { |
163 { |
167 μ.merge_spikes_fitness(config.merging, |
164 μ.merge_spikes_fitness(config.final_merging_method(), |
168 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
165 |μ̃| dataterm.calculate_fit_op(μ̃, opA, b), |
169 |&v| v); |
166 |&v| v); |
170 μ.prune(); |
167 μ.prune(); |
171 μ |
168 μ |
172 } |
169 } |
253 config, ®, &state, &mut stats |
250 config, ®, &state, &mut stats |
254 ); |
251 ); |
255 |
252 |
256 // Prune and possibly merge spikes |
253 // Prune and possibly merge spikes |
257 if config.merge_now(&state) { |
254 if config.merge_now(&state) { |
258 stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, ®); |
255 stats.merged += prox_penalty.merge_spikes( |
|
256 &mut μ, &mut τv, &μ_base, None, τ, ε, config, ®, |
|
257 Some(|μ̃ : &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)), |
|
258 ); |
259 } |
259 } |
260 |
260 |
261 stats.pruned += prune_with_stats(&mut μ); |
261 stats.pruned += prune_with_stats(&mut μ); |
262 |
262 |
263 // Update residual |
263 // Update residual |
361 τ, ε, |
361 τ, ε, |
362 config, ®, &state, &mut stats |
362 config, ®, &state, &mut stats |
363 ); |
363 ); |
364 |
364 |
365 // (Do not) merge spikes. |
365 // (Do not) merge spikes. |
366 if config.merge_now(&state) { |
366 if config.merge_now(&state) && !warned_merging { |
367 match config.merging { |
367 let err = format!("Merging not supported for μFISTA"); |
368 SpikeMergingMethod::None => { }, |
368 println!("{}", err.red()); |
369 _ => if !warned_merging { |
369 warned_merging = true; |
370 let err = format!("Merging not supported for μFISTA"); |
|
371 println!("{}", err.red()); |
|
372 warned_merging = true; |
|
373 } |
|
374 } |
|
375 } |
370 } |
376 |
371 |
377 // Update inertial prameters |
372 // Update inertial prameters |
378 let λ_prev = λ; |
373 let λ_prev = λ; |
379 λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); |
374 λ = 2.0 * λ_prev / ( λ_prev + (4.0 + λ_prev * λ_prev).sqrt() ); |
385 // subproblem with a proximal projection step, this is likely to happen when the |
380 // subproblem with a proximal projection step, this is likely to happen when the |
386 // spike is not needed. A copy of the pruned μ without artithmetic performed is |
381 // spike is not needed. A copy of the pruned μ without artithmetic performed is |
387 // stored in μ_prev. |
382 // stored in μ_prev. |
388 let n_before_prune = μ.len(); |
383 let n_before_prune = μ.len(); |
389 μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); |
384 μ.pruning_sub(1.0 + θ, θ, &mut μ_prev); |
|
385 //let μ_new = (&μ * (1.0 + θ)).sub_matching(&(&μ_prev * θ)); |
|
386 // μ_prev = μ; |
|
387 // μ = μ_new; |
390 debug_assert!(μ.len() <= n_before_prune); |
388 debug_assert!(μ.len() <= n_before_prune); |
391 stats.pruned += n_before_prune - μ.len(); |
389 stats.pruned += n_before_prune - μ.len(); |
392 |
390 |
393 // Update residual |
391 // Update residual |
394 residual = calculate_residual(&μ, opA, b); |
392 residual = calculate_residual(&μ, opA, b); |