Wed, 22 Dec 2021 11:14:38 +0200
Add metaprogramming tools and fast multidimensional loops.
############################### # Fast multi-dimensional loops ############################### """ `module AlgTools.Loops` This module implements `for_zip`, `for_indices`, and `for_linspace` generated functions for loops amenable to a high level of optimisation by the compiler. """ module Loops using ..Metaprogramming ############## # Our exports ############## export Box, Grid, @Grid_str, for_zip, for_indices, for_linspace ##################################### # Looping over many iterators/arrays ##################################### """ `for_zip(f, a...)` With `dims` of type `Dims{N}`, loop over all `(i_1,…,i_N)` with `i_j ∈ 1:dims[j]`. Typically to be used as ```julia for_zip(a, b) do x, y # do something end ``` """ @generated function for_zip(f :: Function, a...) # Generate reference variables to avoid execution of potential # code inputs several times. aref = [gensym("aref_$(j)") for j ∈ 1:length(a)] # Generate length-checking asserts asserts = ( :( length($b)==M ) for b ∈ aref[2:end] ) quote # Assign input expressions to reference variables local $(lift_exprs(aref)) = (a...,) # Check that all inputs are of equal length M = length($(aref[1])) @assert $(all_exprs(asserts)) "Parameter lengths have to be equal" # Do the loop for i = 1:M f($([ :( $c[i] ) for c ∈ aref]...)) end end end ####################### # Looping over indices ####################### """ `for_indices(f, N, dims)` With `dims` of type `Dims{N}`, loop over all `(i_1,…,i_N)` with `i_j ∈ 1:dims[j]`. Typically to be used as ```julia for_indices(N, dims) do i_1, …, i_N # do something end ``` """ @generated function for_indices(f :: Function, dims :: Dims{N}) where N # This depends on numbers being passed as numbers, not Exprs! local i = [gensym("i_$(j)") for j ∈ 1:N] local di = [gensym("di_$(j)") for j ∈ 1:N] local expr = :( f($(i...),) ) for j = 1:N expr = quote for $(i[j]) = 1:$(di[j]) $expr end end end quote $(lift_exprs(di)) = dims $expr end end ########################################## # Looping over indices and transformation ########################################## Box{N, T} = NTuple{N, Tuple{T, T}} struct Grid{R} end """ `@Grid_str` A convenience macro for writing Grid{:gridtype} as Grid"gridtype" for use with `for_linspace`. """ macro Grid_str(s) :(Grid{$(Expr(:quote, Symbol(s)))}) end # For type-checking user's code, it would be better to wrap `for_linspace` # with a multiple dispatch type-case selector. However, the approach of # avoiding a wrapper and calling `gridalign` from `for_linspace` to get # the grid parements helps with inlining the code. @inline gridalign(:: Type{Type{Grid{:midpoint}}}) = (0.5, 0) @inline gridalign(:: Type{Type{Grid{:firstpoint}}}) = (1.0, 0) @inline gridalign(:: Type{Type{Grid{:lastpoint}}}) = (0.0, 0) @inline gridalign(:: Type{Type{Grid{:linspace}}}) = (1.0, 1) """ `for_linspace(f :: Function, dims :: Dims{N} , box :: Box{N, T}, grid :: Type{Grid{G}} = Grid"linspace")` With `box` of type `Box{N, T} = NTuple{N, Tuple{T, T}}` with `T <: AbstractFloat`, loop over all `((i_1,…,i_N), (x_1,…,x_N))` with `i_j ∈ 1:dims[j]` and each `x_i` sampled in equal-length subintervals of the interval `box[i]`. If `grid` is `Grid"midpoint"`, `Grid"firstpoint"`, `Grid"lastpoint"`, the point `x_i` is the mid-point, first point, or last point of the corresponding subinterval, and there are `dims[j]` subintervals. If `grid` is `Grid"linspace"`, there are `dims[j]-1` subintervals, and all the endpoints are generated. Typically to be used as ```julia for_linspace(dims, box) do (i_1, …, i_N), (x_1, …, x_N) # do something end ``` """ @inline @generated function for_linspace(f :: Function, dims :: Dims{N}, box :: Box{N, T}, grid :: Type{Grid{G}} = Grid"linspace") where {N, T <: AbstractFloat, G} shift, extrapoints = gridalign(grid) # Generate some local variables of primitive types, storable in registers i = [gensym("i_$j") for j ∈ 1:N] c = [gensym("c_$j") for j ∈ 1:N] d = [gensym("d_$j") for j ∈ 1:N] n = [gensym("n_$j") for j ∈ 1:N] # Function call x = (:( $(i[j])*$(c[j])+$(d[j]) ) for j ∈ 1:N) expr = :( f( ($(i...),), ($(x...),) ) ) # Main recursion for j = 1:N expr = quote for $(i[j]) = 1:$(n[j]) $expr end end end # Initialisation of linear transformation and variable references asserts = ( :( @assert ($(n[j]) > extrapoints) "Dimension $j too small" ) for j ∈ 1:N ) scales = ( :( (box[$j][2]-box[$j][1])/($(n[j])-extrapoints) ) for j ∈ 1:N ) translates = ( :( box[$j][1] - shift*$(c[j]) ) for j ∈ 1:N ) expr = quote # Set up local (register-storable) variables for dims, needed in the loop above local $(lift_exprs(n)) = dims # Check everything is ok $(sequence_exprs(asserts)) # Set up linear transformation, needed in the loop above local $(lift_exprs(c)) = $(lift_exprs(scales)) local $(lift_exprs(d)) = $(lift_exprs(translates)) # Execute the loop $expr end end end # module