Mon, 18 Nov 2019 11:00:54 -0500
Initialise
Manifest.toml | file | annotate | diff | comparison | revisions | |
Project.toml | file | annotate | diff | comparison | revisions | |
src/AlgTools.jl | file | annotate | diff | comparison | revisions | |
src/Gradient.jl | file | annotate | diff | comparison | revisions | |
src/Iterate.jl | file | annotate | diff | comparison | revisions | |
src/LinkedLists.jl | file | annotate | diff | comparison | revisions | |
src/StructTools.jl | file | annotate | diff | comparison | revisions | |
src/Util.jl | file | annotate | diff | comparison | revisions |
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Manifest.toml Mon Nov 18 11:00:54 2019 -0500 @@ -0,0 +1,8 @@ +# This file is machine-generated - editing it directly is not advised + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Project.toml Mon Nov 18 11:00:54 2019 -0500 @@ -0,0 +1,7 @@ +name = "AlgTools" +uuid = "c46e2e78-5339-41fd-a966-983ff60ab8e7" +authors = ["tuomov "] +version = "0.1.0" + +[deps] +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/AlgTools.jl Mon Nov 18 11:00:54 2019 -0500 @@ -0,0 +1,9 @@ +module AlgTools + +include("LinkedLists.jl") +include("StructTools.jl") +include("Iterate.jl") +include("Gradient.jl") +include("Util.jl") + +end
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Gradient.jl Mon Nov 18 11:00:54 2019 -0500 @@ -0,0 +1,185 @@ +######################## +# Discretised gradients +######################## + +module Gradient + +############## +# Our exports +############## + +export ∇₂!, ∇₂ᵀ!, ∇₂fold!, + ∇₂_norm₂₂_est, ∇₂_norm₂₂_est², + ∇₂_norm₂∞_est, ∇₂_norm₂∞_est², + ∇₂c!, + ∇₃!, ∇₃ᵀ!, + vec∇₃!, vec∇₃ᵀ! + +################## +# Helper routines +################## + +@inline function imfold₂′!(f_aa!, f_a0!, f_ab!, + f_0a!, f_00!, f_0b!, + f_ba!, f_b0!, f_bb!, + n, m, state) + # First row + state = f_aa!(state, (1, 1)) + for j = 2:m-1 + state = f_a0!(state, (1, j)) + end + state = f_ab!(state, (1, m)) + + # Middle rows + for i=2:n-1 + state = f_0a!(state, (i, 1)) + for j = 2:m-1 + state = f_00!(state, (i, j)) + end + state = f_0b!(state, (i, m)) + end + + # Last row + state = f_ba!(state, (n, 1)) + for j =2:m-1 + state = f_b0!(state, (n, j)) + end + return f_bb!(state, (n, m)) +end + +######################### +# 2D forward differences +######################### + +∇₂_norm₂₂_est² = 8 +∇₂_norm₂₂_est = √∇₂_norm₂₂_est² +∇₂_norm₂∞_est² = 2 +∇₂_norm₂∞_est = √∇₂_norm₂∞_est² + +function ∇₂!(u₁, u₂, u) + @. @views begin + u₁[1:(end-1), :] = u[2:end, :] - u[1:(end-1), :] + u₁[end, :, :] = 0 + + u₂[:, 1:(end-1)] = u[:, 2:end] - u[:, 1:(end-1)] + u₂[:, end] = 0 + end + return u₁, u₂ +end + +function ∇₂!(v, u) + ∇₂!(@view(v[1, :, :]), @view(v[2, :, :]), u) +end + +@inline function ∇₂fold!(f!::Function, u, state) + g! = (state, pt) -> begin + (i, j) = pt + g = @inbounds [u[i+1, j]-u[i, j], u[i, j+1]-u[i, j]] + return f!(g, state, pt) + end + gr! = (state, pt) -> begin + (i, j) = pt + g = @inbounds [u[i+1, j]-u[i, j], 0.0] + return f!(g, state, pt) + end + gb! = (state, pt) -> begin + (i, j) = pt + g = @inbounds [0.0, u[i, j+1]-u[i, j]] + return f!(g, state, pt) + end + g0! = (state, pt) -> begin + return f!([0.0, 0.0], state, pt) + end + return imfold₂′!(g!, g!, gr!, + g!, g!, gr!, + gb!, gb!, g0!, + size(u, 1), size(u, 2), state) +end + +function ∇₂ᵀ!(v, v₁, v₂) + @. @views begin + v[2:(end-1), :] = v₁[1:(end-2), :] - v₁[2:(end-1), :] + v[1, :] = -v₁[1, :] + v[end, :] = v₁[end-1, :] + + v[:, 2:(end-1)] += v₂[:, 1:(end-2)] - v₂[:, 2:(end-1)] + v[:, 1] += -v₂[:, 1] + v[:, end] += v₂[:, end-1] + end + return v +end + +function ∇₂ᵀ!(u, v) + ∇₂ᵀ!(u, @view(v[1, :, :]), @view(v[2, :, :])) +end + +################################################## +# 2D central differences (partial implementation) +################################################## + +function ∇₂c!(v, u) + @. @views begin + v[1, 2:(end-1), :] = (u[3:end, :] - u[1:(end-2), :])/2 + v[1, end, :] = (u[end, :] - u[end-1, :])/2 + v[1, 1, :] = (u[2, :] - u[1, :])/2 + + v[2, :, 2:(end-1)] = (u[:, 3:end] - u[:, 1:(end-2)])/2 + v[2, :, end] = (u[:, end] - u[:, end-1])/2 + v[2, :, 1] = (u[:, 2] - u[:, 1])/2 + end +end + +######################### +# 3D forward differences +######################### + +function ∇₃!(u₁,u₂,u₃,u) + @. @views begin + u₁[1:(end-1), :, :] = u[2:end, :, :] - u[1:(end-1), :, :] + u₁[end, :, :] = 0 + + u₂[:, 1:(end-1), :] = u[:, 2:end, :] - u[:, 1:(end-1), :] + u₂[:, end, :] = 0 + + u₃[:, :, 1:(end-1)] = u[:, :, 2:end] - u[:, :, 1:(end-1)] + u₃[:, :, end] = 0 + end + return u₁, u₂, u₃ +end + +function ∇₃ᵀ!(v,v₁,v₂,v₃) + @. @views begin + v[2:(end-1), :, :] = v₁[1:(end-2), :, :] - v₁[2:(end-1), :, :] + v[1, :, :] = -v₁[1, :, :] + v[end, :, :] = v₁[end-1, :, :] + + v[:, 2:(end-1), :] += v₂[:, 1:(end-2), :] - v₂[:, 2:(end-1), :] + v[:, 1, :] += -v₂[:, 1, :] + v[:, end, :] += v₂[:, end-1, :] + + v[:, :, 2:(end-1)] += v₃[:, :, 1:(end-2)] - v₃[:, :, 2:(end-1)] + v[:, :, 1] += -v₃[:, :, 1] + v[:, :, end] += v₃[:, :, end-1] + end + return v +end + +########################################### +# 3D forward differences for vector fields +########################################### + +function vec∇₃!(u₁,u₂,u₃,u) + @. @views for j=1:size(u, 1) + ∇₃!(u₁[j, :, :, :],u₂[j, :, :, :],u₃[j, :, :, :],u[j, :, :, :]) + end + return u₁, u₂, u₃ +end + +function vec∇₃ᵀ!(u,v₁,v₂,v₃) + @. @views for j=1:size(u, 1) + ∇₃ᵀ!(u[j, :, :, :],v₁[j, :, :, :],v₂[j, :, :, :],v₃[j, :, :, :]) + end + return u +end + +end # Module
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Iterate.jl Mon Nov 18 11:00:54 2019 -0500 @@ -0,0 +1,52 @@ +################################# +# Tools for iterative algorithms +################################# + +module Iterate + +using Printf + +############## +# Our exports +############## + +export simple_iterate + +######################################################################## +# Simple itertion function, calling `step()` `params.maxiter` times and +# reporting objective value every `params.verbose_iter` iterations. +# The function `step` should take as its argument a function that itself +# takes as its argument a function that calculates the objective value +# on demand. +######################################################################## + +function simple_iterate(step :: Function, + params::NamedTuple) + for iter=1:params.maxiter + step() do calc_objective + if params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0 + v, _ = calc_objective() + @printf("%d/%d J=%f\n", iter, params.maxiter, v) + return true + end + end + end +end + +function simple_iterate(step :: Function, + datachannel::Channel{T}, + params::NamedTuple) where T + for iter=1:params.maxiter + d = take!(datachannel) + step(d) do calc_objective + if params.verbose_iter!=0 && mod(iter, params.verbose_iter) == 0 + v, _ = calc_objective() + @printf("%d/%d J=%f\n", iter, params.maxiter, v) + return true + end + end + end +end + +end +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/LinkedLists.jl Mon Nov 18 11:00:54 2019 -0500 @@ -0,0 +1,52 @@ +#################################################################### +# Immutable linked list (different from the mutable lists of +# https://github.com/ChrisRackauckas/LinkedLists.jl) +#################################################################### + +module LinkedLists + +############## +# Our exports +############## + +export LinkedListEntry, + LinkedList, + unfold_linked_list + +############# +# Data types +############# + +struct LinkedListEntry{T} + value :: T + next :: Union{LinkedListEntry{T},Nothing} +end + +LinkedList{T} = Union{LinkedListEntry{T},Nothing} + +############ +# Functions +############ + +function Base.iterate(list::LinkedList{T}) where T + return Base.iterate(list, list) +end + +function Base.iterate(list::LinkedList{T}, tail::Nothing) where T + return nothing +end + +function Base.iterate(list::LinkedList{T}, tail::LinkedListEntry{T}) where T + return tail.value, tail.next +end + +# Return the items in the list with the tail first +function unfold_linked_list(list::LinkedList{T}) where T + res = [] + for value ∈ list + push!(res, value) + end + return reverse(res) +end + +end
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/StructTools.jl Mon Nov 18 11:00:54 2019 -0500 @@ -0,0 +1,41 @@ +################################# +# Tools for working with structs +################################# + +module StructTools + +############## +# Our exports +############## + +export replace, + IterableStruct + +###################################################### +# Replace entries by those given as keyword arguments +###################################################### + +function replace(base::T; kw...) where T + k = keys(kw) + T([n ∈ k ? kw[n] : getfield(base, n) for n ∈ fieldnames(T)]...) +end + +######################################################### +# Iteration of structs. +# One only needs to make them instance of IterableStruct +######################################################### + +abstract type IterableStruct end + +function Base.iterate(s::T) where T <: IterableStruct + return Base.iterate(s, (0, fieldnames(T))) +end + +function Base.iterate( + s::T, st::Tuple{Integer,NTuple{N,Symbol}} +) where T <: IterableStruct where N + (i, k)=st + return (i<N ? (getfield(s, i+1), (i+1, k)) : nothing) +end + +end
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Util.jl Mon Nov 18 11:00:54 2019 -0500 @@ -0,0 +1,148 @@ +######################### +# Some utility functions +######################### + +module Util + +############## +# Our exports +############## + +export map_first_slice!, + reduce_first_slice, + norm₂, + γnorm₂, + norm₂w, + norm₂², + norm₂w², + norm₂₁, + γnorm₂₁, + dot, + mean, + proj_norm₂₁ball!, + curry, + ⬿ + +######################## +# Functional programming +######################### + +curry = (f::Function,y...)->(z...)->f(y...,z...) + +############################### +# For working with NamedTuples +############################### + +⬿ = merge + +###### +# map +###### + +@inline function map_first_slice!(f!, y) + for i in CartesianIndices(size(y)[2:end]) + @inbounds f!(@view(y[:, i])) + end +end + +@inline function map_first_slice!(x, f!, y) + for i in CartesianIndices(size(y)[2:end]) + @inbounds f!(@view(x[:, i]), @view(y[:, i])) + end +end + +@inline function reduce_first_slice(f, y; init=0.0) + accum=init + for i in CartesianIndices(size(y)[2:end]) + @inbounds accum=f(accum, @view(y[:, i])) + end + return accum +end + +########################### +# Norms and inner products +########################### + +@inline function dot(x, y) + @assert(length(x)==length(y)) + + accum=0 + for i=1:length(y) + @inbounds accum += x[i]*y[i] + end + return accum +end + +@inline function norm₂w²(y, w) + #Insane memory allocs + #return @inbounds sum(i -> y[i]*y[i]*w[i], 1:length(y)) + accum=0 + for i=1:length(y) + @inbounds accum=accum+y[i]*y[i]*w[i] + end + return accum +end + +@inline function norm₂w(y, w) + return √(norm₂w²(y, w)) +end + +@inline function norm₂²(y) + #Insane memory allocs + #return @inbounds sum(i -> y[i]*y[i], 1:length(y)) + accum=0 + for i=1:length(y) + @inbounds accum=accum+y[i]*y[i] + end + return accum +end + +@inline function norm₂(y) + return √(norm₂²(y)) +end + +@inline function γnorm₂(y, γ) + hubersq = xsq -> begin + x=√xsq + return if x > γ + x-γ/2 + elseif x<-γ + -x-γ/2 + else + xsq/(2γ) + end + end + + if γ==0 + return norm₂(y) + else + return hubersq(norm₂²(y)) + end +end + +function norm₂₁(y) + return reduce_first_slice((s, x) -> s+norm₂(x), y) +end + +function γnorm₂₁(y,γ) + return reduce_first_slice((s, x) -> s+γnorm₂(x, γ), y) +end + +function mean(v) + return sum(v)/prod(size(v)) +end + +@inline function proj_norm₂₁ball!(y, α) + α²=α*α + y′=reshape(y, (size(y, 1), prod(size(y)[2:end]))) + + @inbounds @simd for i=1:size(y′, 2)# in CartesianIndices(size(y)[2:end]) + n² = norm₂²(@view(y′[:, i])) + if n²>α² + y′[:, i] .*= (α/√n²) + end + end +end + +end # Module +