src/measures/merging.rs

Sun, 26 Jan 2025 11:58:02 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Sun, 26 Jan 2025 11:58:02 -0500
branch
dev
changeset 42
6a7365d73e4c
parent 39
6316d68b58af
child 51
0693cc9ba9f0
permissions
-rw-r--r--

Fixes to Radon norm prox term inner algorithm

/*!
Spike merging heuristics for [`DiscreteMeasure`]s.

This module primarily provides the [`SpikeMerging`] trait, and within it,
the [`SpikeMerging::merge_spikes`] method. The trait is implemented on
[`DiscreteMeasure<Loc<F, N>, F>`]s in dimensions `N=1` and `N=2`.
*/

use numeric_literals::replace_float_literals;
use std::cmp::Ordering;
use serde::{Serialize, Deserialize};
//use clap::builder::{PossibleValuesParser, PossibleValue};
use alg_tools::nanleast::NaNLeast;

use crate::types::*;
use super::delta::*;
use super::discrete::*;

/// Spike merging heuristic selection
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[allow(dead_code)]
pub struct SpikeMergingMethod<F> {
    // Merging radius
    pub(crate) radius : F,
    // Enabled
    pub(crate) enabled : bool,
    // Interpolate merged points
    pub(crate) interp : bool,
}


#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> Default for SpikeMergingMethod<F> {
    fn default() -> Self {
        SpikeMergingMethod{
            radius : 0.01,
            enabled : false,
            interp : true,
        }
    }
}

/// Trait for dimension-dependent implementation of heuristic peak merging strategies.
pub trait SpikeMerging<F> {
    /// Attempt spike merging according to [`SpikeMerging`] method.
    ///
    /// Returns the last [`Some`] returned by the merging candidate acceptance decision closure
    /// `accept` if any merging is performed. The closure should accept as its only parameter a
    /// new candidate measure (it will generally be internally mutated `self`, although this is
    /// not guaranteed), and return [`None`] if the merge is accepted, and otherwise a [`Some`] of
    /// an arbitrary value. This method will return that value for the *last* accepted merge, or
    /// [`None`] if no merge was accepted.
    ///
    /// This method is stable with respect to spike locations:  on merge, the weights of existing
    /// removed spikes is set to zero, new ones inserted at the end of the spike vector.
    /// They merge may also be performed by increasing the weights of the existing spikes,
    /// without inserting new spikes.
    fn merge_spikes<G>(&mut self, method : SpikeMergingMethod<F>, accept : G) -> usize
    where G : FnMut(&'_ Self) -> bool {
        if method.enabled {
            self.do_merge_spikes_radius(method.radius, method.interp, accept)
        } else {
            0
        }
    }

    /// Attempt to merge spikes based on a value and a fitness function.
    ///
    /// Calls [`SpikeMerging::merge_spikes`] with `accept` constructed from the composition of
    /// `value` and `fitness`, compared to initial fitness. Returns the last return value of `value`
    // for a merge  accepted by `fitness`. If no merge was accepted, `value` applied to the initial
    /// `self` is returned. also the number of merges is returned;
    fn merge_spikes_fitness<G, H, V, O>(
        &mut self,
        method : SpikeMergingMethod<F>,
        value : G,
        fitness : H
    ) -> (V, usize)
    where G : Fn(&'_ Self) -> V,
          H : Fn(&'_ V) -> O,
          O : PartialOrd {
        let mut res = value(self);
        let initial_fitness = fitness(&res);
        let count = self.merge_spikes(method, |μ| {
            res = value(μ);
            fitness(&res) <= initial_fitness
        });
        (res, count)
    }

    /// Attempt to merge spikes that are within radius $ρ$ of each other (unspecified norm).
    ///
    /// This method implements [`SpikeMerging::merge_spikes`] for
    /// [`SpikeMergingMethod::HeuristicRadius`]. The closure `accept` and the return value are
    /// as for that method.
    fn do_merge_spikes_radius<G>(&mut self, ρ : F, interp : bool, accept : G) -> usize
    where G : FnMut(&'_ Self) -> bool;
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float, const N : usize>  DiscreteMeasure<Loc<F, N>, F> {
    /// Attempts to merge spikes with indices `i` and `j`.
    ///
    /// This assumes that the weights of the two spikes have already been checked not to be zero.
    ///
    /// The parameter `res` points to the current “result” for [`SpikeMerging::merge_spikes`].
    /// If the merge is accepted by `accept` returning a [`Some`], `res` will be replaced by its
    /// return value.
    ///
    /// Returns the index of `self.spikes` storing the new spike.
    fn attempt_merge<G>(
        &mut self,
        i : usize,
        j : usize,
        interp : bool,
        accept : &mut G
    ) -> Option<usize>
    where G : FnMut(&'_ Self) -> bool {
        let &DeltaMeasure{ x : xi, α : αi } = &self.spikes[i];
        let &DeltaMeasure{ x : xj, α : αj } = &self.spikes[j];

        if interp {
            // Merge inplace
            self.spikes[i].α = 0.0;
            self.spikes[j].α = 0.0;
            let αia = αi.abs();
            let αja = αj.abs();
            self.spikes.push(DeltaMeasure{ α : αi + αj, x : (xi * αia + xj * αja) / (αia + αja) });
            if accept(self) {
                Some(self.spikes.len()-1)
            } else {
                // Merge not accepted, restore modification
                self.spikes[i].α = αi;
                self.spikes[j].α = αj;
                self.spikes.pop();
                None
            }
        } else {
            // Attempt merge inplace, first combination
            self.spikes[i].α = αi + αj;
            self.spikes[j].α = 0.0;
            if accept(self) {
                // Merge accepted
                Some(i)
            } else {
                // Attempt merge inplace, second combination
                self.spikes[i].α = 0.0;
                self.spikes[j].α = αi + αj;
                if accept(self) {
                    // Merge accepted
                    Some(j)
                } else {
                    // Merge not accepted, restore modification
                    self.spikes[i].α = αi;
                    self.spikes[j].α = αj;
                    None
                }
            }
        }
    }
}

/// Sorts a vector of indices into `slice` by `compare`.
///
/// The closure `compare` operators on references to elements of `slice`.
/// Returns the sorted vector of indices into `slice`.
pub fn sort_indices_by<V, F>(slice : &[V], mut compare : F) -> Vec<usize>
where F : FnMut(&V, &V) -> Ordering
{
    let mut indices = Vec::from_iter(0..slice.len());
    indices.sort_by(|&i, &j| compare(&slice[i], &slice[j]));
    indices
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 1>, F> {

    fn do_merge_spikes_radius<G>(
        &mut self,
        ρ : F,
        interp : bool,
        mut accept : G
    ) -> usize
    where G : FnMut(&'_ Self) -> bool {
        // Sort by coordinate into an indexing array.
        let mut indices = sort_indices_by(&self.spikes, |&δ1, &δ2| {
            let &Loc([x1]) = &δ1.x;
            let &Loc([x2]) = &δ2.x;
            // nan-ignoring ordering of floats
            NaNLeast(x1).cmp(&NaNLeast(x2))
        });

        // Initialise result
        let mut count = 0;

        // Scan consecutive pairs and merge if close enough and accepted by `accept`.
        if indices.len() == 0 {
            return count
        }
        for k in 0..(indices.len()-1) {
            let i = indices[k];
            let j = indices[k+1];
            let &DeltaMeasure{ x : Loc([xi]), α : αi } = &self.spikes[i];
            let &DeltaMeasure{ x : Loc([xj]), α : αj } = &self.spikes[j];
            debug_assert!(xi <= xj);
            // If close enough, attempt merging
            if αi != 0.0 && αj != 0.0 && xj <= xi + ρ {
                if let Some(l) = self.attempt_merge(i, j, interp, &mut accept) {
                    // For this to work (the debug_assert! to not trigger above), the new
                    // coordinate produced by attempt_merge has to be at most xj.
                    indices[k+1] = l;
                    count += 1
                }
            }
        }

        count
    }
}

/// Orders `δ1` and `δ1` according to the first coordinate.
fn compare_first_coordinate<F : Float>(
    δ1 : &DeltaMeasure<Loc<F, 2>, F>,
    δ2 : &DeltaMeasure<Loc<F, 2>, F>
) -> Ordering {
    let &Loc([x11, ..]) = &δ1.x;
    let &Loc([x21, ..]) = &δ2.x;
    // nan-ignoring ordering of floats
    NaNLeast(x11).cmp(&NaNLeast(x21))
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float> SpikeMerging<F> for DiscreteMeasure<Loc<F, 2>, F> {

    fn do_merge_spikes_radius<G>(&mut self, ρ : F, interp : bool, mut accept : G) -> usize
    where G : FnMut(&'_ Self) -> bool {
        // Sort by first coordinate into an indexing array.
        let mut indices = sort_indices_by(&self.spikes, compare_first_coordinate);

        // Initialise result
        let mut count = 0;
        let mut start_scan_2nd = 0;

        // Scan in order
        if indices.len() == 0 {
            return count
        }
        for k in 0..indices.len()-1 {
            let i = indices[k];
            let &DeltaMeasure{ x : Loc([xi1, xi2]), α : αi } = &self[i];

            if αi == 0.0 {
                // Nothin to be done if the weight is already zero
                continue
            }

            let mut closest = None;

            // Scan for second spike. We start from `start_scan_2nd + 1` with `start_scan_2nd`
            // the smallest invalid merging index on the previous loop iteration, because a
            // the _closest_ mergeable spike might have index less than `k` in `indices`, and a
            // merge with it might have not been attempted with this spike if a different closer
            // spike was discovered based on the second coordinate.
            'scan_2nd: for l in (start_scan_2nd+1)..indices.len() {
                if l == k {
                    // Do not attempt to merge a spike with itself
                    continue
                }
                let j = indices[l];
                let &DeltaMeasure{ x : Loc([xj1, xj2]), α : αj } = &self[j];

                if xj1 < xi1 - ρ {
                    // Spike `j = indices[l]` has too low first coordinate. Update starting index
                    // for next iteration, and continue scanning.
                    start_scan_2nd = l;
                    continue 'scan_2nd
                } else if xj1 > xi1 + ρ {
                    // Break out: spike `j = indices[l]` has already too high first coordinate, no
                    // more close enough spikes can be found due to the sorting of `indices`.
                    break 'scan_2nd
                }

                // If also second coordinate is close enough, attempt merging if closer than
                // previously discovered mergeable spikes.
                let d2 = (xi2-xj2).abs();
                if αj != 0.0 && d2 <= ρ {
                    let r1 = xi1-xj1;
                    let d = (d2*d2 + r1*r1).sqrt();
                    match closest {
                        None => closest = Some((l, j, d)),
                        Some((_, _, r)) if r > d => closest = Some((l, j, d)),
                        _ => {},
                    }
                }
            }

            // Attempt merging closest close-enough spike
            if let Some((l, j, _)) = closest {
                if let Some(n) = self.attempt_merge(i, j, interp, &mut accept) {
                    // If merge was succesfull, make new spike candidate for merging.
                    indices[l] = n;
                    count += 1;
                    let compare = |i, j| compare_first_coordinate(&self.spikes[i],
                                                                  &self.spikes[j]);
                    // Re-sort relevant range of indices
                    if l < k {
                        indices[l..k].sort_by(|&i, &j| compare(i, j));
                    } else {
                        indices[k+1..=l].sort_by(|&i, &j| compare(i, j));
                    }
                }
            }
        }

        count
    }
}

mercurial