src/fb.rs

branch
dev
changeset 39
6316d68b58af
parent 37
c5d8bd1a7728
child 51
0693cc9ba9f0
--- a/src/fb.rs	Thu Jan 23 23:35:28 2025 +0100
+++ b/src/fb.rs	Thu Jan 23 23:34:05 2025 +0100
@@ -93,10 +93,7 @@
     DiscreteMeasure,
     RNDM,
 };
-use crate::measures::merging::{
-    SpikeMergingMethod,
-    SpikeMerging,
-};
+use crate::measures::merging::SpikeMerging;
 use crate::forward_model::{
     ForwardModel,
     AdjointProductBoundedBy,
@@ -164,7 +161,7 @@
     RNDM<F, N> : SpikeMerging<F>,
     for<'a> &'a RNDM<F, N> : Instance<RNDM<F, N>>,
 {
-    μ.merge_spikes_fitness(config.merging,
+    μ.merge_spikes_fitness(config.final_merging_method(),
                            |μ̃| dataterm.calculate_fit_op(μ̃, opA, b),
                            |&v| v);
     μ.prune();
@@ -255,7 +252,10 @@
 
         // Prune and possibly merge spikes
         if config.merge_now(&state) {
-            stats.merged += prox_penalty.merge_spikes(&mut μ, &mut τv, &μ_base, τ, ε, config, &reg);
+            stats.merged += prox_penalty.merge_spikes(
+                &mut μ, &mut τv, &μ_base, None, τ, ε, config, &reg,
+                Some(|μ̃ : &RNDM<F, N>| L2Squared.calculate_fit_op(μ̃, opA, b)),
+            );
         }
 
         stats.pruned += prune_with_stats(&mut μ);
@@ -363,15 +363,10 @@
         );
 
         // (Do not) merge spikes.
-        if config.merge_now(&state) {
-            match config.merging {
-                SpikeMergingMethod::None => { },
-                _ => if !warned_merging {
-                    let err = format!("Merging not supported for μFISTA");
-                    println!("{}", err.red());
-                    warned_merging = true;
-                }
-            }
+        if config.merge_now(&state) && !warned_merging {
+            let err = format!("Merging not supported for μFISTA");
+            println!("{}", err.red());
+            warned_merging = true;
         }
 
         // Update inertial prameters
@@ -387,6 +382,9 @@
         // stored in μ_prev.
         let n_before_prune = μ.len();
         μ.pruning_sub(1.0 + θ, θ, &mut μ_prev);
+        //let μ_new = (&μ * (1.0 + θ)).sub_matching(&(&μ_prev * θ));
+        // μ_prev = μ;
+        // μ = μ_new;
         debug_assert!(μ.len() <= n_before_prune);
         stats.pruned += n_before_prune - μ.len();
 

mercurial