src/sliding_fb.rs

Thu, 29 Aug 2024 00:00:00 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 29 Aug 2024 00:00:00 -0500
branch
dev
changeset 34
efa60bc4f743
parent 32
56c8adc32b09
child 35
b087e3eab191
permissions
-rw-r--r--

Radon FB + sliding improvements

/*!
Solver for the point source localisation problem using a sliding
forward-backward splitting method.
*/

use numeric_literals::replace_float_literals;
use serde::{Serialize, Deserialize};
//use colored::Colorize;
//use nalgebra::{DVector, DMatrix};
use itertools::izip;
use std::iter::Iterator;

use alg_tools::iterate::{
    AlgIteratorFactory,
    AlgIteratorState
};
use alg_tools::euclidean::Euclidean;
use alg_tools::sets::Cube;
use alg_tools::loc::Loc;
use alg_tools::mapping::{Apply, Differentiable};
use alg_tools::norms::{Norm, L2};
use alg_tools::bisection_tree::{
    BTFN,
    PreBTFN,
    Bounds,
    BTNodeLookup,
    BTNode,
    BTSearch,
    P2Minimise,
    SupportGenerator,
    LocalAnalysis,
    //Bounded,
};
use alg_tools::mapping::RealMapping;
use alg_tools::nalgebra_support::ToNalgebraRealField;

use crate::types::*;
use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon};
use crate::measures::merging::{
    //SpikeMergingMethod,
    SpikeMerging,
};
use crate::forward_model::ForwardModel;
use crate::seminorms::DiscreteMeasureOp;
//use crate::tolerance::Tolerance;
use crate::plot::{
    SeqPlotter,
    Plotting,
    PlotLookup
};
use crate::fb::*;
use crate::regularisation::SlidingRegTerm;
use crate::dataterm::{
    L2Squared,
    //DataTerm,
    calculate_residual,
    calculate_residual2,
};
use crate::transport::TransportLipschitz;

/// Settings for [`pointsource_sliding_fb_reg`].
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct SlidingFBConfig<F : Float> {
    /// Step length scaling
    pub τ0 : F,
    /// Transport step length $θ$ normalised to $(0, 1)$.
    pub θ0 : F,
    /// Maximum transport mass scaling.
    // /// The maximum transported mass is this factor times $\norm{b}^2/(2α)$.
    // pub max_transport_scale : F,
    /// Transport tolerance wrt. ω
    pub transport_tolerance_ω : F,
    /// Transport tolerance wrt. ∇v
    pub transport_tolerance_dv : F,
    /// Generic parameters
    pub insertion : FBGenericConfig<F>,
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for SlidingFBConfig<F> {
    fn default() -> Self {
        SlidingFBConfig {
            τ0 : 0.99,
            θ0 : 0.99,
            //max_transport_scale : 10.0,
            transport_tolerance_ω : 1.0, // TODO: no idea what this should be
            transport_tolerance_dv : 1.0, // TODO: no idea what this should be
            insertion : Default::default()
        }
    }
}

/// Scale each |γ|_i ≠ 0 by q_i=q̄/g(γ_i)
#[replace_float_literals(F::cast_from(literal))]
fn scale_down<'a, I, F, G, const N : usize>(
    iter : I,
    q̄ : F,
    mut g : G
) where F : Float,
        I : Iterator<Item = &'a mut DeltaMeasure<Loc<F,N>, F>>,
        G : FnMut(&DeltaMeasure<Loc<F,N>, F>) -> F {
    iter.for_each(|δ| {
        if δ.α != 0.0 {
            let b = g(δ);
            if b * δ.α > 0.0 {
                δ.α *= q̄/b;
            }
        }
    });
}

/// Iteratively solve the pointsource localisation problem using sliding forward-backward
/// splitting
///
/// The parametrisatio is as for [`pointsource_fb_reg`].
/// Inertia is currently not supported.
#[replace_float_literals(F::cast_from(literal))]
pub fn pointsource_sliding_fb_reg<'a, F, I, A, GA, 𝒟, BTA, BT𝒟, G𝒟, S, K, Reg, const N : usize>(
    opA : &'a A,
    b : &A::Observable,
    reg : Reg,
    op𝒟 : &'a 𝒟,
    sfbconfig : &SlidingFBConfig<F>,
    iterator : I,
    mut plotter : SeqPlotter<F, N>,
) -> DiscreteMeasure<Loc<F, N>, F>
where F : Float + ToNalgebraRealField,
      I : AlgIteratorFactory<IterInfo<F, N>>,
      for<'b> &'b A::Observable : std::ops::Neg<Output=A::Observable>,
                                  //+ std::ops::Mul<F, Output=A::Observable>,  <-- FIXME: compiler overflow
      A::Observable : std::ops::MulAssign<F>,
      A::PreadjointCodomain : for<'b> Differentiable<&'b Loc<F, N>, Output=Loc<F, N>>,
      GA : SupportGenerator<F, N, SupportType = S, Id = usize> + Clone,
      A : ForwardModel<Loc<F, N>, F, PreadjointCodomain = BTFN<F, GA, BTA, N>>
          + Lipschitz<&'a 𝒟, FloatType=F> + TransportLipschitz<L2Squared, FloatType=F>,
      BTA : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      G𝒟 : SupportGenerator<F, N, SupportType = K, Id = usize> + Clone,
      𝒟 : DiscreteMeasureOp<Loc<F, N>, F, PreCodomain = PreBTFN<F, G𝒟, N>,
                                          Codomain = BTFN<F, G𝒟, BT𝒟, N>>,
      BT𝒟 : BTSearch<F, N, Data=usize, Agg=Bounds<F>>,
      S: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>
         + Differentiable<Loc<F, N>, Output=Loc<F,N>>,
      K: RealMapping<F, N> + LocalAnalysis<F, Bounds<F>, N>,
         //+ Differentiable<Loc<F, N>, Output=Loc<F,N>>,
      BTNodeLookup: BTNode<F, usize, Bounds<F>, N>,
      Cube<F, N>: P2Minimise<Loc<F, N>, F>,
      PlotLookup : Plotting<N>,
      DiscreteMeasure<Loc<F, N>, F> : SpikeMerging<F>,
      Reg : SlidingRegTerm<F, N> {

    assert!(sfbconfig.τ0 > 0.0 &&
            sfbconfig.θ0 > 0.0);

    // Set up parameters
    let config = &sfbconfig.insertion;
    let op𝒟norm = op𝒟.opnorm_bound();
    //let max_transport = sfbconfig.max_transport_scale
    //                    * reg.radon_norm_bound(b.norm2_squared() / 2.0);
    //let tlip = opA.transport_lipschitz_factor(L2Squared) * max_transport;
    //let ℓ = 0.0;
    let θ = sfbconfig.θ0; // (ℓ + tlip);
    let τ = sfbconfig.τ0/opA.lipschitz_factor(&op𝒟).unwrap();
    // We multiply tolerance by τ for FB since our subproblems depending on tolerances are scaled
    // by τ compared to the conditional gradient approach.
    let tolerance = config.tolerance * τ * reg.tolerance_scaling();
    let mut ε = tolerance.initial();

    // Initialise iterates
    let mut μ = DiscreteMeasure::new();
    let mut γ1 = DiscreteMeasure::new();
    let mut residual = -b;
    let mut stats = IterInfo::new();

    // Run the algorithm
    iterator.iterate(|state| {
        // Calculate smooth part of surrogate model.
        // Using `std::mem::replace` here is not ideal, and expects that `empty_observable`
        // has no significant overhead. For some reosn Rust doesn't allow us simply moving
        // the residual and replacing it below before the end of this closure.
        let r = std::mem::replace(&mut residual, opA.empty_observable());
        let v = opA.preadjoint().apply(r);

        // Save current base point and shift μ to new positions. Idea is that
        //  μ_base(_masses) = μ^k (vector of masses)
        //  μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
        //  γ1 = π_♯^1γ^{k+1}
        //  μ = μ^{k+1}
        let μ_base_masses : Vec<F> = μ.iter_masses().collect();
        let mut μ_base_minus_γ0 = μ.clone(); // Weights will be set in the loop below.
        // Construct μ^{k+1} and π_♯^1γ^{k+1} initial candidates
        let mut sum_norm_dv_times_γinit = 0.0;
        let mut sum_abs_γinit = 0.0;
        //let mut sum_norm_dv = 0.0;
        let γ_prev_len = γ1.len();
        assert!(μ.len() >= γ_prev_len);
        γ1.extend(μ[γ_prev_len..].iter().cloned());
        for (δ, ρ) in izip!(μ.iter_spikes_mut(), γ1.iter_spikes_mut()) {
            let d_v_x = v.differential(&δ.x);
            // If old transport has opposing sign, the new transport will be none.
            ρ.α = if (ρ.α > 0.0 && δ.α < 0.0) || (ρ.α < 0.0 && δ.α > 0.0) {
                0.0
            } else {
                δ.α
            };
            δ.x -= d_v_x * (θ * δ.α.signum()); // This is δ.α.signum() when δ.α ≠ 0.
            ρ.x = δ.x;
            let nrm = d_v_x.norm(L2);
            let a = ρ.α.abs();
            let v = nrm * a;
            if v > 0.0 {
                sum_norm_dv_times_γinit += v;
                sum_abs_γinit += a;
            }
        }

        // A priori transport adaptation based on bounding ∫ ⟨∇v(x), z-y⟩ dλ(x, y, z).
        // This is just one option, there are many.
        let t = ε * sfbconfig.transport_tolerance_dv;
        if sum_norm_dv_times_γinit > t {
            // Scale each |γ|_i by q_i=q̄/‖vx‖_i such that ∑_i |γ|_i q_i ‖vx‖_i = t
            // TODO: store the closure values above?
            scale_down(γ1.iter_spikes_mut(),
                       t / sum_abs_γinit,
                       |δ| v.differential(&δ.x).norm(L2));
        }
        //println!("|γ| = {}, |μ| = {}", γ1.norm(crate::measures::Radon), μ.norm(crate::measures::Radon));

        // Solve finite-dimensional subproblem several times until the dual variable for the
        // regularisation term conforms to the assumptions made for the transport above.
        let (d, within_tolerances) = 'adapt_transport: loop {
            // Update weights for μ_base_minus_γ0 = μ^k - π_♯^0γ^{k+1}
            for (δ_γ1, δ_μ_base_minus_γ0, &α_μ_base) in izip!(γ1.iter_spikes(),
                                                              μ_base_minus_γ0.iter_spikes_mut(),
                                                              μ_base_masses.iter()) {
                δ_μ_base_minus_γ0.set_mass(α_μ_base - δ_γ1.get_mass());
            }

            // Calculate transported_minus_τv = -τA_*(A[μ_transported + μ_transported_base]-b)
            let residual_μ̆ = calculate_residual2(&γ1, &μ_base_minus_γ0, opA, b);
            let transported_minus_τv̆ = opA.preadjoint().apply(residual_μ̆ * (-τ));

            // Construct μ^{k+1} by solving finite-dimensional subproblems and insert new spikes.
            let (d, within_tolerances) = insert_and_reweigh(
                &mut μ, &transported_minus_τv̆, &γ1, Some(&μ_base_minus_γ0),
                op𝒟, op𝒟norm,
                τ, ε,
                config,
                &reg, state, &mut stats,
            );

            // A posteriori transport adaptation based on bounding (1/τ)∫ ω(z) - ω(y) dλ(x, y, z).
            let all_ok = if false { // Basic check
                // If π_♯^1γ^{k+1} = γ1 has non-zero mass at some point y, but μ = μ^{k+1} does not,
                // then the ansatz ∇w̃_x(y) = w^{k+1}(y) may not be satisfied. So set the mass of γ1
                // at that point to zero, and retry.
                let mut all_ok = true;
                for (α_μ, α_γ1) in izip!(μ.iter_masses(), γ1.iter_masses_mut()) {
                    if α_μ == 0.0 && *α_γ1 != 0.0 {
                        all_ok = false;
                        *α_γ1 = 0.0;
                    }
                }
                all_ok
            } else {
                // TODO: Could maybe optimise, as this is also formed in insert_and_reweigh above.
                let mut minus_ω = op𝒟.apply(γ1.sub_matching(&μ) + &μ_base_minus_γ0);

                // let vpos = γ1.iter_spikes()
                //              .filter(|δ| δ.α > 0.0)
                //              .map(|δ| minus_ω.apply(&δ.x))
                //              .reduce(F::max)
                //              .and_then(|threshold| {
                //                 minus_ω.minimise_below(threshold,
                //                                         ε * config.refinement.tolerance_mult,
                //                                         config.refinement.max_steps)
                //                        .map(|(_z, minus_ω_z)| minus_ω_z)
                //              });

                // let vneg = γ1.iter_spikes()
                //              .filter(|δ| δ.α < 0.0)
                //              .map(|δ| minus_ω.apply(&δ.x))
                //              .reduce(F::min)
                //              .and_then(|threshold| {
                //                 minus_ω.maximise_above(threshold,
                //                                         ε * config.refinement.tolerance_mult,
                //                                         config.refinement.max_steps)
                //                        .map(|(_z, minus_ω_z)| minus_ω_z)
                //              });
                let (_, vpos) = minus_ω.minimise(ε * config.refinement.tolerance_mult,
                                                 config.refinement.max_steps);
                let (_, vneg) = minus_ω.maximise(ε * config.refinement.tolerance_mult,
                                                 config.refinement.max_steps);
            
                let t = τ * ε * sfbconfig.transport_tolerance_ω;
                let val = |δ : &DeltaMeasure<Loc<F, N>, F>| {
                    δ.α * (minus_ω.apply(&δ.x) - if δ.α >= 0.0 { vpos } else { vneg })
                    // match if δ.α >= 0.0 { vpos } else { vneg } {
                    //     None => 0.0,
                    //     Some(v) => δ.α * (minus_ω.apply(&δ.x) - v)
                    // }
                };
                // Calculate positive/bad (rp) values under the integral.
                // Also store sum of masses for the positive entries.
                let (rp, w) = γ1.iter_spikes().fold((0.0, 0.0), |(p, w), δ| {
                    let v = val(δ);
                    if v <= 0.0 { (p, w) } else { (p + v, w + δ.α.abs()) }
                });

                if rp > t {
                    // TODO: store v above?
                    scale_down(γ1.iter_spikes_mut(), t / w, val);
                    false
                } else {
                    true
                }
            };

            if all_ok {
                break 'adapt_transport (d, within_tolerances)
            }
        };

        stats.untransported_fraction = Some({
            assert_eq!(μ_base_masses.len(), γ1.len());
            let (a, b) = stats.untransported_fraction.unwrap_or((0.0, 0.0));
            let source = μ_base_masses.iter().map(|v| v.abs()).sum();
            (a + μ_base_minus_γ0.norm(Radon), b + source)
        });
        stats.transport_error = Some({
            assert_eq!(μ_base_masses.len(), γ1.len());
            let (a, b) = stats.transport_error.unwrap_or((0.0, 0.0));
            let err = izip!(μ.iter_masses(), γ1.iter_masses()).map(|(v,w)| (v-w).abs()).sum();
            (a + err, b + γ1.norm(Radon))
        });

        // Prune spikes with zero weight. To maintain correct ordering between μ and γ1, also the
        // latter needs to be pruned when μ is.
        // TODO: This could do with a two-vector Vec::retain to avoid copies.
        let μ_new = DiscreteMeasure::from_iter(μ.iter_spikes().filter(|δ| δ.α != F::ZERO).cloned());
        if μ_new.len() != μ.len() {
            let mut μ_iter = μ.iter_spikes();
            γ1.prune_by(|_| μ_iter.next().unwrap().α != F::ZERO);
            μ = μ_new;
        }

        // TODO: how to merge?

        // Update residual
        residual = calculate_residual(&μ, opA, b);

        // Update main tolerance for next iteration
        let ε_prev = ε;
        ε = tolerance.update(ε, state.iteration());
        stats.this_iters += 1;

        // Give function value if needed
        state.if_verbose(|| {
            // Plot if so requested
            plotter.plot_spikes(
                format!("iter {} end; {}", state.iteration(), within_tolerances), &d,
                "start".to_string(), None::<&A::PreadjointCodomain>, // TODO: Should be Some(&((-τ) * v)), but not implemented
                reg.target_bounds(τ, ε_prev), &μ,
            );
            // Calculate mean inner iterations and reset relevant counters.
            // Return the statistics
            let res = IterInfo {
                value : residual.norm2_squared_div2() + reg.apply(&μ),
                n_spikes : μ.len(),
                ε : ε_prev,
                postprocessing: config.postprocessing.then(|| μ.clone()),
                .. stats
            };
            stats = IterInfo::new();
            res
        })
    });

    postprocess(μ, config, L2Squared, opA, b)
}

mercurial