src/VectorMath.jl

Wed, 15 Dec 2021 01:09:09 +0200

author
Tuomo Valkonen <tuomov@iki.fi>
date
Wed, 15 Dec 2021 01:09:09 +0200
changeset 36
6dfa8001eed2
parent 33
a60d2f12ef93
permissions
-rw-r--r--

Implement ZipArrays

###########################
# Norms, projections, etc.
###########################

__precompile__()

module VectorMath

##############
# Our exports
##############

export norm₁,
       norm₂,
       γnorm₂,
       norm₂w,
       norm₂²,
       norm₂w²,
       norm₂₁,
       γnorm₂₁,
       dot,
       mean,
       proj_norm₂₁ball!,
       proj_nonneg!

###########################
# Norms and inner products
###########################

@inline function dot(x, y)
    @assert(length(x)==length(y))

    accum=0
    for i=1:length(y)
        @inbounds accum += x[i]*y[i]
    end
    return accum
end

@inline function norm₂w²(y, w)
    #Insane memory allocs
    #return @inbounds sum(i -> y[i]*y[i]*w[i], 1:length(y))
    accum=0
    for i=1:length(y)
        @inbounds accum=accum+y[i]*y[i]*w[i]
    end
    return accum
end

@inline function norm₂w(y, w)
    return √(norm₂w²(y, w))
end

@inline function norm₂²(y)
    #Insane memory allocs
    #return @inbounds sum(i -> y[i]*y[i], 1:length(y))
    accum=0
    for i=1:length(y)
        @inbounds accum=accum+y[i]*y[i]
    end
    return accum
end

@inline function norm₂(y)
    return √(norm₂²(y))
end

@inline function norm₁(y)
    accum=0
    for i=1:length(y)
        @inbounds accum=accum+abs(y[i])
    end
    return accum
end

@inline function γnorm₂(y, γ)
    hubersq = xsq -> begin
        x=√xsq
        return if x > γ
            x-γ/2
        elseif x<-γ
            -x-γ/2
        else
            xsq/(2γ)
        end
    end

    if γ==0
        return norm₂(y)
    else
        return hubersq(norm₂²(y))
    end
end

function norm₂₁(y)
    return reduce_first_slice((s, x) -> s+norm₂(x), y)
end

function γnorm₂₁(y,γ)
    return reduce_first_slice((s, x) -> s+γnorm₂(x, γ), y)
end

function mean(v)
    return sum(v)/prod(size(v))
end

##############
# Projections
##############

@inline function proj_norm₂₁ball!(y, α)
    α²=α*α

    if ndims(y)==3 && size(y, 1)==2
        @inbounds for i=1:size(y, 2)
            @simd for j=1:size(y, 3)
                n² = y[1,i,j]*y[1,i,j]+y[2,i,j]*y[2,i,j]
                if n²>α²
                    v = α/√n²
                    y[1, i, j] *= v
                    y[2, i, j] *= v
                end
            end
        end
    else
        y′=reshape(y, (size(y, 1), prod(size(y)[2:end])))

        @inbounds @simd for i=1:size(y′, 2)# in CartesianIndices(size(y)[2:end])
            n² = norm₂²(@view(y′[:, i]))
            if n²>α²
                @views y′[:, i] .*= (α/√n²)
            end
        end
    end
end

@inline function proj_nonneg!(y)
    @inbounds @simd for i=1:length(y)
        if y[i] < 0
            y[i] = 0
        end
    end
    return y
end

end # Module

mercurial