src/prox_penalty/radon_squared.rs

branch
dev
changeset 39
6316d68b58af
parent 37
c5d8bd1a7728
--- a/src/prox_penalty/radon_squared.rs	Thu Jan 23 23:35:28 2025 +0100
+++ b/src/prox_penalty/radon_squared.rs	Thu Jan 23 23:34:05 2025 +0100
@@ -12,7 +12,7 @@
     AlgIteratorIteration,
     AlgIterator
 };
-use alg_tools::norms::L2;
+use alg_tools::norms::{L2, Norm};
 use alg_tools::linops::Mapping;
 use alg_tools::bisection_tree::{
     BTFN,
@@ -75,10 +75,10 @@
     where
         I : AlgIterator
     {
-        assert!(ν_delta.is_none(), "Transport not implemented for Radon-squared prox term");
+        let mut y = μ_base.masses_dvector();
 
-        let mut y = μ_base.masses_vec();
-
+        assert!(μ_base.len() <= μ.len());
+        
         'i_and_w: for i in 0..=1 {
             // Optimise weights
             if μ.len() > 0 {
@@ -90,10 +90,12 @@
                                                μ.iter_locations()
                                                 .map(|ζ| - F::to_nalgebra_mixed(τv.apply(ζ))));
                 let mut x = μ.masses_dvector();
-                // Ugly hack because DVector::push doesn't push but copies.
-                let yvec = DVector::from_column_slice(y.as_slice());
+                y.extend(std::iter::repeat(0.0.to_nalgebra_mixed()).take(0.max(x.len()-y.len())));
+                assert_eq!(y.len(), x.len());
                 // Solve finite-dimensional subproblem.
-                stats.inner_iters += reg.solve_findim_l1squared(&yvec, &g̃, τ, &mut x, ε, config);
+                // TODO: This assumes that ν_delta has no common locations with μ-μ_base, to
+                // ignore it.
+                stats.inner_iters += reg.solve_findim_l1squared(&y, &g̃, τ, &mut x, ε, config);
 
                 // Update masses of μ based on solution of finite-dimensional subproblem.
                 μ.set_masses_dvector(&x);
@@ -107,7 +109,8 @@
             }
             
             // Calculate ‖μ - μ_base‖_ℳ
-            let n = μ.dist_matching(μ_base);
+            // TODO: This assumes that ν_delta has no common locations with μ-μ_base.
+            let n = μ.dist_matching(μ_base) + ν_delta.map_or(0.0, |ν| ν.norm(Radon));
         
             // Find a spike to insert, if needed.
             // This only check the overall tolerances, not tolerances on support of μ-μ_base or μ,
@@ -118,8 +121,6 @@
                     // Weight is found out by running the finite-dimensional optimisation algorithm
                     // above
                     *μ += DeltaMeasure { x : ξ, α : 0.0 };
-                    //*μ_base += DeltaMeasure { x : ξ, α : 0.0 };
-                    y.push(0.0.to_nalgebra_mixed());
                     stats.inserted += 1;
                 }
             };
@@ -133,12 +134,20 @@
         μ : &mut RNDM<F, N>,
         τv : &mut BTFN<F, GA, BTA, N>,
         μ_base : &RNDM<F, N>,
+        ν_delta: Option<&RNDM<F, N>>,
         τ : F,
         ε : F,
         config : &FBGenericConfig<F>,
         reg : &Reg,
+        fitness : Option<impl Fn(&RNDM<F, N>) -> F>,
     ) -> usize
     {
+        if config.fitness_merging {
+            if let Some(f) = fitness {
+                return μ.merge_spikes_fitness(config.merging, f, |&v| v)
+                        .1
+            }
+        }
         μ.merge_spikes(config.merging, |μ_candidate| {
             // Important: μ_candidate's new points are afterwards,
             // and do not conflict with μ_base.
@@ -147,7 +156,10 @@
             // after μ_candidate's extra points.
             // TODO: doesn't seem to work, maybe need to merge μ_base as well?
             // Although that doesn't seem to make sense.
-            let μ_radon = μ_candidate.sub_matching(μ_base);
+            let μ_radon = match ν_delta {
+                None => μ_candidate.sub_matching(μ_base),
+                Some(ν) => μ_candidate.sub_matching(μ_base) - ν,
+            };
             reg.verify_merge_candidate_radonsq(τv, μ_candidate, τ, ε, &config, &μ_radon)
             //let n = μ_candidate.dist_matching(μ_base);
             //reg.find_tolerance_violation_slack(τv, τ, ε, false, config, n).is_none()

mercurial