src/regularisation.rs

Tue, 01 Aug 2023 10:32:12 +0300

author
Tuomo Valkonen <tuomov@iki.fi>
date
Tue, 01 Aug 2023 10:32:12 +0300
branch
dev
changeset 33
aec67cdd6b14
parent 32
56c8adc32b09
child 34
efa60bc4f743
permissions
-rw-r--r--

merge

/*!
Regularisation terms
*/

use numeric_literals::replace_float_literals;
use serde::{Serialize, Deserialize};
use alg_tools::norms::Norm;
use alg_tools::linops::Apply;
use alg_tools::loc::Loc;
use crate::types::*;
use crate::measures::{
    DiscreteMeasure,
    DeltaMeasure,
    Radon
};
use crate::fb::FBGenericConfig;
#[allow(unused_imports)] // Used by documentation.
use crate::fb::pointsource_fb_reg;
#[allow(unused_imports)] // Used by documentation.
use crate::sliding_fb::pointsource_sliding_fb_reg;

use nalgebra::{DVector, DMatrix};
use alg_tools::nalgebra_support::ToNalgebraRealField;
use alg_tools::mapping::Mapping;
use alg_tools::bisection_tree::{
    BTFN,
    Bounds,
    BTSearch,
    P2Minimise,
    SupportGenerator,
    LocalAnalysis,
    Bounded,
};
use crate::subproblem::{
    nonneg::quadratic_nonneg,
    unconstrained::quadratic_unconstrained,
};
use alg_tools::iterate::AlgIteratorFactory;

/// The regularisation term $α\\|μ\\|\_{ℳ(Ω)} + δ_{≥ 0}(μ)$ for [`pointsource_fb_reg`] and other
/// algorithms.
///
/// The only member of the struct is the regularisation parameter α.
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct NonnegRadonRegTerm<F : Float>(pub F /* α */);

impl<'a, F : Float> NonnegRadonRegTerm<F> {
    /// Returns the regularisation parameter
    pub fn α(&self) -> F {
        let &NonnegRadonRegTerm(α) = self;
        α
    }
}

impl<'a, F : Float, const N : usize> Apply<&'a DiscreteMeasure<Loc<F, N>, F>>
for NonnegRadonRegTerm<F> {
    type Output = F;
    
    fn apply(&self, μ : &'a DiscreteMeasure<Loc<F, N>, F>) -> F {
        self.α() * μ.norm(Radon)
    }
}


/// The regularisation term $α\|μ\|_{ℳ(Ω)}$ for [`pointsource_fb_reg`].
///
/// The only member of the struct is the regularisation parameter α.
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct RadonRegTerm<F : Float>(pub F /* α */);


impl<'a, F : Float> RadonRegTerm<F> {
    /// Returns the regularisation parameter
    pub fn α(&self) -> F {
        let &RadonRegTerm(α) = self;
        α
    }
}

impl<'a, F : Float, const N : usize> Apply<&'a DiscreteMeasure<Loc<F, N>, F>>
for RadonRegTerm<F> {
    type Output = F;
    
    fn apply(&self, μ : &'a DiscreteMeasure<Loc<F, N>, F>) -> F {
        self.α() * μ.norm(Radon)
    }
}

/// Regularisation term configuration
#[derive(Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Debug)]
pub enum Regularisation<F : Float> {
    /// $α \\|μ\\|\_{ℳ(Ω)}$
    Radon(F),
    /// $α\\|μ\\|\_{ℳ(Ω)} + δ_{≥ 0}(μ)$
    NonnegRadon(F),
}

impl<'a, F : Float, const N : usize> Apply<&'a DiscreteMeasure<Loc<F, N>, F>>
for Regularisation<F> {
    type Output = F;
    
    fn apply(&self, μ : &'a DiscreteMeasure<Loc<F, N>, F>) -> F {
        match *self {
            Self::Radon(α) => RadonRegTerm(α).apply(μ),
            Self::NonnegRadon(α) => NonnegRadonRegTerm(α).apply(μ),
        }
    }
}

/// Abstraction of regularisation terms for [`generic_pointsource_fb_reg`].
pub trait RegTerm<F : Float + ToNalgebraRealField, const N : usize>
: for<'a> Apply<&'a DiscreteMeasure<Loc<F, N>, F>, Output = F> {
    /// Approximately solve the problem
    /// <div>$$
    ///     \min_{x ∈ ℝ^n} \frac{1}{2} x^⊤Ax - g^⊤ x + τ G(x)
    /// $$</div>
    /// for $G$ depending on the trait implementation.
    ///
    /// The parameter `mA` is $A$. An estimate for its opeator norm should be provided in
    /// `mA_normest`. The initial iterate and output is `x`. The current main tolerance is `ε`.
    ///
    /// Returns the number of iterations taken.
    fn solve_findim(
        &self,
        mA : &DMatrix<F::MixedType>,
        g : &DVector<F::MixedType>,
        τ : F,
        x : &mut DVector<F::MixedType>,
        mA_normest : F,
        ε : F,
        config : &FBGenericConfig<F>
    ) -> usize;

    /// Find a point where `d` may violate the tolerance `ε`.
    ///
    /// If `skip_by_rough_check` is set, do not find the point if a rough check indicates that we
    /// are in bounds. `ε` is the current main tolerance and `τ` a scaling factor for the
    /// regulariser.
    ///
    /// Returns `None` if `d` is in bounds either based on the rough check, or a more precise check
    /// terminating early. Otherwise returns a possibly violating point, the value of `d` there,
    /// and a boolean indicating whether the found point is in bounds.
    fn find_tolerance_violation<G, BT>(
        &self,
        d : &mut BTFN<F, G, BT, N>,
        τ : F,
        ε : F,
        skip_by_rough_check : bool,
        config : &FBGenericConfig<F>,
    ) -> Option<(Loc<F, N>, F, bool)>
    where BT : BTSearch<F, N, Agg=Bounds<F>>,
          G : SupportGenerator<F, N, Id=BT::Data>,
          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
                           + LocalAnalysis<F, Bounds<F>, N>;

    /// Verify that `d` is in bounds `ε` for a merge candidate `μ`
    ///
    /// `ε` is the current main tolerance and `τ` a scaling factor for the regulariser.
    fn verify_merge_candidate<G, BT>(
        &self,
        d : &mut BTFN<F, G, BT, N>,
        μ : &DiscreteMeasure<Loc<F, N>, F>,
        τ : F,
        ε : F,
        config : &FBGenericConfig<F>,
    ) -> bool
    where BT : BTSearch<F, N, Agg=Bounds<F>>,
          G : SupportGenerator<F, N, Id=BT::Data>,
          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
                           + LocalAnalysis<F, Bounds<F>, N>;

    /// TODO: document this
    fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>>;

    /// Returns a scaling factor for the tolerance sequence.
    ///
    /// Typically this is the regularisation parameter.
    fn tolerance_scaling(&self) -> F;
}

/// Abstraction of regularisation terms for [`pointsource_sliding_fb_reg`].
pub trait SlidingRegTerm<F : Float + ToNalgebraRealField, const N : usize>
: RegTerm<F, N> {
    /// Calculate $τ[w(z) - w(y)]$ for some w in the subdifferential of the regularisation
    /// term, such that $-ε ≤ τw - d ≤ ε$.
    fn goodness<G, BT>(
        &self,
        d : &mut BTFN<F, G, BT, N>,
        μ : &DiscreteMeasure<Loc<F, N>, F>,
        y : &Loc<F, N>,
        z : &Loc<F, N>,
        τ : F,
        ε : F,
        config : &FBGenericConfig<F>,
    ) -> F
    where BT : BTSearch<F, N, Agg=Bounds<F>>,
          G : SupportGenerator<F, N, Id=BT::Data>,
          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
                           + LocalAnalysis<F, Bounds<F>, N>;
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N>
for NonnegRadonRegTerm<F>
where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
    fn solve_findim(
        &self,
        mA : &DMatrix<F::MixedType>,
        g : &DVector<F::MixedType>,
        τ : F,
        x : &mut DVector<F::MixedType>,
        mA_normest : F,
        ε : F,
        config : &FBGenericConfig<F>
    ) -> usize {
        let inner_tolerance = ε * config.inner.tolerance_mult;
        let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
        let inner_τ = config.inner.τ0 / mA_normest;
        quadratic_nonneg(config.inner.method, mA, g, τ * self.α(), x,
                         inner_τ, inner_it)
    }

    #[inline]
    fn find_tolerance_violation<G, BT>(
        &self,
        d : &mut BTFN<F, G, BT, N>,
        τ : F,
        ε : F,
        skip_by_rough_check : bool,
        config : &FBGenericConfig<F>,
    ) -> Option<(Loc<F, N>, F, bool)>
    where BT : BTSearch<F, N, Agg=Bounds<F>>,
          G : SupportGenerator<F, N, Id=BT::Data>,
          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
                           + LocalAnalysis<F, Bounds<F>, N> {
        let τα = τ * self.α();
        let keep_below = τα + ε;
        let maximise_above = τα + ε * config.insertion_cutoff_factor;
        let refinement_tolerance = ε * config.refinement.tolerance_mult;

        // If preliminary check indicates that we are in bonds, and if it otherwise matches
        // the insertion strategy, skip insertion.
        if skip_by_rough_check && d.bounds().upper() <= keep_below {
            None
        } else {
            // If the rough check didn't indicate no insertion needed, find maximising point.
            d.maximise_above(maximise_above, refinement_tolerance, config.refinement.max_steps)
             .map(|(ξ, v_ξ)| (ξ, v_ξ, v_ξ <= keep_below))
        }
    }

    fn verify_merge_candidate<G, BT>(
        &self,
        d : &mut BTFN<F, G, BT, N>,
        μ : &DiscreteMeasure<Loc<F, N>, F>,
        τ : F,
        ε : F,
        config : &FBGenericConfig<F>,
    ) -> bool
    where BT : BTSearch<F, N, Agg=Bounds<F>>,
          G : SupportGenerator<F, N, Id=BT::Data>,
          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
                           + LocalAnalysis<F, Bounds<F>, N> {
        let τα = τ * self.α();
        let refinement_tolerance = ε * config.refinement.tolerance_mult;
        let merge_tolerance = config.merge_tolerance_mult * ε;
        let keep_below = τα + merge_tolerance;
        let keep_supp_above = τα - merge_tolerance;
        let bnd = d.bounds();

        return (
            bnd.lower() >= keep_supp_above
            ||
            μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
                (β == 0.0) || d.apply(x) >= keep_supp_above
            }).all(std::convert::identity)
         ) && (
            bnd.upper() <= keep_below
            ||
            d.has_upper_bound(keep_below, refinement_tolerance, config.refinement.max_steps)
        )
    }

    fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
        let τα = τ * self.α();
        Some(Bounds(τα - ε,  τα + ε))
    }

    fn tolerance_scaling(&self) -> F {
        self.α()
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float + ToNalgebraRealField, const N : usize> SlidingRegTerm<F, N>
for NonnegRadonRegTerm<F>
where Cube<F, N> : P2Minimise<Loc<F, N>, F> {

    fn goodness<G, BT>(
        &self,
        d : &mut BTFN<F, G, BT, N>,
        _μ : &DiscreteMeasure<Loc<F, N>, F>,
        y : &Loc<F, N>,
        z : &Loc<F, N>,
        τ : F,
        ε : F,
        _config : &FBGenericConfig<F>,
    ) -> F
    where BT : BTSearch<F, N, Agg=Bounds<F>>,
          G : SupportGenerator<F, N, Id=BT::Data>,
          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
                           + LocalAnalysis<F, Bounds<F>, N> {
        //let w = |x| 1.0.min((ε + d.apply(x))/(τ * self.α()));
        let τw = |x| τ.min((ε + d.apply(x))/self.α());
        τw(z) - τw(y)
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float + ToNalgebraRealField, const N : usize> RegTerm<F, N> for RadonRegTerm<F>
where Cube<F, N> : P2Minimise<Loc<F, N>, F> {
    fn solve_findim(
        &self,
        mA : &DMatrix<F::MixedType>,
        g : &DVector<F::MixedType>,
        τ : F,
        x : &mut DVector<F::MixedType>,
        mA_normest: F,
        ε : F,
        config : &FBGenericConfig<F>
    ) -> usize {
        let inner_tolerance = ε * config.inner.tolerance_mult;
        let inner_it = config.inner.iterator_options.stop_target(inner_tolerance);
        let inner_τ = config.inner.τ0 / mA_normest;
        quadratic_unconstrained(config.inner.method, mA, g, τ * self.α(), x,
                                inner_τ, inner_it)
    }

   fn find_tolerance_violation<G, BT>(
        &self,
        d : &mut BTFN<F, G, BT, N>,
        τ : F,
        ε : F,
        skip_by_rough_check : bool,
        config : &FBGenericConfig<F>,
    ) -> Option<(Loc<F, N>, F, bool)>
    where BT : BTSearch<F, N, Agg=Bounds<F>>,
          G : SupportGenerator<F, N, Id=BT::Data>,
          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
                           + LocalAnalysis<F, Bounds<F>, N> {
        let τα = τ * self.α();
        let keep_below = τα + ε;
        let keep_above = -τα - ε;
        let maximise_above = τα + ε * config.insertion_cutoff_factor;
        let minimise_below = -τα - ε * config.insertion_cutoff_factor;
        let refinement_tolerance = ε * config.refinement.tolerance_mult;

        // If preliminary check indicates that we are in bonds, and if it otherwise matches
        // the insertion strategy, skip insertion.
        if skip_by_rough_check && Bounds(keep_above, keep_below).superset(&d.bounds()) {
            None
        } else {
            // If the rough check didn't indicate no insertion needed, find maximising point.
            let mx = d.maximise_above(maximise_above, refinement_tolerance,
                                      config.refinement.max_steps);
            let mi = d.minimise_below(minimise_below, refinement_tolerance,
                                      config.refinement.max_steps);

            match (mx, mi) {
                (None, None) => None,
                (Some((ξ, v_ξ)), None) => Some((ξ, v_ξ, keep_below >= v_ξ)),
                (None, Some((ζ, v_ζ))) => Some((ζ, v_ζ, keep_above <= v_ζ)),
                (Some((ξ, v_ξ)), Some((ζ, v_ζ))) => {
                    if v_ξ - τα > τα - v_ζ {
                        Some((ξ, v_ξ, keep_below >= v_ξ))
                    } else {
                        Some((ζ, v_ζ, keep_above <= v_ζ))
                    }
                }
            }
        }
    }

    fn verify_merge_candidate<G, BT>(
        &self,
        d : &mut BTFN<F, G, BT, N>,
        μ : &DiscreteMeasure<Loc<F, N>, F>,
        τ : F,
        ε : F,
        config : &FBGenericConfig<F>,
    ) -> bool
    where BT : BTSearch<F, N, Agg=Bounds<F>>,
          G : SupportGenerator<F, N, Id=BT::Data>,
          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
                           + LocalAnalysis<F, Bounds<F>, N> {
        let τα = τ * self.α();
        let refinement_tolerance = ε * config.refinement.tolerance_mult;
        let merge_tolerance = config.merge_tolerance_mult * ε;
        let keep_below = τα + merge_tolerance;
        let keep_above = -τα - merge_tolerance;
        let keep_supp_pos_above = τα - merge_tolerance;
        let keep_supp_neg_below = -τα + merge_tolerance;
        let bnd = d.bounds();

        return (
            (bnd.lower() >= keep_supp_pos_above && bnd.upper() <= keep_supp_neg_below)
            ||
            μ.iter_spikes().map(|&DeltaMeasure{ α : β, ref x }| {
                use std::cmp::Ordering::*;
                match β.partial_cmp(&0.0) {
                    Some(Greater) => d.apply(x) >= keep_supp_pos_above,
                    Some(Less) => d.apply(x) <= keep_supp_neg_below,
                    _ => true,
                }
            }).all(std::convert::identity)
        ) && (
            bnd.upper() <= keep_below
            ||
            d.has_upper_bound(keep_below, refinement_tolerance,
                              config.refinement.max_steps)
        ) && (
            bnd.lower() >= keep_above
            ||
            d.has_lower_bound(keep_above, refinement_tolerance,
                              config.refinement.max_steps)
        )
    }

    fn target_bounds(&self, τ : F, ε : F) -> Option<Bounds<F>> {
        let τα = τ * self.α();
        Some(Bounds(-τα - ε,  τα + ε))
    }

    fn tolerance_scaling(&self) -> F {
        self.α()
    }
}

#[replace_float_literals(F::cast_from(literal))]
impl<F : Float + ToNalgebraRealField, const N : usize> SlidingRegTerm<F, N>
for RadonRegTerm<F>
where Cube<F, N> : P2Minimise<Loc<F, N>, F> {

    fn goodness<G, BT>(
        &self,
        d : &mut BTFN<F, G, BT, N>,
        _μ : &DiscreteMeasure<Loc<F, N>, F>,
        y : &Loc<F, N>,
        z : &Loc<F, N>,
        τ : F,
        ε : F,
        _config : &FBGenericConfig<F>,
    ) -> F
    where BT : BTSearch<F, N, Agg=Bounds<F>>,
          G : SupportGenerator<F, N, Id=BT::Data>,
          G::SupportType : Mapping<Loc<F, N>,Codomain=F>
                           + LocalAnalysis<F, Bounds<F>, N> {

        let α = self.α();
        // let w = |x| {
        //     let dx = d.apply(x);
        //     ((-ε + dx)/(τ * α)).max(1.0.min(ε + dx)/(τ * α))
        // };
        let τw = |x| {
            let dx = d.apply(x);
            ((-ε + dx)/α).max(τ.min(ε + dx)/α)
        };
        τw(z) - τw(y)
    }
}

mercurial