src/sliding_pdps.rs

branch
dev
changeset 63
7a8a55fd41c0
parent 61
4f468d35fa29
child 66
fe47ad484deb
--- a/src/sliding_pdps.rs	Thu Feb 26 11:38:43 2026 -0500
+++ b/src/sliding_pdps.rs	Thu Feb 26 11:36:22 2026 -0500
@@ -6,13 +6,11 @@
 use crate::fb::*;
 use crate::forward_model::{BoundedCurvature, BoundedCurvatureGuess};
 use crate::measures::merging::SpikeMerging;
-use crate::measures::{DiscreteMeasure, Radon, RNDM};
+use crate::measures::{DiscreteMeasure, RNDM};
 use crate::plot::Plotter;
 use crate::prox_penalty::{ProxPenalty, StepLengthBoundPair};
 use crate::regularisation::SlidingRegTerm;
-use crate::sliding_fb::{
-    aposteriori_transport, initial_transport, SlidingFBConfig, TransportConfig, TransportStepLength,
-};
+use crate::sliding_fb::{SlidingFBConfig, Transport, TransportConfig, TransportStepLength};
 use crate::types::*;
 use alg_tools::convex::{Conjugable, Prox, Zero};
 use alg_tools::direct_product::Pair;
@@ -24,13 +22,10 @@
 };
 use alg_tools::mapping::{DifferentiableMapping, DifferentiableRealMapping, Instance};
 use alg_tools::nalgebra_support::ToNalgebraRealField;
-use alg_tools::norms::{Norm, L2};
+use alg_tools::norms::L2;
 use anyhow::ensure;
 use numeric_literals::replace_float_literals;
 use serde::{Deserialize, Serialize};
-//use colored::Colorize;
-//use nalgebra::{DVector, DMatrix};
-use std::iter::Iterator;
 
 /// Settings for [`pointsource_sliding_pdps_pair`].
 #[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
@@ -148,7 +143,7 @@
 
     // Initialise iterates
     let mut μ = μ0.unwrap_or_else(|| DiscreteMeasure::new());
-    let mut γ1 = DiscreteMeasure::new();
+    let mut γ = Transport::new();
     //let zero_z = z.similar_origin();
 
     // Set up parameters
@@ -186,22 +181,25 @@
     let κ = τ * σ_d * ψ / ((1.0 - β) * ψ - τ * σ_d * bigM);
     //  The factor two in the manuscript disappears due to the definition of 𝚹 being
     // for ‖x-y‖₂² instead of c_2(x, y)=‖x-y‖₂²/2.
-    let (maybe_ℓ_F, maybe_transport_lip) = f.curvature_bound_components(config.guess);
-    let transport_lip = maybe_transport_lip?;
-    let calculate_θ = |ℓ_F, max_transport| {
-        let ℓ_r = transport_lip * max_transport;
-        config.transport.θ0 / (τ * (ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport)
-    };
-    let mut θ_or_adaptive = match maybe_ℓ_F {
-        // We assume that the residual is decreasing.
-        Ok(ℓ_F) => TransportStepLength::AdaptiveMax {
-            l: ℓ_F, // TODO: could estimate computing the real reesidual
-            max_transport: 0.0,
-            g: calculate_θ,
-        },
-        Err(_) => {
-            TransportStepLength::FullyAdaptive {
-                l: F::EPSILON, max_transport: 0.0, g: calculate_θ
+
+    let mut θ_or_adaptive = match f.curvature_bound_components(config.guess) {
+        (_, Err(_)) => TransportStepLength::Fixed(config.transport.θ0),
+        (maybe_ℓ_F, Ok(transport_lip)) => {
+            let calculate_θτ = move |ℓ_F, max_transport| {
+                let ℓ_r = transport_lip * max_transport;
+                config.transport.θ0 / ((ℓ + ℓ_F + ℓ_r) + κ * bigθ * max_transport / τ)
+            };
+            match maybe_ℓ_F {
+                Ok(ℓ_F) => TransportStepLength::AdaptiveMax {
+                    l: ℓ_F, // TODO: could estimate computing the real reesidual
+                    max_transport: 0.0,
+                    g: calculate_θτ,
+                },
+                Err(_) => TransportStepLength::FullyAdaptive {
+                    l: F::EPSILON, // Start with something very small to estimate differentials
+                    max_transport: 0.0,
+                    g: calculate_θτ,
+                },
             }
         }
     };
@@ -243,26 +241,25 @@
 
         //dbg!(&μ);
 
-        let (μ_base_masses, mut μ_base_minus_γ0) =
-            initial_transport(&mut γ1, &mut μ, τ, &mut θ_or_adaptive, v);
+        γ.initial_transport(&μ, τ, &mut θ_or_adaptive, v);
+
+        let mut attempts = 0;
 
         // Solve finite-dimensional subproblem several times until the dual variable for the
         // regularisation term conforms to the assumptions made for the transport above.
-        let (maybe_d, _within_tolerances, mut τv̆, z_new) = 'adapt_transport: loop {
+        let (maybe_d, _within_tolerances, mut τv̆, z_new, μ̆) = 'adapt_transport: loop {
+            // Set initial guess for μ=μ^{k+1}.
+            γ.μ̆_into(&mut μ);
+            let μ̆ = μ.clone();
+
             // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
-            // let residual_μ̆ =
-            //     calculate_residual2(Pair(&γ1, &z), Pair(&μ_base_minus_γ0, &zero_z), opA, b);
-            // let Pair(mut τv̆, τz̆) = opA.preadjoint().apply(residual_μ̆ * τ);
-            // TODO: might be able to optimise the measure sum working as calculate_residual2 above.
-            let Pair(mut τv̆, τz̆) = f.differential(Pair(&γ1 + &μ_base_minus_γ0, &z)) * τ;
+            let Pair(mut τv̆, τz̆) = f.differential(Pair(&μ̆, &z)) * τ;
             // opKμ.preadjoint().gemv(&mut τv̆, τ, y, 1.0);
 
             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
             let (maybe_d, within_tolerances) = prox_penalty.insert_and_reweigh(
                 &mut μ,
                 &mut τv̆,
-                &γ1,
-                Some(&μ_base_minus_γ0),
                 τ,
                 ε,
                 &config.insertion,
@@ -277,59 +274,37 @@
             z_new = fnR.prox(σ_p, z_new + &z);
 
             // A posteriori transport adaptation.
-            if aposteriori_transport(
-                &mut γ1,
-                &mut μ,
-                &mut μ_base_minus_γ0,
-                &μ_base_masses,
+            if γ.aposteriori_transport(
+                &μ,
+                &μ̆,
+                &mut τv̆,
                 Some(z_new.dist2(&z)),
                 ε,
                 &config.transport,
+                &mut attempts,
             ) {
-                break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new);
+                break 'adapt_transport (maybe_d, within_tolerances, τv̆, z_new, μ̆);
             }
         };
 
-        stats.untransported_fraction = Some({
-            assert_eq!(μ_base_masses.len(), γ1.len());
-            let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
-            let source = μ_base_masses.iter().map(|v| v.abs()).sum();
-            (a + μ_base_minus_γ0.norm(Radon), b + source)
-        });
-        stats.transport_error = Some({
-            assert_eq!(μ_base_masses.len(), γ1.len());
-            let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
-            (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
-        });
+        γ.get_transport_stats(&mut stats, &μ);
 
         // Merge spikes.
         // This crucially expects the merge routine to be stable with respect to spike locations,
         // and not to performing any pruning. That is be to done below simultaneously for γ.
-        let ins = &config.insertion;
-        if ins.merge_now(&state) {
+        if config.insertion.merge_now(&state) {
             stats.merged += prox_penalty.merge_spikes_no_fitness(
                 &mut μ,
                 &mut τv̆,
-                &γ1,
-                Some(&μ_base_minus_γ0),
+                &μ̆,
                 τ,
                 ε,
-                ins,
+                &config.insertion,
                 &reg,
-                //Some(|μ̃ : &RNDM<N, F>| calculate_residual(Pair(μ̃, &z), opA, b).norm2_squared_div2()),
             );
         }
 
-        // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
-        // latter needs to be pruned when μ is.
-        // TODO: This could do with a two-vector Vec::retain to avoid copies.
-        let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
-        if μ_new.len() != μ.len() {
-            let mut μ_iter = μ.iter_spikes();
-            γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
-            stats.pruned += μ.len() - μ_new.len();
-            μ = μ_new;
-        }
+        γ.prune_compat(&mut μ, &mut stats);
 
         // Do dual update
         // opKμ.gemv(&mut y, σ_d*(1.0 + ω), &μ, 1.0);    // y = y + σ_d K[(1+ω)(μ,z)^{k+1}]

mercurial