src/prox_penalty/wave.rs

Wed, 22 Apr 2026 23:43:00 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Wed, 22 Apr 2026 23:43:00 -0500
branch
dev
changeset 69
3ad8879ee6e1
parent 63
7a8a55fd41c0
permissions
-rw-r--r--

Bump versions

/*!
Basic proximal penalty based on convolution operators $𝒟$.
 */

use super::{InsertionConfig, ProxPenalty, ProxTerm, StepLengthBound, StepLengthBoundPD};
use crate::dataterm::QuadraticDataTerm;
use crate::forward_model::ForwardModel;
use crate::measures::merging::SpikeMerging;
use crate::measures::{DeltaMeasure, DiscreteMeasure, Radon};
use crate::regularisation::RegTerm;
use crate::seminorms::DiscreteMeasureOp;
use crate::types::IterInfo;
use alg_tools::bounds::MinMaxMapping;
use alg_tools::error::DynResult;
use alg_tools::iterate::{AlgIterator, AlgIteratorIteration};
use alg_tools::linops::BoundedLinear;
use alg_tools::mapping::{Instance, Mapping, Space};
use alg_tools::nalgebra_support::ToNalgebraRealField;
use alg_tools::norms::{Linfinity, Norm, NormExponent, L2};
use alg_tools::types::*;
use colored::Colorize;
use nalgebra::DVector;
use numeric_literals::replace_float_literals;

#[replace_float_literals(F::cast_from(literal))]
impl<F, M, Reg, 𝒟, O, Domain> ProxPenalty<Domain, M, Reg, F> for 𝒟
where
    Domain: Space + Clone + PartialEq + 'static,
    for<'a> &'a Domain: Instance<Domain>,
    F: Float + ToNalgebraRealField,
    𝒟: DiscreteMeasureOp<Domain, F>,
    𝒟::Codomain: Mapping<Domain, Codomain = F>,
    M: Mapping<Domain, Codomain = F>,
    for<'a> &'a M: std::ops::Add<𝒟::PreCodomain, Output = O>,
    O: MinMaxMapping<Domain, F>,
    Reg: RegTerm<Domain, F>,
    DiscreteMeasure<Domain, F>: SpikeMerging<F>,
{
    type ReturnMapping = O;

    fn prox_type() -> ProxTerm {
        ProxTerm::Wave
    }

    fn insert_and_reweigh<I>(
        &self,
        μ: &mut DiscreteMeasure<Domain, F>,
        τv: &mut M,
        τ: F,
        ε: F,
        config: &InsertionConfig<F>,
        reg: &Reg,
        state: &AlgIteratorIteration<I>,
        stats: &mut IterInfo<F>,
    ) -> DynResult<(Option<Self::ReturnMapping>, bool)>
    where
        I: AlgIterator,
    {
        let op𝒟norm = self.opnorm_bound(Radon, Linfinity)?;

        // Maximum insertion count and measure difference calculation depend on insertion style.
        let (max_insertions, warn_insertions) =
            match (state.iteration(), config.bootstrap_insertions) {
                (i, Some((l, k))) if i <= l => (k, false),
                _ => (config.max_insertions, !state.is_quiet()),
            };

        let μ_base = μ.clone();
        let ω0 = self.apply(&μ_base);

        // Add points to support until within error tolerance or maximum insertion count reached.
        let mut count = 0;
        let (within_tolerances, d) = 'insertion: loop {
            if μ.len() > 0 {
                // Form finite-dimensional subproblem. The subproblem references to the original μ^k
                // from the beginning of the iteration are all contained in the immutable c and g.
                // TODO: observe negation of -τv after switch from minus_τv: finite-dimensional
                // problems have not yet been updated to sign change.
                let à = self.findim_matrix(μ.iter_locations());
                let g̃ = DVector::from_iterator(
                    μ.len(),
                    μ.iter_locations()
                        .map(|ζ| ω0.apply(ζ) - τv.apply(ζ))
                        .map(F::to_nalgebra_mixed),
                );
                let mut x = μ.masses_dvector();

                // The gradient of the forward component of the inner objective is C^*𝒟Cx - g̃.
                // We have |C^*𝒟Cx|_2 = sup_{|z|_2 ≤ 1} ⟨z, C^*𝒟Cx⟩ = sup_{|z|_2 ≤ 1} ⟨Cz|𝒟Cx⟩
                // ≤ sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟Cx|_∞ ≤  sup_{|z|_2 ≤ 1} |Cz|_ℳ |𝒟| |Cx|_ℳ
                // ≤ sup_{|z|_2 ≤ 1} |z|_1 |𝒟| |x|_1 ≤ sup_{|z|_2 ≤ 1} n |z|_2 |𝒟| |x|_2
                // = n |𝒟| |x|_2, where n is the number of points. Therefore
                let Ã_normest = op𝒟norm * F::cast_from(μ.len());

                // Solve finite-dimensional subproblem.
                stats.inner_iters += reg.solve_findim(&Ã, &g̃, τ, &mut x, Ã_normest, ε, config);

                // Update masses of μ based on solution of finite-dimensional subproblem.
                μ.set_masses_dvector(&x);
            }

            // Form d = τv + 𝒟μ - ω0 = τv + 𝒟(μ - μ^k) for checking the proximate optimality
            // conditions in the predual space, and finding new points for insertion, if necessary.
            let mut d = &*τv + self.preapply(μ.sub_matching(&μ_base));

            // If no merging heuristic is used, let's be more conservative about spike insertion,
            // and skip it after first round. If merging is done, being more greedy about spike
            // insertion also seems to improve performance.
            let skip_by_rough_check = if config.merging.enabled {
                false
            } else {
                count > 0
            };

            // Find a spike to insert, if needed
            let (ξ, _v_ξ, in_bounds) =
                match reg.find_tolerance_violation(&mut d, τ, ε, skip_by_rough_check, config) {
                    None => break 'insertion (true, d),
                    Some(res) => res,
                };

            // Break if maximum insertion count reached
            if count >= max_insertions {
                break 'insertion (in_bounds, d);
            }

            // No point in optimising the weight here; the finite-dimensional algorithm is fast.
            *μ += DeltaMeasure { x: ξ, α: 0.0 };
            count += 1;
        };

        if !within_tolerances && warn_insertions {
            // Complain (but continue) if we failed to get within tolerances
            // by inserting more points.
            let err = format!(
                "Maximum insertions reached without achieving \
                                subproblem solution tolerance"
            );
            println!("{}", err.red());
        }

        Ok((Some(d), within_tolerances))
    }

    fn merge_spikes(
        &self,
        μ: &mut DiscreteMeasure<Domain, F>,
        τv: &mut M,
        μ_base: &DiscreteMeasure<Domain, F>,
        τ: F,
        ε: F,
        config: &InsertionConfig<F>,
        reg: &Reg,
        fitness: Option<impl Fn(&DiscreteMeasure<Domain, F>) -> F>,
    ) -> usize {
        if config.fitness_merging {
            if let Some(f) = fitness {
                return μ.merge_spikes_fitness(config.merging, f, |&v| v).1;
            }
        }
        μ.merge_spikes(config.merging, |μ_candidate| {
            let mut d = &*τv + self.preapply(μ_candidate.sub_matching(μ_base));
            reg.verify_merge_candidate(&mut d, μ_candidate, τ, ε, config)
        })
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<'a, F, A, 𝒟, Domain> StepLengthBound<F, QuadraticDataTerm<F, DiscreteMeasure<Domain, F>, A>>
    for 𝒟
where
    Domain: Space + Clone + PartialEq + 'static,
    F: Float + ToNalgebraRealField,
    𝒟: DiscreteMeasureOp<Domain, F>,
    A: ForwardModel<DiscreteMeasure<Domain, F>, F>
        + for<'b> BoundedLinear<DiscreteMeasure<Domain, F>, &'b 𝒟, L2, F>,
    DiscreteMeasure<Domain, F>: for<'b> Norm<&'b 𝒟, F>,
    for<'b> &'b 𝒟: NormExponent,
{
    fn step_length_bound(
        &self,
        f: &QuadraticDataTerm<F, DiscreteMeasure<Domain, F>, A>,
    ) -> DynResult<F> {
        // TODO: direct squared calculation
        Ok(f.operator().opnorm_bound(self, L2)?.powi(2))
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F, A, 𝒟, Domain> StepLengthBoundPD<F, A, DiscreteMeasure<Domain, F>> for 𝒟
where
    Domain: Space + Clone + PartialEq + 'static,
    F: Float + ToNalgebraRealField,
    𝒟: DiscreteMeasureOp<Domain, F>,
    A: for<'a> BoundedLinear<DiscreteMeasure<Domain, F>, &'a 𝒟, L2, F>,
    DiscreteMeasure<Domain, F>: for<'a> Norm<&'a 𝒟, F>,
    for<'b> &'b 𝒟: NormExponent,
{
    fn step_length_bound_pd(&self, opA: &A) -> DynResult<F> {
        opA.opnorm_bound(self, L2)
    }
}

mercurial