src/convex.rs

Thu, 01 May 2025 02:28:28 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Thu, 01 May 2025 02:28:28 -0500
branch
dev
changeset 121
fc7d923ff6e7
parent 112
ed8124f1af1d
child 124
6aa955ad8122
permissions
-rw-r--r--

another overflow

/*!
Some convex analysis basics
*/

use crate::error::DynResult;
use crate::euclidean::Euclidean;
use crate::instance::{DecompositionMut, Instance, InstanceMut};
use crate::linops::{IdOp, Scaled};
use crate::mapping::{DifferentiableImpl, LipschitzDifferentiableImpl, Mapping, Space};
use crate::norms::*;
use crate::operator_arithmetic::{Constant, Weighted};
use crate::types::*;
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;

/// Trait for convex mappings. Has no features, just serves as a constraint
///
/// TODO: should constrain `Mapping::Codomain` to implement a partial order,
/// but this makes everything complicated with little benefit.
pub trait ConvexMapping<Domain: Normed<F>, F: Num = f64>: Mapping<Domain, Codomain = F> {
    /// Returns (a lower estimate of) the factor of strong convexity in the norm of `Domain`.
    fn factor_of_strong_convexity(&self) -> F {
        F::ZERO
    }
}

/// Trait for mappings with a Fenchel conjugate
///
/// The conjugate type has to implement [`ConvexMapping`], but a `Conjugable` mapping need
/// not be convex.
pub trait Conjugable<Domain: HasDual<F>, F: Num = f64>: Mapping<Domain, Codomain = F> {
    type Conjugate<'a>: ConvexMapping<Domain::DualSpace, F>
    where
        Self: 'a;

    fn conjugate(&self) -> Self::Conjugate<'_>;
}

/// Trait for mappings with a Fenchel preconjugate
///
/// In contrast to [`Conjugable`], the preconjugate need not implement [`ConvexMapping`],
/// but a `Preconjugable` mapping has to be convex.
pub trait Preconjugable<Domain, Predual, F: Num = f64>: ConvexMapping<Domain, F>
where
    Domain: Normed<F>,
    Predual: HasDual<F>,
{
    type Preconjugate<'a>: Mapping<Predual, Codomain = F>
    where
        Self: 'a;

    fn preconjugate(&self) -> Self::Preconjugate<'_>;
}

/// Trait for mappings with a proximap map
///
/// The conjugate type has to implement [`ConvexMapping`], but a `Conjugable` mapping need
/// not be convex.
pub trait Prox<Domain: Space>: Mapping<Domain> {
    type Prox<'a>: Mapping<Domain, Codomain = Domain>
    where
        Self: 'a;

    /// Returns a proximal mapping with weight τ
    fn prox_mapping(&self, τ: Self::Codomain) -> Self::Prox<'_>;

    /// Calculate the proximal mapping with weight τ
    fn prox<I: Instance<Domain>>(&self, τ: Self::Codomain, z: I) -> Domain {
        self.prox_mapping(τ).apply(z)
    }

    /// Calculate the proximal mapping with weight τ in-place
    fn prox_mut<'b>(&self, τ: Self::Codomain, y: &'b mut Domain)
    where
        &'b mut Domain: InstanceMut<Domain>,
        Domain::Decomp: DecompositionMut<Domain>,
        for<'a> &'a Domain: Instance<Domain>,
    {
        *y = self.prox(τ, &*y);
    }
}

/// Constraint to the unit ball of the norm described by `E`.
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct NormConstraint<F: Float, E: NormExponent> {
    radius: F,
    norm: NormMapping<F, E>,
}

impl<Domain, E, F> ConvexMapping<Domain, F> for NormMapping<F, E>
where
    Domain: Normed<F>,
    E: NormExponent,
    F: Float,
    Self: Mapping<Domain, Codomain = F>,
{
}

impl<F, E, Domain> Mapping<Domain> for NormConstraint<F, E>
where
    Domain: Space + Norm<F, E>,
    F: Float,
    E: NormExponent,
{
    type Codomain = F;

    fn apply<I: Instance<Domain>>(&self, d: I) -> F {
        if d.eval(|x| x.norm(self.norm.exponent)) <= self.radius {
            F::ZERO
        } else {
            F::INFINITY
        }
    }
}

impl<Domain, E, F> ConvexMapping<Domain, F> for NormConstraint<F, E>
where
    Domain: Normed<F>,
    E: NormExponent,
    F: Float,
    Self: Mapping<Domain, Codomain = F>,
{
}

impl<E, F, Domain> Conjugable<Domain, F> for NormMapping<F, E>
where
    E: HasDualExponent,
    F: Float,
    Domain: HasDual<F> + Norm<F, E> + Normed<F>,
    <Domain as HasDual<F>>::DualSpace: Norm<F, E::DualExp>,
{
    type Conjugate<'a>
        = NormConstraint<F, E::DualExp>
    where
        Self: 'a;

    fn conjugate(&self) -> Self::Conjugate<'_> {
        NormConstraint {
            radius: F::ONE,
            norm: self.exponent.dual_exponent().as_mapping(),
        }
    }
}

impl<C, E, F, Domain> Conjugable<Domain, F> for Weighted<NormMapping<F, E>, C>
where
    C: Constant<Type = F>,
    E: HasDualExponent,
    F: Float,
    Domain: HasDual<F> + Norm<F, E> + Space,
    <Domain as HasDual<F>>::DualSpace: Norm<F, E::DualExp>,
{
    type Conjugate<'a>
        = NormConstraint<F, E::DualExp>
    where
        Self: 'a;

    fn conjugate(&self) -> Self::Conjugate<'_> {
        NormConstraint {
            radius: self.weight.value(),
            norm: self.base_fn.exponent.dual_exponent().as_mapping(),
        }
    }
}

impl<Domain, E, F> Prox<Domain> for NormConstraint<F, E>
where
    Domain: Space + Norm<F, E>,
    E: NormExponent,
    F: Float,
    NormProjection<F, E>: Mapping<Domain, Codomain = Domain>,
{
    type Prox<'a>
        = NormProjection<F, E>
    where
        Self: 'a;

    #[inline]
    fn prox_mapping(&self, _τ: Self::Codomain) -> Self::Prox<'_> {
        assert!(self.radius >= F::ZERO);
        NormProjection {
            radius: self.radius,
            exponent: self.norm.exponent,
        }
    }
}

/// Projection to the unit ball of the norm described by `E`.
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct NormProjection<F: Float, E: NormExponent> {
    radius: F,
    exponent: E,
}

/*
impl<F, Domain> Mapping<Domain> for NormProjection<F, L2>
where
    Domain : Space + Euclidean<F> + std::ops::MulAssign<F>,
    F : Float,
{
    type Codomain = Domain;

    fn apply<I : Instance<Domain>>(&self, d : I) -> Domain {
        d.own().proj_ball2(self.radius)
    }
}
*/

impl<F, E, Domain> Mapping<Domain> for NormProjection<F, E>
where
    Domain: Normed<F> + Projection<F, E>,
    F: Float,
    E: NormExponent,
{
    type Codomain = Domain;

    fn apply<I: Instance<Domain>>(&self, d: I) -> Domain {
        d.own().proj_ball(self.radius, self.exponent)
    }
}

/// The zero mapping
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct Zero<Domain: Space, F: Num>(PhantomData<(Domain, F)>);

impl<Domain: Space, F: Num> Zero<Domain, F> {
    pub fn new() -> Self {
        Zero(PhantomData)
    }
}

impl<Domain: Space, F: Num> Mapping<Domain> for Zero<Domain, F> {
    type Codomain = F;

    /// Compute the value of `self` at `x`.
    fn apply<I: Instance<Domain>>(&self, _x: I) -> Self::Codomain {
        F::ZERO
    }
}

impl<Domain: Normed<F>, F: Float> ConvexMapping<Domain, F> for Zero<Domain, F> {}

impl<Domain: HasDual<F>, F: Float> Conjugable<Domain, F> for Zero<Domain, F> {
    type Conjugate<'a>
        = ZeroIndicator<Domain::DualSpace, F>
    where
        Self: 'a;

    #[inline]
    fn conjugate(&self) -> Self::Conjugate<'_> {
        ZeroIndicator::new()
    }
}

impl<Domain, Predual, F: Float> Preconjugable<Domain, Predual, F> for Zero<Domain, F>
where
    Domain: Normed<F>,
    Predual: HasDual<F>,
{
    type Preconjugate<'a>
        = ZeroIndicator<Predual, F>
    where
        Self: 'a;

    #[inline]
    fn preconjugate(&self) -> Self::Preconjugate<'_> {
        ZeroIndicator::new()
    }
}

impl<Domain: Space + Clone, F: Num> Prox<Domain> for Zero<Domain, F> {
    type Prox<'a>
        = IdOp<Domain>
    where
        Self: 'a;

    #[inline]
    fn prox_mapping(&self, _τ: Self::Codomain) -> Self::Prox<'_> {
        IdOp::new()
    }
}

/// The zero indicator
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct ZeroIndicator<Domain: Space, F: Num>(PhantomData<(Domain, F)>);

impl<Domain: Space, F: Num> ZeroIndicator<Domain, F> {
    pub fn new() -> Self {
        ZeroIndicator(PhantomData)
    }
}

impl<Domain: Normed<F>, F: Float> Mapping<Domain> for ZeroIndicator<Domain, F> {
    type Codomain = F;

    /// Compute the value of `self` at `x`.
    fn apply<I: Instance<Domain>>(&self, x: I) -> Self::Codomain {
        x.eval(|x̃| if x̃.is_zero() { F::ZERO } else { F::INFINITY })
    }
}

impl<Domain: Normed<F>, F: Float> ConvexMapping<Domain, F> for ZeroIndicator<Domain, F> {
    fn factor_of_strong_convexity(&self) -> F {
        F::INFINITY
    }
}

impl<Domain: HasDual<F>, F: Float> Conjugable<Domain, F> for ZeroIndicator<Domain, F> {
    type Conjugate<'a>
        = Zero<Domain::DualSpace, F>
    where
        Self: 'a;

    #[inline]
    fn conjugate(&self) -> Self::Conjugate<'_> {
        Zero::new()
    }
}

impl<Domain, Predual, F: Float> Preconjugable<Domain, Predual, F> for ZeroIndicator<Domain, F>
where
    Domain: Normed<F>,
    Predual: HasDual<F>,
{
    type Preconjugate<'a>
        = Zero<Predual, F>
    where
        Self: 'a;

    #[inline]
    fn preconjugate(&self) -> Self::Preconjugate<'_> {
        Zero::new()
    }
}

/// The squared Euclidean norm divided by two
#[derive(Copy, Clone, Serialize, Deserialize)]
pub struct Norm222<F: Float>(PhantomData<F>);

impl</*Domain: Euclidean<F>,*/ F: Float> Norm222<F> {
    pub fn new() -> Self {
        Norm222(PhantomData)
    }
}

impl<X: Euclidean<F>, F: Float> Mapping<X> for Norm222<F> {
    type Codomain = F;

    /// Compute the value of `self` at `x`.
    fn apply<I: Instance<X>>(&self, x: I) -> Self::Codomain {
        x.eval(|z| z.norm2_squared() / F::TWO)
    }
}

impl<X: Euclidean<F>, F: Float> ConvexMapping<X, F> for Norm222<F> {
    fn factor_of_strong_convexity(&self) -> F {
        F::ONE
    }
}

impl<X: Euclidean<F>, F: Float> Conjugable<X, F> for Norm222<F> {
    type Conjugate<'a>
        = Self
    where
        Self: 'a;

    #[inline]
    fn conjugate(&self) -> Self::Conjugate<'_> {
        Self::new()
    }
}

impl<X: Euclidean<F>, F: Float> Preconjugable<X, X, F> for Norm222<F> {
    type Preconjugate<'a>
        = Self
    where
        Self: 'a;

    #[inline]
    fn preconjugate(&self) -> Self::Preconjugate<'_> {
        Self::new()
    }
}

impl<X, F> Prox<X> for Norm222<F>
where
    F: Float,
    X: Euclidean<F, Output = X>,
{
    type Prox<'a>
        = Scaled<F>
    where
        Self: 'a;

    fn prox_mapping(&self, τ: F) -> Self::Prox<'_> {
        Scaled(F::ONE / (F::ONE + τ))
    }
}

impl<X, F> DifferentiableImpl<X> for Norm222<F>
where
    F: Float,
    X: Euclidean<F, Output = X>,
{
    type Derivative = X;

    fn differential_impl<I: Instance<X>>(&self, x: I) -> X {
        x.own()
    }
}

impl<X, F> LipschitzDifferentiableImpl<X, L2> for Norm222<F>
where
    F: Float,
    X: Euclidean<F, Output = X>,
{
    type FloatType = F;

    fn diff_lipschitz_factor(&self, _: L2) -> DynResult<Self::FloatType> {
        Ok(F::ONE)
    }
}

mercurial