Initialise

Mon, 18 Nov 2019 11:00:54 -0500

author
Tuomo Valkonen <tuomov@iki.fi>
date
Mon, 18 Nov 2019 11:00:54 -0500
changeset 0
888dfd34d24a
child 1
5762e68842c9

Initialise

Manifest.toml file | annotate | diff | comparison | revisions
Project.toml file | annotate | diff | comparison | revisions
src/AlgTools.jl file | annotate | diff | comparison | revisions
src/Gradient.jl file | annotate | diff | comparison | revisions
src/Iterate.jl file | annotate | diff | comparison | revisions
src/LinkedLists.jl file | annotate | diff | comparison | revisions
src/StructTools.jl file | annotate | diff | comparison | revisions
src/Util.jl file | annotate | diff | comparison | revisions
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/Manifest.toml	Mon Nov 18 11:00:54 2019 -0500
@@ -0,0 +1,8 @@
+# This file is machine-generated - editing it directly is not advised
+
+[[Printf]]
+deps = ["Unicode"]
+uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
+
+[[Unicode]]
+uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/Project.toml	Mon Nov 18 11:00:54 2019 -0500
@@ -0,0 +1,7 @@
+name = "AlgTools"
+uuid = "c46e2e78-5339-41fd-a966-983ff60ab8e7"
+authors = ["tuomov "]
+version = "0.1.0"
+
+[deps]
+Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/AlgTools.jl	Mon Nov 18 11:00:54 2019 -0500
@@ -0,0 +1,9 @@
+module AlgTools
+
+include("LinkedLists.jl")
+include("StructTools.jl")
+include("Iterate.jl")
+include("Gradient.jl")
+include("Util.jl")
+
+end
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Gradient.jl	Mon Nov 18 11:00:54 2019 -0500
@@ -0,0 +1,185 @@
+########################
+# Discretised gradients
+########################
+
+module Gradient
+
+##############
+# Our exports
+##############
+
+export ∇₂!, ∇₂ᵀ!, ∇₂fold!,
+       ∇₂_norm₂₂_est, ∇₂_norm₂₂_est²,
+       ∇₂_norm₂∞_est, ∇₂_norm₂∞_est²,
+       ∇₂c!,
+       ∇₃!, ∇₃ᵀ!,
+       vec∇₃!, vec∇₃ᵀ!
+
+##################
+# Helper routines
+##################
+
+@inline function imfold₂′!(f_aa!, f_a0!, f_ab!,
+                           f_0a!, f_00!, f_0b!,
+                           f_ba!, f_b0!, f_bb!,
+                           n, m, state)
+    # First row
+    state = f_aa!(state, (1, 1))
+    for j = 2:m-1
+        state = f_a0!(state, (1, j))
+    end
+    state = f_ab!(state, (1, m))
+
+    # Middle rows
+    for i=2:n-1
+        state = f_0a!(state, (i, 1))
+        for j = 2:m-1
+            state = f_00!(state, (i, j))
+        end
+        state = f_0b!(state, (i, m))
+    end
+
+    # Last row
+    state = f_ba!(state, (n, 1))
+    for  j =2:m-1
+        state = f_b0!(state, (n, j))
+    end
+    return f_bb!(state, (n, m))
+end
+
+#########################
+# 2D forward differences
+#########################
+
+∇₂_norm₂₂_est² = 8
+∇₂_norm₂₂_est = √∇₂_norm₂₂_est²
+∇₂_norm₂∞_est² = 2
+∇₂_norm₂∞_est = √∇₂_norm₂∞_est²
+
+function ∇₂!(u₁, u₂, u)
+    @. @views begin
+        u₁[1:(end-1), :] = u[2:end, :] - u[1:(end-1), :]
+        u₁[end, :, :] = 0
+
+        u₂[:, 1:(end-1)] = u[:, 2:end] - u[:, 1:(end-1)]
+        u₂[:, end] = 0
+    end
+    return u₁, u₂
+end
+
+function ∇₂!(v, u)
+    ∇₂!(@view(v[1, :, :]), @view(v[2, :, :]), u)
+end
+
+@inline function ∇₂fold!(f!::Function, u, state)
+    g! = (state, pt) -> begin
+        (i, j) = pt
+        g = @inbounds [u[i+1, j]-u[i, j], u[i, j+1]-u[i, j]]
+        return f!(g, state, pt)
+    end
+    gr! = (state, pt) -> begin
+        (i, j) = pt
+        g = @inbounds [u[i+1, j]-u[i, j], 0.0]
+        return f!(g, state, pt)
+    end
+    gb! = (state, pt) -> begin
+        (i, j) = pt
+        g = @inbounds [0.0, u[i, j+1]-u[i, j]]
+        return f!(g, state, pt)
+    end
+    g0! = (state, pt) -> begin
+        return f!([0.0, 0.0], state, pt)
+    end
+    return imfold₂′!(g!, g!, gr!,
+                     g!, g!, gr!,
+                     gb!, gb!, g0!,
+                     size(u, 1), size(u, 2), state)
+end
+
+function ∇₂ᵀ!(v, v₁, v₂)
+    @. @views begin
+        v[2:(end-1), :] = v₁[1:(end-2), :] - v₁[2:(end-1), :]
+        v[1, :] = -v₁[1, :]
+        v[end, :] = v₁[end-1, :]
+
+        v[:, 2:(end-1)] += v₂[:, 1:(end-2)] - v₂[:, 2:(end-1)]
+        v[:, 1] += -v₂[:, 1]
+        v[:, end] += v₂[:, end-1]
+    end
+    return v
+end
+
+function ∇₂ᵀ!(u, v)
+    ∇₂ᵀ!(u, @view(v[1, :, :]), @view(v[2, :, :]))
+end
+
+##################################################
+# 2D central differences (partial implementation)
+##################################################
+
+function ∇₂c!(v, u)
+    @. @views begin
+        v[1, 2:(end-1), :] = (u[3:end, :] - u[1:(end-2), :])/2
+        v[1, end, :] = (u[end, :] - u[end-1, :])/2
+        v[1, 1, :] = (u[2, :] - u[1, :])/2
+
+        v[2, :, 2:(end-1)] = (u[:, 3:end] - u[:, 1:(end-2)])/2
+        v[2, :, end] = (u[:, end] - u[:, end-1])/2
+        v[2, :, 1] = (u[:, 2] - u[:, 1])/2
+    end
+end
+
+#########################
+# 3D forward differences
+#########################
+
+function ∇₃!(u₁,u₂,u₃,u)
+    @. @views begin
+        u₁[1:(end-1), :, :] = u[2:end, :, :] - u[1:(end-1), :, :]
+        u₁[end, :, :] = 0
+
+        u₂[:, 1:(end-1), :] = u[:, 2:end, :] - u[:, 1:(end-1), :]
+        u₂[:, end, :] = 0
+
+        u₃[:, :, 1:(end-1)] = u[:, :, 2:end] - u[:, :, 1:(end-1)]
+        u₃[:, :, end] = 0
+    end
+    return u₁, u₂, u₃
+end
+
+function ∇₃ᵀ!(v,v₁,v₂,v₃)
+    @. @views begin
+        v[2:(end-1), :, :] = v₁[1:(end-2), :, :] - v₁[2:(end-1), :, :]
+        v[1, :, :] = -v₁[1, :, :]
+        v[end, :, :] = v₁[end-1, :, :]
+
+        v[:, 2:(end-1), :] += v₂[:, 1:(end-2), :] - v₂[:, 2:(end-1), :]
+        v[:, 1, :] += -v₂[:, 1, :]
+        v[:, end, :] += v₂[:, end-1, :]
+
+        v[:, :, 2:(end-1)] += v₃[:, :, 1:(end-2)] - v₃[:, :, 2:(end-1)]
+        v[:, :, 1] += -v₃[:, :, 1]
+        v[:, :, end] += v₃[:, :, end-1]
+    end
+    return v
+end
+
+###########################################
+# 3D forward differences for vector fields
+###########################################
+
+function vec∇₃!(u₁,u₂,u₃,u)
+    @. @views for j=1:size(u, 1)
+        ∇₃!(u₁[j, :, :, :],u₂[j, :, :, :],u₃[j, :, :, :],u[j, :, :, :])
+    end
+    return u₁, u₂, u₃
+end
+
+function vec∇₃ᵀ!(u,v₁,v₂,v₃)
+    @. @views for j=1:size(u, 1)
+        ∇₃ᵀ!(u[j, :, :, :],v₁[j, :, :, :],v₂[j, :, :, :],v₃[j, :, :, :])
+    end
+    return u
+end
+
+end # Module
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Iterate.jl	Mon Nov 18 11:00:54 2019 -0500
@@ -0,0 +1,52 @@
+#################################
+# Tools for iterative algorithms
+#################################
+
+module Iterate
+
+using Printf
+
+##############
+# Our exports
+##############
+
+export simple_iterate
+
+########################################################################
+# Simple itertion function, calling `step()` `params.maxiter` times and 
+# reporting objective value every `params.verbose_iter` iterations.
+# The function `step` should take as its argument a function that itself
+# takes as its argument a function that calculates the objective value
+# on demand.
+########################################################################
+
+function simple_iterate(step :: Function,
+                        params::NamedTuple)
+    for iter=1:params.maxiter
+        step() do calc_objective
+            if params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0
+                v, _ = calc_objective()
+                @printf("%d/%d J=%f\n", iter, params.maxiter, v)
+                return true
+            end
+        end
+    end
+end
+
+function simple_iterate(step :: Function,
+                        datachannel::Channel{T},
+                        params::NamedTuple) where T
+    for iter=1:params.maxiter
+        d = take!(datachannel)
+        step(d) do calc_objective
+            if params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0
+                v, _ = calc_objective()
+                @printf("%d/%d J=%f\n", iter, params.maxiter, v)
+                return true
+            end
+        end
+    end
+end
+
+end
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/LinkedLists.jl	Mon Nov 18 11:00:54 2019 -0500
@@ -0,0 +1,52 @@
+####################################################################
+# Immutable linked list (different from the mutable lists of
+# https://github.com/ChrisRackauckas/LinkedLists.jl)
+####################################################################
+
+module LinkedLists
+
+##############
+# Our exports
+##############
+
+export LinkedListEntry,
+       LinkedList,
+       unfold_linked_list
+
+#############
+# Data types
+#############
+
+struct LinkedListEntry{T}
+    value :: T
+    next :: Union{LinkedListEntry{T},Nothing}
+end
+
+LinkedList{T} = Union{LinkedListEntry{T},Nothing}
+
+############
+# Functions
+############
+
+function Base.iterate(list::LinkedList{T}) where T
+    return Base.iterate(list, list)
+end
+
+function Base.iterate(list::LinkedList{T}, tail::Nothing) where T
+    return nothing
+end
+
+function Base.iterate(list::LinkedList{T}, tail::LinkedListEntry{T}) where T
+    return tail.value, tail.next
+end
+
+# Return the items in the list with the tail first
+function unfold_linked_list(list::LinkedList{T}) where T
+    res = []
+    for value ∈ list
+        push!(res, value)
+    end
+    return reverse(res)
+end
+
+end
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/StructTools.jl	Mon Nov 18 11:00:54 2019 -0500
@@ -0,0 +1,41 @@
+#################################
+# Tools for working with structs
+#################################
+
+module StructTools
+
+##############
+# Our exports
+##############
+
+export replace,
+       IterableStruct
+
+######################################################
+# Replace entries by those given as keyword arguments
+######################################################
+
+function replace(base::T; kw...) where T
+    k = keys(kw)
+    T([n ∈ k ? kw[n] : getfield(base, n) for n ∈ fieldnames(T)]...)
+end
+
+#########################################################
+# Iteration of structs.
+# One only needs to make them instance of IterableStruct
+#########################################################
+
+abstract type IterableStruct end
+
+function Base.iterate(s::T) where T <: IterableStruct
+    return Base.iterate(s, (0, fieldnames(T)))
+end
+
+function Base.iterate(
+    s::T, st::Tuple{Integer,NTuple{N,Symbol}}
+) where T <: IterableStruct where N
+    (i, k)=st
+    return (i<N ? (getfield(s, i+1), (i+1, k)) : nothing)
+end
+
+end
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Util.jl	Mon Nov 18 11:00:54 2019 -0500
@@ -0,0 +1,148 @@
+#########################
+# Some utility functions
+#########################
+
+module Util
+
+##############
+# Our exports
+##############
+
+export map_first_slice!,
+       reduce_first_slice,
+       norm₂,
+       γnorm₂,
+       norm₂w,
+       norm₂²,
+       norm₂w²,
+       norm₂₁,
+       γnorm₂₁,
+       dot,
+       mean,
+       proj_norm₂₁ball!,
+       curry,
+       ⬿
+
+########################
+# Functional programming
+#########################
+
+curry = (f::Function,y...)->(z...)->f(y...,z...)
+
+###############################
+# For working with NamedTuples
+###############################
+
+⬿ = merge
+
+######
+# map
+######
+
+@inline function map_first_slice!(f!, y)
+    for i in CartesianIndices(size(y)[2:end])
+        @inbounds f!(@view(y[:, i]))
+    end
+end
+
+@inline function map_first_slice!(x, f!, y)
+    for i in CartesianIndices(size(y)[2:end])
+        @inbounds f!(@view(x[:, i]), @view(y[:, i]))
+    end
+end
+
+@inline function reduce_first_slice(f, y; init=0.0)
+    accum=init
+    for i in CartesianIndices(size(y)[2:end])
+        @inbounds accum=f(accum, @view(y[:, i]))
+    end
+    return accum
+end
+
+###########################
+# Norms and inner products
+###########################
+
+@inline function dot(x, y)
+    @assert(length(x)==length(y))
+
+    accum=0
+    for i=1:length(y)
+        @inbounds accum += x[i]*y[i]
+    end
+    return accum
+end
+
+@inline function norm₂w²(y, w)
+    #Insane memory allocs
+    #return @inbounds sum(i -> y[i]*y[i]*w[i], 1:length(y))
+    accum=0
+    for i=1:length(y)
+        @inbounds accum=accum+y[i]*y[i]*w[i]
+    end
+    return accum
+end
+
+@inline function norm₂w(y, w)
+    return √(norm₂w²(y, w))
+end
+
+@inline function norm₂²(y)
+    #Insane memory allocs
+    #return @inbounds sum(i -> y[i]*y[i], 1:length(y))
+    accum=0
+    for i=1:length(y)
+        @inbounds accum=accum+y[i]*y[i]
+    end
+    return accum
+end
+
+@inline function norm₂(y)
+    return √(norm₂²(y))
+end
+
+@inline function γnorm₂(y, γ)
+    hubersq = xsq -> begin
+        x=√xsq
+        return if x > γ
+            x-γ/2
+        elseif x<-γ
+            -x-γ/2
+        else
+            xsq/(2γ)
+        end
+    end
+
+    if γ==0
+        return norm₂(y)
+    else
+        return hubersq(norm₂²(y))
+    end
+end
+
+function norm₂₁(y)
+    return reduce_first_slice((s, x) -> s+norm₂(x), y)
+end
+
+function γnorm₂₁(y,γ)
+    return reduce_first_slice((s, x) -> s+γnorm₂(x, γ), y)
+end
+
+function mean(v)
+    return sum(v)/prod(size(v))
+end
+
+@inline function proj_norm₂₁ball!(y, α)
+    α²=α*α
+    y′=reshape(y, (size(y, 1), prod(size(y)[2:end])))
+
+    @inbounds @simd for i=1:size(y′, 2)# in CartesianIndices(size(y)[2:end])
+        n² = norm₂²(@view(y′[:, i]))
+        if n²>α²
+            y′[:, i] .*= (α/√n²)
+        end
+    end
+end
+
+end # Module
+

mercurial