src/sliding_fb.rs

branch
dev
changeset 34
efa60bc4f743
parent 32
56c8adc32b09
child 35
b087e3eab191
--- a/src/sliding_fb.rs	Tue Aug 01 10:32:12 2023 +0300
+++ b/src/sliding_fb.rs	Thu Aug 29 00:00:00 2024 -0500
@@ -8,19 +8,17 @@
 //use colored::Colorize;
 //use nalgebra::{DVector, DMatrix};
 use itertools::izip;
-use std::iter::{Map, Flatten};
+use std::iter::Iterator;
 
 use alg_tools::iterate::{
     AlgIteratorFactory,
     AlgIteratorState
 };
-use alg_tools::euclidean::{
-    Euclidean,
-    Dot
-};
+use alg_tools::euclidean::Euclidean;
 use alg_tools::sets::Cube;
 use alg_tools::loc::Loc;
 use alg_tools::mapping::{Apply, Differentiable};
+use alg_tools::norms::{Norm, L2};
 use alg_tools::bisection_tree::{
     BTFN,
     PreBTFN,
@@ -37,10 +35,7 @@
 use alg_tools::nalgebra_support::ToNalgebraRealField;
 
 use crate::types::*;
-use crate::measures::{
-    DiscreteMeasure,
-    DeltaMeasure,
-};
+use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon};
 use crate::measures::merging::{
     //SpikeMergingMethod,
     SpikeMerging,
@@ -69,15 +64,15 @@
 pub struct SlidingFBConfig<F : Float> {
     /// Step length scaling
     pub τ0 : F,
-    /// Transport smoothness assumption
-    pub ℓ0 : F,
-    /// Inverse of the scaling factor $θ$ of the 2-norm-squared transport cost.
-    /// This means that $τθ$ is the step length for the transport step.
-    pub inverse_transport_scaling : F,
-    /// Factor for deciding transport reduction based on smoothness assumption violation
-    pub minimum_goodness_factor : F,
-    /// Maximum rays to retain in transports from each source.
-    pub maximum_rays : usize,
+    /// Transport step length $θ$ normalised to $(0, 1)$.
+    pub θ0 : F,
+    /// Maximum transport mass scaling.
+    // /// The maximum transported mass is this factor times $\norm{b}^2/(2α)$.
+    // pub max_transport_scale : F,
+    /// Transport tolerance wrt. ω
+    pub transport_tolerance_ω : F,
+    /// Transport tolerance wrt. ∇v
+    pub transport_tolerance_dv : F,
     /// Generic parameters
     pub insertion : FBGenericConfig<F>,
 }
@@ -87,129 +82,32 @@
     fn default() -> Self {
         SlidingFBConfig {
             τ0 : 0.99,
-            ℓ0 : 1.5,
-            inverse_transport_scaling : 1.0,
-            minimum_goodness_factor : 1.0, // TODO: totally arbitrary choice,
-                                           // should be scaled by problem data?
-            maximum_rays : 10,
+            θ0 : 0.99,
+            //max_transport_scale : 10.0,
+            transport_tolerance_ω : 1.0, // TODO: no idea what this should be
+            transport_tolerance_dv : 1.0, // TODO: no idea what this should be
             insertion : Default::default()
         }
     }
 }
 
-/// A transport ray (including various additional computational information).
-#[derive(Clone, Debug)]
-pub struct Ray<Domain, F : Num> {
-    /// The destination of the ray, and the mass. The source is indicated in a [`RaySet`].
-    δ : DeltaMeasure<Domain, F>,
-    /// Goodness of the data term for the aray: $v(z)-v(y)-⟨∇v(x), z-y⟩ + ℓ‖z-y‖^2$.
-    goodness : F,
-    /// Goodness of the regularisation term for the ray: $w(z)-w(y)$.
-    /// Initially zero until $w$ can be constructed.
-    reg_goodness : F,
-    /// Indicates that this ray also forms a component in γ^{k+1} with the mass `to_return`.
-    to_return : F,
-}
-
-/// A set of transport rays with the same source point.
-#[derive(Clone, Debug)]
-pub struct RaySet<Domain, F : Num> {
-    /// Source of every ray in thset
-    source : Domain,
-    /// Mass of the diagonal ray, with destination the same as the source.
-    diagonal: F,
-    /// Goodness of the data term for the diagonal ray with $z=x$:
-    /// $v(x)-v(y)-⟨∇v(x), x-y⟩ + ℓ‖x-y‖^2$.
-    diagonal_goodness : F,
-    /// Goodness of the data term for the diagonal ray with $z=x$: $w(x)-w(y)$.
-    diagonal_reg_goodness : F,
-    /// The non-diagonal rays.
-    rays : Vec<Ray<Domain, F>>,
-}
-
+/// Scale each |γ|_i ≠ 0 by q_i=q̄/g(γ_i)
 #[replace_float_literals(F::cast_from(literal))]
-impl<Domain, F : Float> RaySet<Domain, F> {
-    fn non_diagonal_mass(&self) -> F {
-        self.rays
-            .iter()
-            .map(|Ray{ δ : DeltaMeasure{ α, .. }, .. }| *α)
-            .sum()
-    }
-
-    fn total_mass(&self) -> F {
-        self.non_diagonal_mass() + self.diagonal
-    }
-
-    fn targets<'a>(&'a self)
-    -> Map<
-        std::slice::Iter<'a, Ray<Domain, F>>,
-        fn(&'a Ray<Domain, F>) -> &'a DeltaMeasure<Domain, F>
-    > {
-        fn get_δ<'b, Domain, F : Float>(Ray{ δ, .. }: &'b Ray<Domain, F>)
-        -> &'b DeltaMeasure<Domain, F> {
-            δ
+fn scale_down<'a, I, F, G, const N : usize>(
+    iter : I,
+    q̄ : F,
+    mut g : G
+) where F : Float,
+        I : Iterator<Item = &'a mut DeltaMeasure<Loc<F,N>, F>>,
+        G : FnMut(&DeltaMeasure<Loc<F,N>, F>) -> F {
+    iter.for_each(|δ| {
+        if δ.α != 0.0 {
+            let b = g(δ);
+            if b * δ.α > 0.0 {
+                δ.α *= q̄/b;
+            }
         }
-        self.rays
-            .iter()
-            .map(get_δ)
-    }
-
-    // fn non_diagonal_goodness(&self) -> F {
-    //     self.rays
-    //         .iter()
-    //         .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| {
-    //             α * (goodness + reg_goodness)
-    //         })
-    //         .sum()
-    // }
-
-    // fn total_goodness(&self) -> F {
-    //     self.non_diagonal_goodness() + (self.diagonal_goodness + self.diagonal_reg_goodness)
-    // }
-
-    fn non_diagonal_badness(&self) -> F {
-        self.rays
-            .iter()
-            .map(|&Ray{ δ : DeltaMeasure{ α, .. }, goodness, reg_goodness, .. }| {
-                0.0.max(- α * (goodness + reg_goodness))
-            })
-            .sum()
-    }
-
-    fn total_badness(&self) -> F {
-        self.non_diagonal_badness()
-        + 0.0.max(- self.diagonal * (self.diagonal_goodness + self.diagonal_reg_goodness))
-    }
-
-    fn total_return(&self) -> F {
-        self.rays
-            .iter()
-            .map(|&Ray{ to_return, .. }| to_return)
-            .sum()
-    }
-}
-
-#[replace_float_literals(F::cast_from(literal))]
-impl<Domain : Clone, F : Num> RaySet<Domain, F> {
-    fn return_targets<'a>(&'a self)
-    -> Flatten<Map<
-        std::slice::Iter<'a, Ray<Domain, F>>,
-        fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>>
-    >> {
-        fn get_return<'b, Domain : Clone, F : Num>(ray: &'b Ray<Domain, F>)
-        -> Option<DeltaMeasure<Domain, F>> {
-            (ray.to_return != 0.0).then_some(
-                DeltaMeasure{x : ray.δ.x.clone(), α : ray.to_return}
-            )
-        }
-        let tmp : Map<
-            std::slice::Iter<'a, Ray<Domain, F>>,
-            fn(&'a Ray<Domain, F>) -> Option<DeltaMeasure<Domain, F>>
-        > = self.rays
-                .iter()
-                .map(get_return);
-        tmp.flatten()
-    }
+    });
 }
 
 /// Iteratively solve the pointsource localisation problem using sliding forward-backward
@@ -218,7 +116,7 @@
 /// The parametrisatio is as for [`pointsource_fb_reg`].
 /// Inertia is currently not supported.
 #[replace_float_literals(F::cast_from(literal))]
-pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, G𝒟, S, K, Reg, const N : usize>(
+pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>(
     opA : &'a A,
     b : &A::Observable,
     reg : Reg,
@@ -238,8 +136,9 @@
           + Lipschitz<&'a 𝒟, FloatType=F> + TransportLipschitz<L2Squared, FloatType=F>,
       BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
-      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>>,
-      𝒟::Codomain : RealMapping<F, N>,
+      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>,
+                                          Codomain = BTFN<F, G𝒟, BT𝒟, N>>,
+      BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
       S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>
          + Differentiable<Loc<F, N>, Output=Loc<F,N>>,
       K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
@@ -251,25 +150,25 @@
       Reg : SlidingRegTerm<F, N> {
 
     assert!(sfbconfig.τ0 > 0.0 &&
-            sfbconfig.inverse_transport_scaling > 0.0 &&
-            sfbconfig.ℓ0 > 0.0);
+            sfbconfig.θ0 > 0.0);
 
     // Set up parameters
     let config = &sfbconfig.insertion;
     let op𝒟norm = op𝒟.opnorm_bound();
-    let θ = sfbconfig.inverse_transport_scaling;
-    let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap()
-                            .max(opA.transport_lipschitz_factor(L2Squared) * θ);
-    let ℓ = sfbconfig.ℓ0; // TODO: v scaling?
+    //let max_transport = sfbconfig.max_transport_scale
+    //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0);
+    //let tlip = opA.transport_lipschitz_factor(L2Squared) * max_transport;
+    //let ℓ = 0.0;
+    let θ = sfbconfig.θ0; // (ℓ + tlip);
+    let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
     // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
     // by τ compared to the conditional gradient approach.
     let tolerance = config.tolerance * τ * reg.tolerance_scaling();
     let mut ε = tolerance.initial();
 
     // Initialise iterates
-    let mut μ : DiscreteMeasure<Loc<F, N>, F> = DiscreteMeasure::new();
-    let mut μ_transported_base = DiscreteMeasure::new();
-    let mut γ_hat : Vec<RaySet<Loc<F, N>, F>> = Vec::new();   // γ̂_k and extra info
+    let mut μ = DiscreteMeasure::new();
+    let mut γ1 = DiscreteMeasure::new();
     let mut residual = -b;
     let mut stats = IterInfo::new();
 
@@ -279,272 +178,170 @@
         // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
         // has no significant overhead. For some reosn Rust doesn't allow us simply moving
         // the residual and replacing it below before the end of this closure.
-        residual *= -τ;
         let r = std::mem::replace(&mut residual, opA.empty_observable());
-        let minus_τv = opA.preadjoint().apply(r);
-
-        // Save current base point and shift μ to new positions.
-        let μ_base = μ.clone();
-        for δ in μ.iter_spikes_mut() {
-            δ.x += minus_τv.differential(&δ.x) * θ;
-        }
-        let mut μ_transported = μ.clone();
-
-        assert_eq!(μ.len(), γ_hat.len());
+        let v = opA.preadjoint().apply(r);
 
-        // Calculate the goodness λ formed from γ_hat (≈ γ̂_k) and γ^{k+1}, where the latter
-        // transports points x from μ_base to points y in μ as shifted above, or “returns”
-        // them “home” to z given by the rays in γ_hat. Returning is necessary if the rays
-        // are not “good” for the smoothness assumptions, or if γ_hat has more mass than
-        // μ_base.
-        let mut total_goodness = 0.0;     // data term goodness
-        let mut total_reg_goodness = 0.0; // regulariser goodness
-        let minimum_goodness = - ε * sfbconfig.minimum_goodness_factor;
-
-        for (δ, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) {
-            // Calculate data term goodness for all rays.
-            let &DeltaMeasure{ x : ref y, α : δ_mass } = δ;
-            let x = &r.source;
-            let mvy = minus_τv.apply(y);
-            let mdvx = minus_τv.differential(x);
-            let mut r_total_mass = 0.0; // Total mass of all rays with source r.source.
-            let mut bad_mass = 0.0;
-            let mut calc_goodness = |goodness : &mut F, reg_goodness : &mut F, α, z : &Loc<F, N>| {
-                *reg_goodness = 0.0; // Initial guess
-                *goodness = mvy - minus_τv.apply(z) + mdvx.dot(&(z-y))
-                            + ℓ * z.dist2_squared(&y);
-                total_goodness += *goodness * α;
-                r_total_mass += α; // TODO: should this include to_return from staging? (Probably not)
-                if *goodness < 0.0 {
-                    bad_mass += α;
-                }
+        // Save current base point and shift μ to new positions. Idea is that
+        //  μ_base(_masses) = μ^k (vector of masses)
+        //  μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
+        //  γ1 = π_♯^1γ^{k+1}
+        //  μ = μ^{k+1}
+        let μ_base_masses : Vec<F> = μ.iter_masses().collect();
+        let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
+        // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
+        let mut sum_norm_dv_times_γinit = 0.0;
+        let mut sum_abs_γinit = 0.0;
+        //let mut sum_norm_dv = 0.0;
+        let γ_prev_len = γ1.len();
+        assert!(μ.len() >= γ_prev_len);
+        γ1.extend(μ[γ_prev_len..].iter().cloned());
+        for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) {
+            let d_v_x = v.differential(&δ.x);
+            // If old transport has opposing sign, the new transport will be none.
+            ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
+                0.0
+            } else {
+                δ.α
             };
-            for ray in r.rays.iter_mut() {
-                calc_goodness(&mut ray.goodness, &mut ray.reg_goodness, ray.δ.α, &ray.δ.x);
-            }
-            calc_goodness(&mut r.diagonal_goodness, &mut r.diagonal_reg_goodness, r.diagonal, x);
-
-            // If the total mass of the ray set is less than that of μ at the same source,
-            // a diagonal component needs to be added to be able to (attempt to) transport
-            // all mass of μ. In the opposite case, we need to construct γ_{k+1} to ‘return’
-            // the the extra mass of γ̂_k to the target z. We return mass from the oldest “bad”
-            // rays in the set.
-            if δ_mass >= r_total_mass {
-                r.diagonal += δ_mass - r_total_mass;
-            } else {
-                let mut reduce_transport = r_total_mass - δ_mass;
-                let mut good_needed = (bad_mass - reduce_transport).max(0.0);
-                // NOTE: reg_goodness is zero at this point, so it is not used in this code.
-                let mut reduce_ray = |goodness, to_return : Option<&mut F>, α : &mut F| {
-                    if reduce_transport > 0.0 {
-                        let return_amount = if goodness < 0.0 {
-                            α.min(reduce_transport)
-                        } else {
-                            let amount = α.min(good_needed);
-                            good_needed -= amount;
-                            amount
-                        };
-
-                        if return_amount > 0.0 {
-                            reduce_transport -= return_amount;
-                            // Adjust total goodness by returned amount
-                            total_goodness -= goodness * return_amount;
-                            to_return.map(|tr| *tr += return_amount);
-                            *α -= return_amount;
-                            *α > 0.0
-                        } else {
-                            true
-                        }
-                    } else {
-                        true
-                    }
-                };
-                r.rays.retain_mut(|ray| {
-                    reduce_ray(ray.goodness, Some(&mut ray.to_return), &mut ray.δ.α)
-                });
-                // A bad diagonal is simply reduced without any 'return'.
-                // It was, after all, just added to match μ, but there is no need to match it.
-                // It's just a heuristic.
-                // TODO: Maybe a bad diagonal should be the first to go.
-                reduce_ray(r.diagonal_goodness, None, &mut r.diagonal);
+            δ.x -= d_v_x * (θ * δ.α.signum()); // This is δ.α.signum() when δ.α ≠ 0.
+            ρ.x = δ.x;
+            let nrm = d_v_x.norm(L2);
+            let a = ρ.α.abs();
+            let v = nrm * a;
+            if v > 0.0 {
+                sum_norm_dv_times_γinit += v;
+                sum_abs_γinit += a;
             }
         }
 
+        // A priori transport adaptation based on bounding ∫ ⟨∇v(x), z-y⟩ dλ(x, y, z).
+        // This is just one option, there are many.
+        let t = ε * sfbconfig.transport_tolerance_dv;
+        if sum_norm_dv_times_γinit > t {
+            // Scale each |γ|_i by q_i=q̄/‖vx‖_i such that ∑_i |γ|_i q_i ‖vx‖_i = t
+            // TODO: store the closure values above?
+            scale_down(γ1.iter_spikes_mut(),
+                       t / sum_abs_γinit,
+                       |δ| v.differential(&δ.x).norm(L2));
+        }
+        //println!("|γ| = {}, |μ| = {}", γ1.norm(crate::measures::Radon), μ.norm(crate::measures::Radon));
+
         // Solve finite-dimensional subproblem several times until the dual variable for the
         // regularisation term conforms to the assumptions made for the transport above.
         let (d, within_tolerances) = 'adapt_transport: loop {
-            // If transport violates goodness requirements, shift it to ‘return’ mass to z,
-            // forcing y = z. Based on the badness of each ray set (sum of bad rays' goodness),
-            // we proportionally distribute the reductions to each ray set, and within each ray
-            // set, prioritise reducing the oldest bad rays' weight.
-            let tg = total_goodness + total_reg_goodness;
-            let adaptation_needed = minimum_goodness - tg;
-            if adaptation_needed > 0.0 {
-                let total_badness = γ_hat.iter().map(|r| r.total_badness()).sum();
-
-                let mut return_ray = |goodness : F,
-                                      reg_goodness : F,
-                                      to_return : Option<&mut F>,
-                                      α : &mut F,
-                                      left_to_return : &mut F| {
-                    let g = goodness + reg_goodness;
-                    assert!(*α >= 0.0 && *left_to_return >= 0.0);
-                    if *left_to_return > 0.0 && g < 0.0 {
-                        let return_amount = (*left_to_return / (-g)).min(*α);
-                        *left_to_return -= (-g) * return_amount;
-                        total_goodness -= goodness * return_amount;
-                        total_reg_goodness -= reg_goodness * return_amount;
-                        to_return.map(|tr| *tr += return_amount);
-                        *α -= return_amount;
-                        *α > 0.0
-                    } else {
-                        true
-                    }
-                };
-                
-                for r in γ_hat.iter_mut() {
-                    let mut left_to_return = adaptation_needed * r.total_badness() / total_badness;
-                    if left_to_return > 0.0 {
-                        for ray in r.rays.iter_mut() {
-                            return_ray(ray.goodness, ray.reg_goodness,
-                                       Some(&mut ray.to_return), &mut ray.δ.α, &mut left_to_return);
-                        }
-                        return_ray(r.diagonal_goodness, r.diagonal_reg_goodness,
-                                   None, &mut r.diagonal, &mut left_to_return);
-                    }
-                }
+            // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
+            for (δ_γ1, δ_μ_base_minus_γ0, &α_μ_base) in izip!(γ1.iter_spikes(),
+                                                              μ_base_minus_γ0.iter_spikes_mut(),
+                                                              μ_base_masses.iter()) {
+                δ_μ_base_minus_γ0.set_mass(α_μ_base - δ_γ1.get_mass());
             }
 
-            // Construct μ_k + (π_#^1-π_#^0)γ_{k+1}.
-            // This can be broken down into
-            //
-            // μ_transported_base = [μ - π_#^0 (γ_shift + γ_return)] + π_#^1 γ_return, and
-            // μ_transported = π_#^1 γ_shift
-            //
-            // where γ_shift is our “true” γ_{k+1}, and γ_return is the return compoennt.
-            // The former can be constructed from δ.x and δ_new.x for δ in μ_base and δ_new in μ
-            // (which has already been shifted), and the mass stored in a γ_hat ray's δ measure
-            // The latter can be constructed from γ_hat rays' source and destination with the
-            // to_return mass.
-            //
-            // Note that μ_transported is constructed to have the same spike locations as μ, but
-            // to have same length as μ_base. This loop does not iterate over the spikes of μ
-            // (and corresponding transports of γ_hat) that have been newly     added in the current
-            // 'adapt_transport loop.
-            for (δ, δ_transported, r) in izip!(μ_base.iter_spikes(),
-                                               μ_transported.iter_spikes_mut(),
-                                               γ_hat.iter()) {
-                let &DeltaMeasure{ref x, α} = δ;
-                debug_assert_eq!(*x, r.source);
-                let shifted_mass = r.total_mass();
-                let ret_mass = r.total_return();
-                // μ - π_#^0 (γ_shift + γ_return)
-                μ_transported_base += DeltaMeasure { x : *x, α : α - shifted_mass - ret_mass };
-                // π_#^1 γ_return
-                μ_transported_base.extend(r.return_targets());
-                // π_#^1 γ_shift
-                δ_transported.set_mass(shifted_mass);
-            }
             // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b)
-            let transported_residual = calculate_residual2(&μ_transported,
-                                                           &μ_transported_base,
-                                                           opA, b);
-            let transported_minus_τv = opA.preadjoint()
-                                          .apply(transported_residual);
+            let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
+            let transported_minus_τv̆ = opA.preadjoint().apply(residual_μ̆ * (-τ));
 
             // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
-            let (mut d, within_tolerances) = insert_and_reweigh(
-                &mut μ, &transported_minus_τv, &μ_transported, Some(&μ_transported_base),
+            let (d, within_tolerances) = insert_and_reweigh(
+                &mut μ, &transported_minus_τv̆, &γ1, Some(&μ_base_minus_γ0),
                 op𝒟, op𝒟norm,
                 τ, ε,
-                config, &reg, state, &mut stats
+                config,
+                &reg, state, &mut stats,
             );
 
-            // We have  d = ω0 - τv - 𝒟μ = -𝒟(μ - μ^k) - τv; more precisely
-            //          d = minus_τv + op𝒟.preapply(μ_diff(μ, μ_transported, config));
-            // We “essentially” assume that the subdifferential w of the regularisation term
-            // satisfies w'(y)=0, so for a “goodness” estimate τ[w(y)-w(z)-w'(y)(z-y)]
-            // that incorporates the assumption, we need to calculate τ[w(z) - w(y)] for
-            // some w in the subdifferential of the regularisation term, such that
-            // -ε ≤ τw - d ≤ ε. This is done by [`RegTerm::goodness`].
-            for r in γ_hat.iter_mut() {
-                for ray in r.rays.iter_mut() {
-                    ray.reg_goodness = reg.goodness(&mut d, &μ, &r.source, &ray.δ.x, τ, ε, config);
-                    total_reg_goodness += ray.reg_goodness * ray.δ.α;
+            // A posteriori transport adaptation based on bounding (1/τ)∫ ω(z) - ω(y) dλ(x, y, z).
+            let all_ok = if false { // Basic check
+                // If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
+                // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
+                // at that point to zero, and retry.
+                let mut all_ok = true;
+                for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
+                    if α_μ == 0.0 && *α_γ1 != 0.0 {
+                        all_ok = false;
+                        *α_γ1 = 0.0;
+                    }
                 }
-            }
+                all_ok
+            } else {
+                // TODO: Could maybe optimise, as this is also formed in insert_and_reweigh above.
+                let mut minus_ω = op𝒟.apply(γ1.sub_matching(&μ) + &μ_base_minus_γ0);
+
+                // let vpos = γ1.iter_spikes()
+                //              .filter(|δ| δ.α > 0.0)
+                //              .map(|δ| minus_ω.apply(&δ.x))
+                //              .reduce(F::max)
+                //              .and_then(|threshold| {
+                //                 minus_ω.minimise_below(threshold,
+                //                                         ε * config.refinement.tolerance_mult,
+                //                                         config.refinement.max_steps)
+                //                        .map(|(_z, minus_ω_z)| minus_ω_z)
+                //              });
 
-            // If update of regularisation term goodness didn't invalidate minimum goodness
-            // requirements, we have found our step. Otherwise we need to keep reducing
-            // transport by repeating the loop.
-            if total_goodness + total_reg_goodness >= minimum_goodness {
+                // let vneg = γ1.iter_spikes()
+                //              .filter(|δ| δ.α < 0.0)
+                //              .map(|δ| minus_ω.apply(&δ.x))
+                //              .reduce(F::min)
+                //              .and_then(|threshold| {
+                //                 minus_ω.maximise_above(threshold,
+                //                                         ε * config.refinement.tolerance_mult,
+                //                                         config.refinement.max_steps)
+                //                        .map(|(_z, minus_ω_z)| minus_ω_z)
+                //              });
+                let (_, vpos) = minus_ω.minimise(ε * config.refinement.tolerance_mult,
+                                                 config.refinement.max_steps);
+                let (_, vneg) = minus_ω.maximise(ε * config.refinement.tolerance_mult,
+                                                 config.refinement.max_steps);
+            
+                let t = τ * ε * sfbconfig.transport_tolerance_ω;
+                let val = |δ : &DeltaMeasure<Loc<F, N>, F>| {
+                    δ.α * (minus_ω.apply(&δ.x) - if δ.α >= 0.0 { vpos } else { vneg })
+                    // match if δ.α >= 0.0 { vpos } else { vneg } {
+                    //     None => 0.0,
+                    //     Some(v) => δ.α * (minus_ω.apply(&δ.x) - v)
+                    // }
+                };
+                // Calculate positive/bad (rp) values under the integral.
+                // Also store sum of masses for the positive entries.
+                let (rp, w) = γ1.iter_spikes().fold((0.0, 0.0), |(p, w), δ| {
+                    let v = val(δ);
+                    if v <= 0.0 { (p, w) } else { (p + v, w + δ.α.abs()) }
+                });
+
+                if rp > t {
+                    // TODO: store v above?
+                    scale_down(γ1.iter_spikes_mut(), t / w, val);
+                    false
+                } else {
+                    true
+                }
+            };
+
+            if all_ok {
                 break 'adapt_transport (d, within_tolerances)
             }
         };
 
-        // Update γ_hat to new location
-        for (δ_new, r) in izip!(μ.iter_spikes(), γ_hat.iter_mut()) {
-            // Prune rays that only had a return component, as the return component becomes
-            // a diagonal in γ̂^{k+1}.
-            r.rays.retain(|ray| ray.δ.α != 0.0);
-            // Otherwise zero out the return component, or stage rays for pruning
-            // to keep memory and computational demands reasonable.
-            let n_rays = r.rays.len();
-            for (ray, ir) in izip!(r.rays.iter_mut(), (0..n_rays).rev()) {
-                if ir >= sfbconfig.maximum_rays {
-                    // Only keep sfbconfig.maximum_rays - 1 previous rays, staging others for
-                    // pruning in next step.
-                    ray.to_return = ray.δ.α;
-                    ray.δ.α = 0.0;
-                } else {
-                    ray.to_return = 0.0;
-                }
-                ray.goodness = 0.0; // TODO: probably not needed
-                ray.reg_goodness = 0.0;
-            }
-            // Add a new ray for the currently diagonal component
-            if r.diagonal > 0.0 {
-                r.rays.push(Ray{
-                    δ : DeltaMeasure{x : r.source, α : r.diagonal},
-                    goodness : 0.0,
-                    reg_goodness : 0.0,
-                    to_return : 0.0,
-                });
-                // TODO: Maybe this does not need to be done here, and is sufficent to to do where
-                // the goodness is calculated.
-                r.diagonal = 0.0;
-            }
-            r.diagonal_goodness = 0.0;
+        stats.untransported_fraction = Some({
+            assert_eq!(μ_base_masses.len(), γ1.len());
+            let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
+            let source = μ_base_masses.iter().map(|v| v.abs()).sum();
+            (a + μ_base_minus_γ0.norm(Radon), b + source)
+        });
+        stats.transport_error = Some({
+            assert_eq!(μ_base_masses.len(), γ1.len());
+            let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
+            let err = izip!(μ.iter_masses(), γ1.iter_masses()).map(|(v,w)| (v-w).abs()).sum();
+            (a + err, b + γ1.norm(Radon))
+        });
 
-            // Shift source
-            r.source = δ_new.x;
-        }
-        // Extend to new spikes
-        γ_hat.extend(μ[γ_hat.len()..].iter().map(|δ_new| {
-            RaySet{
-                source : δ_new.x,
-                rays : [].into(),
-                diagonal : 0.0,
-                diagonal_goodness : 0.0,
-                diagonal_reg_goodness : 0.0
-            }
-        }));
-
-        // Prune spikes with zero weight. This also moves the marginal differences of corresponding
-        // transports from γ_hat to γ_pruned_marginal_diff.
-        // TODO: optimise standard prune with swap_remove.
-        μ_transported_base.clear();
-        let mut i = 0;
-        assert_eq!(μ.len(), γ_hat.len());
-        while i < μ.len() {
-            if μ[i].α == F::ZERO {
-                μ.swap_remove(i);
-                let r = γ_hat.swap_remove(i);
-                μ_transported_base.extend(r.targets().cloned());
-                μ_transported_base -= DeltaMeasure{ α : r.non_diagonal_mass(), x : r.source };
-            } else {
-                i += 1;
-            }
+        // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
+        // latter needs to be pruned when μ is.
+        // TODO: This could do with a two-vector Vec::retain to avoid copies.
+        let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
+        if μ_new.len() != μ.len() {
+            let mut μ_iter = μ.iter_spikes();
+            γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
+            μ = μ_new;
         }
 
         // TODO: how to merge?
@@ -562,7 +359,7 @@
             // Plot if so requested
             plotter.plot_spikes(
                 format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
-                "start".to_string(), Some(&minus_τv),
+                "start".to_string(), None::<&A::PreadjointCodomain>, // TODO: Should be Some(&((-τ) * v)), but not implemented
                 reg.target_bounds(τ, ε_prev), &μ,
             );
             // Calculate mean inner iterations and reset relevant counters.

mercurial