src/measures/merging.rs

Tue, 21 Mar 2023 20:31:01 +0200

author
Tuomo Valkonen <tuomov@iki.fi>
date
Tue, 21 Mar 2023 20:31:01 +0200
changeset 25
79943be70720
parent 0
eb3c7813b67a
permissions
-rw-r--r--

Implement non-negativity constraints for the conditional gradient methods

/*!
Spike merging heuristics for [`DiscreteMeasure`]s.

This module primarily provides the [`SpikeMerging`] trait, and within it,
the [`SpikeMerging::merge_spikes`] method. The trait is implemented on
[`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`.
*/

use numeric_literals::replace_float_literals;
use std::cmp::Ordering;
use serde::{Serialize, Deserialize};
//use clap::builder::{PossibleValuesParser, PossibleValue};
use alg_tools::nanleast::NaNLeast;

use crate::types::*;
use super::delta::*;
use super::discrete::*;

/// Spike merging heuristic selection
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[allow(dead_code)]
pub enum SpikeMergingMethod<F> {
    /// Try to merge spikes within a given radius of eachother
    HeuristicRadius(F),
    /// No merging
    None,
}

// impl<F : Float> SpikeMergingMethod<F> {
//     /// This is for [`clap`] to display command line help.
//     pub fn value_parser() -> PossibleValuesParser {
//         PossibleValuesParser::new([
//             PossibleValue::new("none").help("No merging"),
//             PossibleValue::new("<radius>").help("Heuristic merging within indicated radius")
//         ])
//     }
// }

impl<F : ClapFloat> std::fmt::Display for SpikeMergingMethod<F> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
        match self {
            Self::None => write!(f, "none"),
            Self::HeuristicRadius(r) => std::fmt::Display::fmt(r, f),
        }
    }
}

impl<F : ClapFloat> std::str::FromStr for SpikeMergingMethod<F> {
    type Err = F::Err;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        if s == "none" {
            Ok(Self::None)
        } else {
            Ok(Self::HeuristicRadius(F::from_str(s)?))
        }
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for SpikeMergingMethod<F> {
    fn default() -> Self {
        SpikeMergingMethod::HeuristicRadius(0.02)
    }
}

/// Trait for dimension-dependent implementation of heuristic peak merging strategies.
pub trait SpikeMerging<F> {
    /// Attempt spike merging according to [`SpikeMerging`] method.
    ///
    /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure
    /// `accept` if any merging is performed. The closure should accept as its only parameter a
    /// new candidate measure (it will generally be internally mutated `self`, although this is
    /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of
    /// an arbitrary value. This method will return that value for the *last* accepted merge, or
    /// [`None`] if no merge was accepted.
    ///
    /// This method is stable with respect to spike locations:  on merge, the weight of existing
    /// spikes is set to zero, and a new one inserted at the end of the spike vector.
    fn merge_spikes<G, V>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> Option<V>
    where G : Fn(&'_ Self) -> Option<V> {
        match method {
            SpikeMergingMethod::HeuristicRadius(ρ) => self.do_merge_spikes_radius(ρ, accept),
            SpikeMergingMethod::None => None,
        }
    }

    /// Attempt to merge spikes based on a value and a fitness function.
    ///
    /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of
    /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value`
    // for a merge  accepted by `fitness`. If no merge was accepted, `value` applied to the initial
    /// `self` is returned.
    fn merge_spikes_fitness<G, H, V, O>(
        &mut self,
        method : SpikeMergingMethod<F>,
        value : G,
        fitness : H
    ) -> V
    where G : Fn(&'_ Self) -> V,
          H : Fn(&'_ V) -> O,
          O : PartialOrd {
        let initial_res = value(self);
        let initial_fitness = fitness(&initial_res);
        self.merge_spikes(method, |μ| {
            let res = value(μ);
            (fitness(&res) <= initial_fitness).then_some(res)
        }).unwrap_or(initial_res)
    }

    /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm).
    ///
    /// This method implements [`SpikeMerging::merge_spikes`] for
    /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are
    /// as for that method.
    fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V>
    where G : Fn(&'_ Self) -> Option<V>;
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float, const N : usize>  DiscreteMeasure<Loc<F, N>, F> {
    /// Attempts to merge spikes with indices `i` and `j`.
    ///
    /// This assumes that the weights of the two spikes have already been checked not to be zero.
    ///
    /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`].
    /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its
    /// return value.
    fn attempt_merge<G, V>(
        &mut self,
        res : &mut Option<V>,
        i : usize,
        j : usize,
        accept : &G
    ) -> bool
    where G : Fn(&'_ Self) -> Option<V> {
        let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
        let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];

        // Merge inplace
        self.spikes[i].α = 0.0;
        self.spikes[j].α = 0.0;
        //self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 });
        self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αi + xj * αj) / (αi + αj) });
        match accept(self) {
            some@Some(..) => {
                // Merge accepted, update our return value
                *res = some;
                // On next iteration process the newly merged spike.
                //indices[k+1] = self.spikes.len() - 1;
                true
            },
            None => {
                // Merge not accepted, restore modification
                self.spikes[i].α = αi;
                self.spikes[j].α = αj;
                self.spikes.pop();
                false
            }
        }
    }

    /*
    /// Attempts to merge spikes with indices i and j, acceptance through a delta.
    fn attempt_merge_change<G, V>(
        &mut self,
        res : &mut Option<V>,
        i : usize,
        j : usize,
        accept_change : &G
    ) -> bool
    where G : Fn(&'_ Self) -> Option<V> {
        let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
        let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];
        let δ = DeltaMeasure{ α : αi + αj, x : (xi + xj)/2.0 };
        let λ = [-self.spikes[i], -self.spikes[j], δ.clone()].into();

        match accept_change(&λ) {
            some@Some(..) => {
                // Merge accepted, update our return value
                *res = some;
                self.spikes[i].α = 0.0;
                self.spikes[j].α = 0.0;
                self.spikes.push(δ);
                true
            },
            None => {
                false
            }
        }
    }*/

}

/// Sorts a vector of indices into `slice` by `compare`.
///
/// The closure `compare` operators on references to elements of `slice`.
/// Returns the sorted vector of indices into `slice`.
pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize>
where F : FnMut(&V, &V) -> Ordering
{
    let mut indices = Vec::from_iter(0..slice.len());
    indices.sort_by(|&i, &j| compare(&slice[i], &slice[j]));
    indices
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> {

    fn do_merge_spikes_radius<G, V>(
        &mut self,
        ρ : F,
        accept : G
    ) -> Option<V>
    where G : Fn(&'_ Self) -> Option<V> {
        // Sort by coordinate into an indexing array.
        let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| {
            let &Loc([x1]) = &δ1.x;
            let &Loc([x2]) = &δ2.x;
            // nan-ignoring ordering of floats
            NaNLeast(x1).cmp(&NaNLeast(x2))
        });

        // Initialise result
        let mut res = None;

        // Scan consecutive pairs and merge if close enough and accepted by `accept`.
        if indices.len() == 0 {
            return res
        }
        for k in 0..(indices.len()-1) {
            let i = indices[k];
            let j = indices[k+1];
            let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i];
            let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j];
            debug_assert!(xi <= xj);
            // If close enough, attempt merging
            if αi != 0.0 && αj != 0.0 && xj <= xi + ρ {
                if self.attempt_merge(&mut res, i, j, &accept) {
                    indices[k+1] = self.spikes.len() - 1;
                }
            }
        }

        res
    }
}

/// Orders `δ1` and `δ1` according to the first coordinate.
fn compare_first_coordinate<F : Float>(
    δ1 : &DeltaMeasure<Loc<F, 2>, F>,
    δ2 : &DeltaMeasure<Loc<F, 2>, F>
) -> Ordering {
    let &Loc([x11, ..]) = &δ1.x;
    let &Loc([x21, ..]) = &δ2.x;
    // nan-ignoring ordering of floats
    NaNLeast(x11).cmp(&NaNLeast(x21))
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> {

    fn do_merge_spikes_radius<G, V>(&mut self, ρ : F, accept : G) -> Option<V>
    where G : Fn(&'_ Self) -> Option<V> {
        // Sort by first coordinate into an indexing array.
        let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate);

        // Initialise result
        let mut res = None;
        let mut start_scan_2nd = 0;

        // Scan in order
        if indices.len() == 0 {
            return res
        }
        for k in 0..indices.len()-1 {
            let i = indices[k];
            let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i];

            if αi == 0.0 {
                // Nothin to be done if the weight is already zero
                continue
            }

            let mut closest = None;

            // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd`
            // the smallest invalid merging index on the previous loop iteration, because a
            // the _closest_ mergeable spike might have index less than `k` in `indices`, and a
            // merge with it might have not been attempted with this spike if a different closer
            // spike was discovered based on the second coordinate.
            'scan_2nd: for l in (start_scan_2nd+1)..indices.len() {
                if l == k {
                    // Do not attempt to merge a spike with itself
                    continue
                }
                let j = indices[l];
                let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j];

                if xj1 < xi1 - ρ {
                    // Spike `j = indices[l]` has too low first coordinate. Update starting index
                    // for next iteration, and continue scanning.
                    start_scan_2nd = l;
                    continue 'scan_2nd
                } else if xj1 > xi1 + ρ {
                    // Break out: spike `j = indices[l]` has already too high first coordinate, no
                    // more close enough spikes can be found due to the sorting of `indices`.
                    break 'scan_2nd
                }

                // If also second coordinate is close enough, attempt merging if closer than
                // previously discovered mergeable spikes.
                let d2 = (xi2-xj2).abs();
                if αj != 0.0 && d2 <= ρ {
                    let r1 = xi1-xj1;
                    let d = (d2*d2 + r1*r1).sqrt();
                    match closest {
                        None => closest = Some((l, j, d)),
                        Some((_, _, r)) if r > d => closest = Some((l, j, d)),
                        _ => {},
                    }
                }
            }

            // Attempt merging closest close-enough spike
            if let Some((l, j, _)) = closest {
                if self.attempt_merge(&mut res, i, j, &accept) {
                    // If merge was succesfull, make new spike candidate for merging.
                    indices[l] = self.spikes.len() - 1;
                    let compare = |i, j| compare_first_coordinate(&self.spikes[i],
                                                                  &self.spikes[j]);
                    // Re-sort relevant range of indices
                    if l < k {
                        indices[l..k].sort_by(|&i, &j| compare(i, j));
                    } else {
                        indices[k+1..=l].sort_by(|&i, &j| compare(i, j));
                    }
                }
            }
        }

        res
    }
}

mercurial