Add metaprogramming tools and fast multidimensional loops. draft

Wed, 22 Dec 2021 11:14:38 +0200

author
Tuomo Valkonen <tuomov@iki.fi>
date
Wed, 22 Dec 2021 11:14:38 +0200
changeset 35
d881275c6564
parent 34
22a64e826ee7
child 36
6dfa8001eed2

Add metaprogramming tools and fast multidimensional loops.

src/AlgTools.jl file | annotate | diff | comparison | revisions
src/Loops.jl file | annotate | diff | comparison | revisions
src/Metaprogramming.jl file | annotate | diff | comparison | revisions
--- a/src/AlgTools.jl	Tue Dec 07 11:41:07 2021 +0200
+++ b/src/AlgTools.jl	Wed Dec 22 11:14:38 2021 +0200
@@ -5,18 +5,11 @@
 
 `module AlgTools`
 
-This module has the submodules:
-- `FunctionalProgramming`
-- `StructTools`
-- `LinkedLists`
-- `Logger`
-- `Iterate`
-- `VectorMath`
-- `Util`
-- `ThreadUtil`
-- `Comms`
-- `LinOps`
-- `DifferentiableFN`
+This module implements useful code for implementing iterative algorithms.
+For further documentation, see the submodules
+`FunctionalProgramming`, `StructTools`, `LinkedLists`, `Logger`, `Iterate`,
+`VectorMath`, `Util`, `ThreadUtil`, `Comms`, `LinOps`, `DifferentiableFN`,
+`Metaprogramming`, and `Loops`.
 """
 module AlgTools
 
@@ -31,5 +24,7 @@
 include("Comms.jl")
 include("LinOps.jl")
 include("DifferentiableFN.jl")
+include("Metaprogramming.jl")
+include("Loops.jl")
 
 end
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Loops.jl	Wed Dec 22 11:14:38 2021 +0200
@@ -0,0 +1,184 @@
+###############################
+# Fast multi-dimensional loops
+###############################
+
+"""
+`module AlgTools.Loops`
+
+This module implements `for_zip`, `for_indices`, and `for_linspace` generated
+functions for loops amenable to a high level of optimisation by the compiler.
+"""
+module Loops
+
+using ..Metaprogramming
+
+##############
+# Our exports
+##############
+
+export Box,
+       Grid,
+       @Grid_str,
+       for_zip,
+       for_indices,
+       for_linspace
+
+#####################################
+# Looping over many iterators/arrays
+#####################################
+
+"""
+`for_zip(f, a...)`
+
+With `dims` of type `Dims{N}`, loop over all `(i_1,…,i_N)` with `i_j ∈ 1:dims[j]`.
+Typically to be used as
+
+```julia
+for_zip(a, b) do x, y
+    # do something
+end
+```
+"""
+@generated function for_zip(f :: Function, a...)
+    # Generate reference variables to avoid execution of potential
+    # code inputs several times.
+    aref = [gensym("aref_$(j)") for j ∈ 1:length(a)]
+    # Generate length-checking asserts
+    asserts = ( :( length($b)==M ) for b ∈ aref[2:end] )
+
+    quote
+        # Assign input expressions to reference variables
+        local $(lift_exprs(aref)) = (a...,)
+        # Check that all inputs are of equal length
+        M = length($(aref[1]))
+        @assert $(all_exprs(asserts)) "Parameter lengths have to be equal"
+        # Do the loop
+        for i = 1:M
+            f($([ :( $c[i] ) for c ∈ aref]...))
+        end
+    end
+end
+
+#######################
+# Looping over indices
+#######################
+
+"""
+`for_indices(f, N, dims)`
+
+With `dims` of type `Dims{N}`, loop over all `(i_1,…,i_N)` with `i_j ∈ 1:dims[j]`.
+Typically to be used as
+
+```julia
+for_indices(N, dims) do i_1, …, i_N
+    # do something
+end
+```
+"""
+@generated function for_indices(f :: Function, dims :: Dims{N}) where N
+    # This depends on numbers being passed as numbers, not Exprs!
+    local i = [gensym("i_$(j)") for j ∈ 1:N]
+    local di = [gensym("di_$(j)") for j ∈ 1:N]
+
+    local expr = :( f($(i...),) )
+
+    for j = 1:N
+        expr = quote
+            for $(i[j]) = 1:$(di[j])
+                $expr
+            end
+        end
+    end
+    quote
+        $(lift_exprs(di)) = dims
+        $expr
+    end
+end
+
+
+##########################################
+# Looping over indices and transformation
+##########################################
+
+Box{N, T} = NTuple{N, Tuple{T, T}}
+struct Grid{R} end
+
+"""
+`@Grid_str`
+
+A convenience macro for writing Grid{:gridtype} as Grid"gridtype" for use with `for_linspace`.
+"""
+macro Grid_str(s)
+    :(Grid{$(Expr(:quote, Symbol(s)))})
+end
+
+# For type-checking user's code, it would be better to wrap `for_linspace`
+# with a multiple dispatch type-case selector. However, the approach of
+# avoiding a wrapper and calling `gridalign` from `for_linspace` to get
+# the grid parements helps with inlining the code.
+@inline gridalign(:: Type{Type{Grid{:midpoint}}}) = (0.5, 0)
+@inline gridalign(:: Type{Type{Grid{:firstpoint}}}) = (1.0, 0)
+@inline gridalign(:: Type{Type{Grid{:lastpoint}}}) = (0.0, 0)
+@inline gridalign(:: Type{Type{Grid{:linspace}}}) = (1.0, 1)
+
+"""
+`for_linspace(f :: Function, dims :: Dims{N} , box :: Box{N, T}, grid :: Type{Grid{G}} = Grid"linspace")`
+
+With `box` of type `Box{N, T} = NTuple{N, Tuple{T, T}}` with `T <: AbstractFloat`,
+loop over all `((i_1,…,i_N), (x_1,…,x_N))`  with `i_j ∈ 1:dims[j]` and each `x_i`
+sampled in equal-length subintervals of the interval `box[i]`.
+If `grid` is `Grid"midpoint"`, `Grid"firstpoint"`, `Grid"lastpoint"`, the point `x_i` is
+the mid-point, first point, or last point of the corresponding subinterval, and there
+are `dims[j]` subintervals. If `grid` is `Grid"linspace"`, there are `dims[j]-1` subintervals,
+and all the endpoints are generated. Typically to be used as
+
+```julia
+for_linspace(dims, box) do (i_1, …, i_N), (x_1, …, x_N)
+    # do something
+end
+```
+"""
+@inline @generated function
+  for_linspace(f :: Function, dims :: Dims{N}, box :: Box{N, T},
+               grid :: Type{Grid{G}} = Grid"linspace") where {N, T <: AbstractFloat, G}
+
+    shift, extrapoints = gridalign(grid)
+
+    # Generate some local variables of primitive types, storable in registers
+    i = [gensym("i_$j") for j ∈ 1:N]
+    c = [gensym("c_$j") for j ∈ 1:N]
+    d = [gensym("d_$j") for j ∈ 1:N]
+    n = [gensym("n_$j") for j ∈ 1:N]
+
+    # Function call
+    x = (:( $(i[j])*$(c[j])+$(d[j]) ) for j ∈ 1:N)
+    expr = :( f( ($(i...),), ($(x...),) ) )
+
+    # Main recursion
+    for j = 1:N
+        expr = quote
+            for $(i[j]) = 1:$(n[j])
+                $expr
+            end
+        end
+    end
+    
+    # Initialisation of linear transformation and variable references
+    asserts = ( :( @assert ($(n[j]) > extrapoints) "Dimension $j too small" ) for j ∈ 1:N )
+    scales = ( :( (box[$j][2]-box[$j][1])/($(n[j])-extrapoints) ) for j ∈ 1:N )
+    translates = ( :( box[$j][1] - shift*$(c[j]) ) for j ∈ 1:N )
+    expr = quote
+        # Set up local (register-storable) variables for dims, needed in the loop above
+        local $(lift_exprs(n)) = dims
+        # Check everything is ok
+        $(sequence_exprs(asserts))
+        # Set up linear transformation, needed in the loop above
+        local $(lift_exprs(c)) = $(lift_exprs(scales))
+        local $(lift_exprs(d)) = $(lift_exprs(translates))
+        # Execute the loop
+        $expr
+    end
+end
+
+
+end # module
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Metaprogramming.jl	Wed Dec 22 11:14:38 2021 +0200
@@ -0,0 +1,124 @@
+##################
+# Metaprogramming
+##################
+
+"""
+`module AlgTools.MetaProgramming`
+
+This module implements helpers for metaprogramming, used, e.g., by `Loops`.
+It includes the macros `@foldl`, `@all`, and `@map` and the metaprogramming
+functions `foldl_exprs`, `all_exprs`, `map_exprs`, `lift_exprs`, and `sequence_exprs`.
+"""
+
+module Metaprogramming
+
+##############
+# Our exports
+##############
+
+export @foldl,
+       @all,
+       @map,
+       foldl_exprs,
+       all_exprs,
+       map_exprs,
+       lift_exprs,
+       sequence_exprs
+
+###########################################
+# Support functions for macro construction
+###########################################
+
+""""
+`foldl_exprs(f, a)`
+
+Does `foldl` on the `Expr`s in `a` with the `f`, a `Function` or another `Expr`
+representing a function. Returns new `Expr` with the expanded `foldl` operation.
+"""
+function foldl_exprs(f :: Expr, a)
+   return foldl((b, c) -> :($f($b, $c)), a)
+end
+
+function foldl_exprs(f :: Function, a)
+   return foldl((b, c) -> :($(f(b, c))), a)
+end
+
+""""
+`all_exprs(a...)`
+
+Does `all` on the `Expr`s in `a`.
+Returns new `Expr` with the expanded `all` operation.
+"""
+function all_exprs(a)
+   return foldl((b, c) -> :( $b && $c ), a)
+end
+
+""""
+`map_exprs(f, a...)`
+
+Does `map` on the `Expr`s in `a` with the `f`, a `Function` or another `Expr`
+representing a function. Returns a new `NTuple` with `Expr`s for the result of
+applying a to all elements of `a`.
+"""
+function map_exprs(f :: Expr , a)
+   return ([ :( $f($b) ) for b ∈ a ]...,)
+end
+
+function map_exprs(f :: Function , a)
+   return ([ :( $(f(b)) ) for b ∈ a ]...,)
+end
+
+""""
+`lift_exprs(a)`
+
+Turns a `Tuple` `a` of `Expr`s into an `Expr` containing `Tuple`.
+"""
+function lift_exprs(a)
+    return :( ($(a...),) )
+end
+
+"""
+`sequence_exprs(a)`
+
+Turns a `Tuple` `a` of `Expr`s into an `Expr` with each original expr on its own line (separated by `;`).
+"""
+function sequence_exprs(a)
+    return foldl_exprs((b, c) -> :( $b; $c ), a)
+end
+
+
+#########
+# Macros
+#########
+
+""""
+`@foldl(f, a...)`
+
+A basic macro version of `foldl`.
+"""
+macro foldl(f, a...)
+    return foldl_exprs(f, a)
+end
+
+""""
+`@all(a...)`
+
+A macro version of `all`.
+"""
+macro all(a...)
+    # Implementation based on @foldl:
+    #return :( @foldl((b, c) -> b && c, $(a...))  )
+    # Direct implementation produces prettier generated code:
+    return all_exprs(a)
+end
+
+""""
+`map(f, a...)`
+
+A macro version of `map`.
+"""
+macro map(f, a...)
+    return lift_exprs(map_exprs(f, a))
+end
+
+end # module

mercurial