src/sliding_fb.rs

Mon, 06 Jan 2025 11:32:57 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Mon, 06 Jan 2025 11:32:57 -0500
branch
dev
changeset 36
fb911f72e698
parent 35
b087e3eab191
child 37
c5d8bd1a7728
permissions
-rw-r--r--

Factor fix

/*!
Solver for the point source localisation problem using a sliding
forward-backward splitting method.
*/

use numeric_literals::replace_float_literals;
use serde::{Serialize, Deserialize};
//use colored::Colorize;
//use nalgebra::{DVector, DMatrix};
use itertools::izip;
use std::iter::Iterator;

use alg_tools::iterate::AlgIteratorFactory;
use alg_tools::euclidean::Euclidean;
use alg_tools::sets::Cube;
use alg_tools::loc::Loc;
use alg_tools::mapping::{Mapping, DifferentiableMapping, Instance};
use alg_tools::norms::Norm;
use alg_tools::bisection_tree::{
    BTFN,
    PreBTFN,
    Bounds,
    BTNodeLookup,
    BTNode,
    BTSearch,
    P2Minimise,
    SupportGenerator,
    LocalAnalysis,
    //Bounded,
};
use alg_tools::mapping::RealMapping;
use alg_tools::nalgebra_support::ToNalgebraRealField;
use alg_tools::norms::{L2, Linfinity};

use crate::types::*;
use crate::measures::{DiscreteMeasure, Radon, RNDM};
use crate::measures::merging::{
    SpikeMergingMethod,
    SpikeMerging,
};
use crate::forward_model::{
    ForwardModel,
    AdjointProductBoundedBy,
    LipschitzValues,
};
use crate::seminorms::DiscreteMeasureOp;
//use crate::tolerance::Tolerance;
use crate::plot::{
    SeqPlotter,
    Plotting,
    PlotLookup
};
use crate::fb::*;
use crate::regularisation::SlidingRegTerm;
use crate::dataterm::{
    L2Squared,
    //DataTerm,
    calculate_residual,
    calculate_residual2,
};
//use crate::transport::TransportLipschitz;

/// Transport settings for [`pointsource_sliding_fb_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct TransportConfig<F : Float> {
    /// Transport step length $θ$ normalised to $(0, 1)$.
    pub θ0 : F,
    /// Factor in $(0, 1)$ for decreasing transport to adapt to tolerance.
    pub adaptation : F,
    /// Transport tolerance wrt. ω
    pub tolerance_ω : F,
    /// Transport tolerance wrt. ∇v
    pub tolerance_dv : F,
}

#[replace_float_literals(F::cast_from(literal))]
impl <F : Float> TransportConfig<F> {
    /// Check that the parameters are ok. Panics if not.
    pub fn check(&self) {
        assert!(self.θ0 > 0.0);
        assert!(0.0 < self.adaptation && self.adaptation < 1.0);
        assert!(self.tolerance_dv > 0.0);
        assert!(self.tolerance_ω > 0.0);
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for TransportConfig<F> {
    fn default() -> Self {
        TransportConfig {
            θ0 : 0.01,
            adaptation : 0.9,
            tolerance_ω : 1000.0, // TODO: no idea what this should be
            tolerance_dv : 1000.0, // TODO: no idea what this should be
        }
    }
}

/// Settings for [`pointsource_sliding_fb_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct SlidingFBConfig<F : Float> {
    /// Step length scaling
    pub τ0 : F,
    /// Transport parameters
    pub transport : TransportConfig<F>,
    /// Generic parameters
    pub insertion : FBGenericConfig<F>,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for SlidingFBConfig<F> {
    fn default() -> Self {
        SlidingFBConfig {
            τ0 : 0.99,
            transport : Default::default(),
            insertion : Default::default()
        }
    }
}

/// Internal type of adaptive transport step length calculation
pub(crate) enum TransportStepLength<F : Float, G : Fn(F, F) -> F> {
    /// Fixed, known step length
    Fixed(F),
    /// Adaptive step length, only wrt. maximum transport.
    /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
    AdaptiveMax{ l : F, max_transport : F, g : G },
    /// Adaptive step length.
    /// Content of `l` depends on use case, while `g` calculates the step length from `l`.
    FullyAdaptive{ l : F, max_transport : F, g : G },
}

/// Constrution and a priori transport adaptation.
#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn initial_transport<F, G, D, Observable, const N : usize>(
    γ1 : &mut RNDM<F, N>,
    μ : &mut RNDM<F, N>,
    opAapply : impl Fn(&RNDM<F, N>) -> Observable,
    ε : F,
    τ : F,
    θ_or_adaptive : &mut TransportStepLength<F, G>,
    opAnorm : F,
    v : D,
    tconfig : &TransportConfig<F>
) -> (Vec<F>, RNDM<F, N>)
where
    F : Float + ToNalgebraRealField,
    G : Fn(F, F) -> F,
    Observable : Euclidean<F, Output=Observable>,
    for<'a> &'a Observable : Instance<Observable>,
    //for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
    D : DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F, N>>,
{

    use TransportStepLength::*;

    // Save current base point and shift μ to new positions. Idea is that
    //  μ_base(_masses) = μ^k (vector of masses)
    //  μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
    //  γ1 = π_♯^1γ^{k+1}
    //  μ = μ^{k+1}
    let μ_base_masses : Vec<F> = μ.iter_masses().collect();
    let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
    // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
    //let mut sum_norm_dv = 0.0;
    let γ_prev_len = γ1.len();
    assert!(μ.len() >= γ_prev_len);
    γ1.extend(μ[γ_prev_len..].iter().cloned());

    // Calculate initial transport and step length.
    // First calculate initial transported weights
    for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
        // If old transport has opposing sign, the new transport will be none.
        ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
            0.0
        } else {
            δ.α
        };
    };

    // A priori transport adaptation based on bounding 2 ‖A‖ ‖A(γ₁-γ₀)‖‖γ‖ by scaling γ.
    // 1. Calculate transport rays.
    //    If the Lipschitz factor of the values v=∇F(μ) are not known, estimate it.
    match *θ_or_adaptive {
        Fixed(θ) => {
            let θτ = τ * θ;
            for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
                ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
            }
        },
        AdaptiveMax{ l : ℓ_v, ref mut max_transport, g : ref calculate_θ } => {
            *max_transport = max_transport.max(γ1.norm(Radon));
            let θτ = τ * calculate_θ(ℓ_v, *max_transport);
            for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
                ρ.x = δ.x - v.differential(&δ.x) * (ρ.α.signum() * θτ);
            }
        },
        FullyAdaptive{ l : ref mut adaptive_ℓ_v, ref mut max_transport, g : ref calculate_θ } => {
            *max_transport = max_transport.max(γ1.norm(Radon));
            let mut θ = calculate_θ(*adaptive_ℓ_v, *max_transport);
            loop {
                let θτ = τ * θ;
                for (δ, ρ) in izip!(μ.iter_spikes(), γ1.iter_spikes_mut()) {
                    let dv_x = v.differential(&δ.x);
                    ρ.x = δ.x - &dv_x * (ρ.α.signum() * θτ);
                    // Estimate Lipschitz factor of ∇v
                    let this_ℓ_v = (dv_x - v.differential(&ρ.x)).norm2();
                    *adaptive_ℓ_v = adaptive_ℓ_v.max(this_ℓ_v);
                }
                let new_θ = calculate_θ(*adaptive_ℓ_v / tconfig.adaptation, *max_transport);
                if new_θ <= θ {
                    break
                }
                θ = new_θ;
            }
        }
    }

    // 2. Adjust transport mass, if needed.
    // This tries to remove the smallest transport masses first.
    if true {
        // Alternative 1 : subtract same amount from all transport rays until reaching zero
        loop {
            let nr =γ1.norm(Radon);
            let n = τ * 2.0 * opAnorm * (opAapply(&*γ1)-opAapply(&*μ)).norm2();
            if n <= 0.0 || nr <= 0.0 {
                break
            }
            let reduction_needed = nr - (ε * tconfig.tolerance_dv / n);
            if reduction_needed <= 0.0 {
                break
            }
            let (min_nonzero, n_nonzero) = γ1.iter_masses()
                                            .map(|α| α.abs())
                                            .filter(|α| *α > F::EPSILON)
                                            .fold((F::INFINITY, 0), |(a, n), b| (a.min(b), n+1));
            assert!(n_nonzero > 0);
            // Reduction that can be done in all nonzero spikes simultaneously
            let h = (reduction_needed / F::cast_from(n_nonzero)).min(min_nonzero);
            for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) {
                ρ.α = ρ.α.signum() * (ρ.α.abs() - h).max(0.0);
                δ.α = ρ.α;
            }
            if min_nonzero * F::cast_from(n_nonzero) >= reduction_needed {
                break
            }
        }
    } else {
        // Alternative 2: first reduce transport rays with greater effect based on differential.
        // This is a an inefficient quick-and-dirty implementation.
        loop {
            let nr = γ1.norm(Radon);
            let a = opAapply(&*γ1)-opAapply(&*μ);
            let na = a.norm2();
            let n = τ * 2.0 * opAnorm * na;
            if n <= 0.0 || nr <= 0.0 {
                break
            }
            let reduction_needed = nr - (ε * tconfig.tolerance_dv / n);
            if reduction_needed <= 0.0 {
                break
            }
            let mut max_d = 0.0;
            let mut max_d_ind = 0;
            for (δ, ρ, i) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), 0..) {
                // Calculate differential of  ‖A(γ₁-γ₀)‖‖γ‖  wrt. each spike
                let s = δ.α.signum();
                // TODO: this is very inefficient implementation due to the limitations
                // of the closure parameters.
                let δ1 = DiscreteMeasure::from([(ρ.x, s)]);
                let δ2 = DiscreteMeasure::from([(δ.x, s)]);
                let a_part = opAapply(&δ1)-opAapply(&δ2);
                let d = a.dot(&a_part)/na * nr + 2.0 * na;
                if d > max_d {
                    max_d = d;
                    max_d_ind = i;
                }
            }
            // Just set mass to zero for transport ray with greater differential
            assert!(max_d > 0.0);
            γ1[max_d_ind].α = 0.0;
            μ[max_d_ind].α = 0.0;
        }
    }

    // Set initial guess for μ=μ^{k+1}.
    for (δ, ρ, &β) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes(), μ_base_masses.iter()) {
        if ρ.α.abs() > F::EPSILON {
            δ.x = ρ.x;
            //δ.α = ρ.α; // already set above
        } else {
            δ.α = β;
        }
    }
    // Calculate μ^k-π_♯^0γ^{k+1} and v̆ = A_*(A[μ_transported + μ_transported_base]-b)
    μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses())
                                                   .map(|(&a,b)| a - b));
    (μ_base_masses, μ_base_minus_γ0)
}

/// A posteriori transport adaptation.
#[replace_float_literals(F::cast_from(literal))]
pub(crate) fn aposteriori_transport<F, const N : usize>(
    γ1 : &mut RNDM<F, N>,
    μ : &mut RNDM<F, N>,
    μ_base_minus_γ0 : &mut RNDM<F, N>,
    μ_base_masses : &Vec<F>,
    ε : F,
    tconfig : &TransportConfig<F>
) -> bool
where F : Float + ToNalgebraRealField {

    // 1. If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
    // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
    // at that point to zero, and retry.
    let mut all_ok = true;
    for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
        if α_μ == 0.0 && *α_γ1 != 0.0 {
            all_ok = false;
            *α_γ1 = 0.0;
        }
    }

    // 2. Through bounding ∫ B_ω(y, z) dλ(x, y, z).
    //    through the estimate ≤ C ‖Δ‖‖γ^{k+1}‖ for Δ := μ^{k+1}-μ^k-(π_♯^1-π_♯^0)γ^{k+1},
    //    which holds for some some C if the convolution kernel in 𝒟 has Lipschitz gradient.
    let nγ = γ1.norm(Radon);
    let nΔ = μ_base_minus_γ0.norm(Radon) + μ.dist_matching(&γ1);
    let t = ε * tconfig.tolerance_ω;
    if nγ*nΔ > t {
        // Since t/(nγ*nΔ)<1, and the constant tconfig.adaptation < 1,
        // this will guarantee that eventually ‖γ‖ decreases sufficiently that we
        // will not enter here.
        *γ1 *= tconfig.adaptation * t / ( nγ * nΔ );
        all_ok = false
    }

    if !all_ok {
        // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
        μ_base_minus_γ0.set_masses(μ_base_masses.iter().zip(γ1.iter_masses())
                                                        .map(|(&a,b)| a - b));

    }

    all_ok
}

/// Iteratively solve the pointsource localisation problem using sliding forward-backward
/// splitting
///
/// The parametrisation is as for [`pointsource_fb_reg`].
/// Inertia is currently not supported.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>(
    opA : &'a A,
    b : &A::Observable,
    reg : Reg,
    op𝒟 : &'a 𝒟,
    config : &SlidingFBConfig<F>,
    iterator : I,
    mut plotter : SeqPlotter<F, N>,
) -> RNDM<F, N>
where F : Float + ToNalgebraRealField,
      I : AlgIteratorFactory<IterInfo<F, N>>,
      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable> + Instance<A::Observable>,
      for<'b> A::Preadjoint<'b> : LipschitzValues<FloatType=F>,
      A::PreadjointCodomain : DifferentiableMapping<
        Loc<F, N>, DerivativeDomain=Loc<F, N>, Codomain=F
      >,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      A : ForwardModel<RNDM<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
          + AdjointProductBoundedBy<RNDM<F, N>, 𝒟, FloatType=F>,
          //+ TransportLipschitz<L2Squared, FloatType=F>,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>,
                                          Codomain = BTFN<F, G𝒟, BT𝒟, N>>,
      BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>
         + DifferentiableMapping<Loc<F, N>, DerivativeDomain=Loc<F,N>>,
      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
         //+ Differentiable<Loc<F, N>, Derivative=Loc<F,N>>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
      PlotLookup : Plotting<N>,
      RNDM<F, N> : SpikeMerging<F>,
      Reg : SlidingRegTerm<F, N> {

    // Check parameters
    assert!(config.τ0 > 0.0, "Invalid step length parameter");
    config.transport.check();

    // Initialise iterates
    let mut μ = DiscreteMeasure::new();
    let mut γ1 = DiscreteMeasure::new();
    let mut residual = -b; // Has to equal $Aμ-b$.

    // Set up parameters
    let op𝒟norm = op𝒟.opnorm_bound(Radon, Linfinity);
    let opAnorm = opA.opnorm_bound(Radon, L2);
    //let max_transport = config.max_transport.scale
    //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0);
    //let ℓ = opA.transport.lipschitz_factor(L2Squared) * max_transport;
    let ℓ = 0.0;
    let τ = config.τ0 / opA.adjoint_product_bound(&op𝒟).unwrap();
    let calculate_θ = |ℓ_v, _| config.transport.θ0 / (τ*(ℓ + ℓ_v));
    let mut θ_or_adaptive = match opA.preadjoint().value_diff_unit_lipschitz_factor() {
        // We only estimate w (the uniform Lipschitz for of v), if we also estimate ℓ_v
        // (the uniform Lipschitz factor of ∇v).
        // We assume that the residual is decreasing.
        Some(ℓ_v0) => TransportStepLength::Fixed(calculate_θ(ℓ_v0 * residual.norm2(), 0.0)),
        None => TransportStepLength::FullyAdaptive {
            l : 0.0,
            max_transport : 0.0,
            g : calculate_θ
        },
    };
    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.insertion.tolerance * τ * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    // Statistics
    let full_stats = |residual : &A::Observable,
                      μ : &RNDM<F, N>,
                      ε, stats| IterInfo {
        value : residual.norm2_squared_div2() + reg.apply(μ),
        n_spikes : μ.len(),
        ε,
        // postprocessing: config.insertion.postprocessing.then(|| μ.clone()),
        .. stats
    };
    let mut stats = IterInfo::new();

    // Run the algorithm
    for state in iterator.iter_init(|| full_stats(&residual, &μ, ε, stats.clone())) {
        // Calculate initial transport
        let v = opA.preadjoint().apply(residual);
        let (μ_base_masses, mut μ_base_minus_γ0) = initial_transport(
            &mut γ1, &mut μ, |ν| opA.apply(ν),
            ε, τ, &mut θ_or_adaptive, opAnorm,
            v, &config.transport,
        );

        // Solve finite-dimensional subproblem several times until the dual variable for the
        // regularisation term conforms to the assumptions made for the transport above.
        let (d, _within_tolerances, τv̆) = 'adapt_transport: loop {
            // Calculate τv̆ = τA_*(A[μ_transported + μ_transported_base]-b)
            let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
            let τv̆ = opA.preadjoint().apply(residual_μ̆ * τ);

            // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
            let (d, within_tolerances) = insert_and_reweigh(
                &mut μ, &τv̆, &γ1, Some(&μ_base_minus_γ0),
                op𝒟, op𝒟norm,
                τ, ε, &config.insertion,
                &reg, &state, &mut stats,
            );

            // A posteriori transport adaptation.
            if aposteriori_transport(
                &mut γ1, &mut μ, &mut μ_base_minus_γ0, &μ_base_masses,
                ε, &config.transport
            ) {
                break 'adapt_transport (d, within_tolerances, τv̆)
            }
        };

        stats.untransported_fraction = Some({
            assert_eq!(μ_base_masses.len(), γ1.len());
            let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
            let source = μ_base_masses.iter().map(|v| v.abs()).sum();
            (a + μ_base_minus_γ0.norm(Radon), b + source)
        });
        stats.transport_error = Some({
            assert_eq!(μ_base_masses.len(), γ1.len());
            let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
            (a + μ.dist_matching(&γ1), b + γ1.norm(Radon))
        });

        // Merge spikes.
        // This expects the prune below to prune γ.
        // TODO: This may not work correctly in all cases.
        let ins = &config.insertion;
        if ins.merge_now(&state) {
            if let SpikeMergingMethod::None = ins.merging {
            } else {
                stats.merged += μ.merge_spikes(ins.merging, |μ_candidate| {
                    let ν = μ_candidate.sub_matching(&γ1)-&μ_base_minus_γ0;
                    let mut d = &τv̆ + op𝒟.preapply(ν);
                    reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, ins)
                });
            }
        }

        // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
        // latter needs to be pruned when μ is.
        // TODO: This could do with a two-vector Vec::retain to avoid copies.
        let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
        if μ_new.len() != μ.len() {
            let mut μ_iter = μ.iter_spikes();
            γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
            stats.pruned += μ.len() - μ_new.len();
            μ = μ_new;
        }

        // Update residual
        residual = calculate_residual(&μ, opA, b);

        let iter = state.iteration();
        stats.this_iters += 1;

        // Give statistics if requested
        state.if_verbose(|| {
            plotter.plot_spikes(iter, Some(&d), Some(&τv̆), &μ);
            full_stats(&residual, &μ, ε, std::mem::replace(&mut stats, IterInfo::new()))
        });

        // Update main tolerance for next iteration
        ε = tolerance.update(ε, iter);
    }

    postprocess(μ, &config.insertion, L2Squared, opA, b)
}

mercurial