src/prox_penalty.rs

Thu, 26 Feb 2026 11:38:43 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 26 Feb 2026 11:38:43 -0500
branch
dev
changeset 61
4f468d35fa29
parent 51
0693cc9ba9f0
child 62
32328a74c790
child 63
7a8a55fd41c0
permissions
-rw-r--r--

General forward operators, separation of measures into own crate, and other architecture improvements to support the pointsource_pde crate.

/*!
Proximal penalty abstraction
*/

use alg_tools::types::*;
use numeric_literals::replace_float_literals;
use serde::{Deserialize, Serialize};

use crate::measures::merging::SpikeMergingMethod;
use crate::measures::DiscreteMeasure;
use crate::regularisation::RegTerm;
use crate::subproblem::InnerSettings;
use crate::tolerance::Tolerance;
use crate::types::{IterInfo, RefinementSettings};
use alg_tools::error::DynResult;
use alg_tools::instance::Space;
use alg_tools::iterate::{AlgIterator, AlgIteratorIteration};
use alg_tools::mapping::Mapping;
use alg_tools::nalgebra_support::ToNalgebraRealField;

pub mod radon_squared;
pub mod wave;
pub use radon_squared::RadonSquared;

/// Settings for the solution of the stepwise optimality condition.
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
#[serde(default)]
pub struct InsertionConfig<F: Float> {
    /// Tolerance for point insertion.
    pub tolerance: Tolerance<F>,

    /// Stop looking for predual maximum (where to isert a new point) below
    /// `tolerance` multiplied by this factor.
    ///
    /// Not used by [`crate::prox_penalty::radon_squared`].
    pub insertion_cutoff_factor: F,

    /// Settings for branch and bound refinement when looking for predual maxima
    pub refinement: RefinementSettings<F>,

    /// Maximum insertions within each outer iteration
    ///
    /// Not used by [`crate::prox_penalty::radon_squared`].
    pub max_insertions: usize,

    /// Pair `(n, m)` for maximum insertions `m` on first `n` iterations.
    ///
    /// Not used by [`crate::prox_penalty::radon_squared`].
    pub bootstrap_insertions: Option<(usize, usize)>,

    /// Inner method settings
    pub inner: InnerSettings<F>,

    /// Spike merging method
    pub merging: SpikeMergingMethod<F>,

    /// Tolerance multiplier for merges
    pub merge_tolerance_mult: F,

    /// Merge spikes after last step (even if merging not generally enabled)
    pub final_merging: bool,

    /// Use fitness as merging criterion. Implies worse convergence guarantees.
    pub fitness_merging: bool,

    /// Iterations between merging heuristic tries
    pub merge_every: usize,
    // /// Save $μ$ for postprocessing optimisation
    // pub postprocessing : bool
}

#[replace_float_literals(F::cast_from(literal))]
impl<F: Float> Default for InsertionConfig<F> {
    fn default() -> Self {
        InsertionConfig {
            tolerance: Default::default(),
            insertion_cutoff_factor: 1.0,
            refinement: Default::default(),
            max_insertions: 100,
            //bootstrap_insertions : None,
            bootstrap_insertions: Some((10, 1)),
            inner: Default::default(),
            merging: Default::default(),
            final_merging: true,
            fitness_merging: false,
            merge_every: 10,
            merge_tolerance_mult: 2.0,
            // postprocessing : false,
        }
    }
}

impl<F: Float> InsertionConfig<F> {
    /// Check if merging should be attempted this iteration
    pub fn merge_now<I: AlgIterator>(&self, state: &AlgIteratorIteration<I>) -> bool {
        self.merging.enabled && state.iteration() % self.merge_every == 0
    }

    /// Returns the final merging method
    pub fn final_merging_method(&self) -> SpikeMergingMethod<F> {
        SpikeMergingMethod { enabled: self.final_merging, ..self.merging }
    }
}

/// Available proximal terms
#[derive(Copy, Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
pub enum ProxTerm {
    /// Partial-to-wave operator 𝒟.
    Wave,
    /// Radon-norm squared
    RadonSquared,
}

/// Trait for proximal penalties
pub trait ProxPenalty<Domain, PreadjointCodomain, Reg, F = f64>
where
    F: Float + ToNalgebraRealField,
    Reg: RegTerm<Domain, F>,
    Domain: Space + Clone,
{
    type ReturnMapping: Mapping<Domain, Codomain = F>;

    /// Returns the type of this proximality penalty
    fn prox_type() -> ProxTerm;

    /// Insert new spikes into `μ` to approximately satisfy optimality conditions
    /// with the forward step term fixed to `τv`.
    ///
    /// May return `τv + w` for `w` a subdifferential of the regularisation term `reg`,
    /// as well as an indication of whether the tolerance bounds `ε` are satisfied.
    ///
    /// `μ_base + ν_delta` is the base point, where `μ` and `μ_base` are assumed to have the same
    /// spike locations, while `ν_delta` may have different locations.
    ///
    /// `τv` is mutable to allow [`alg_tools::bounds::MinMaxMapping`] optimisation to
    /// refine data. Actual values of `τv` are not supposed to be mutated.
    fn insert_and_reweigh<I>(
        &self,
        μ: &mut DiscreteMeasure<Domain, F>,
        τv: &mut PreadjointCodomain,
        μ_base: &DiscreteMeasure<Domain, F>,
        ν_delta: Option<&DiscreteMeasure<Domain, F>>,
        τ: F,
        ε: F,
        config: &InsertionConfig<F>,
        reg: &Reg,
        state: &AlgIteratorIteration<I>,
        stats: &mut IterInfo<F>,
    ) -> DynResult<(Option<Self::ReturnMapping>, bool)>
    where
        I: AlgIterator;

    /// Merge spikes, if possible.
    ///
    /// Either optimality condition merging or objective value (fitness) merging
    /// may be used, the latter only if `fitness` is provided and `config.fitness_merging`
    /// is set.
    fn merge_spikes(
        &self,
        μ: &mut DiscreteMeasure<Domain, F>,
        τv: &mut PreadjointCodomain,
        μ_base: &DiscreteMeasure<Domain, F>,
        ν_delta: Option<&DiscreteMeasure<Domain, F>>,
        τ: F,
        ε: F,
        config: &InsertionConfig<F>,
        reg: &Reg,
        fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>,
    ) -> usize;

    /// Merge spikes, if possible.
    ///
    /// Unlike [`Self::merge_spikes`], this variant only supports optimality condition based merging
    #[inline]
    fn merge_spikes_no_fitness(
        &self,
        μ: &mut DiscreteMeasure<Domain, F>,
        τv: &mut PreadjointCodomain,
        μ_base: &DiscreteMeasure<Domain, F>,
        ν_delta: Option<&DiscreteMeasure<Domain, F>>,
        τ: F,
        ε: F,
        config: &InsertionConfig<F>,
        reg: &Reg,
    ) -> usize {
        /// This is a hack to create a `None` of same type as a `Some`
        // for the `impl Fn` parameter of `merge_spikes`.
        #[inline]
        fn into_none<T>(_: Option<T>) -> Option<T> {
            None
        }
        self.merge_spikes(
            μ,
            τv,
            μ_base,
            ν_delta,
            τ,
            ε,
            config,
            reg,
            into_none(Some(|_: &DiscreteMeasure<Domain, F>| F::ZERO)),
        )
    }
}

/// Trait to calculate step length bound by `Dat` when the proximal penalty is `Self`,
/// which is typically also a [`ProxPenalty`]. If it is given by a (squared) norm $\|.\|_*$, and
/// and `Dat` respresents the function $f$, then this trait should calculate `L` such that
/// $\|f'(x)-f'(y)\| ≤ L\|x-y\|_*, where the step length is supposed to satisfy $τ L ≤ 1$.
pub trait StepLengthBound<F: Float, Dat> {
    /// Returns $L$.
    fn step_length_bound(&self, f: &Dat) -> DynResult<F>;
}

/// A variant of [`StepLengthBound`] for step length parameters for [`Pair`]s of variables.
pub trait StepLengthBoundPair<F: Float, Dat> {
    fn step_length_bound_pair(&self, f: &Dat) -> DynResult<(F, F)>;
}

/// Trait to calculate step length bound by the operator `A` when the proximal penalty is `Self`,
/// which is typically also a [`ProxPenalty`]. If it is given by a (squared) norm $\|.\|_*$,
/// then this trait should calculate `L` such that
/// $\|Ax\| ≤ L\|x\|_*, where the primal-dual step lengths are supposed to satisfy $τσ L^2 ≤ 1$.
/// The domain needs to be specified here, because A can operate on various domains.
pub trait StepLengthBoundPD<F: Float, A, Domain> {
    /// Returns $L$.
    fn step_length_bound_pd(&self, f: &A) -> DynResult<F>;
}

mercurial