# HG changeset patch # User Tuomo Valkonen # Date 1640164478 -7200 # Node ID d881275c65648552c1f29df322656a39082c9f05 # Parent 22a64e826ee7289faeb7481bb810df7992578d56 Add metaprogramming tools and fast multidimensional loops. diff -r 22a64e826ee7 -r d881275c6564 src/AlgTools.jl --- a/src/AlgTools.jl Tue Dec 07 11:41:07 2021 +0200 +++ b/src/AlgTools.jl Wed Dec 22 11:14:38 2021 +0200 @@ -5,18 +5,11 @@ `module AlgTools` -This module has the submodules: -- `FunctionalProgramming` -- `StructTools` -- `LinkedLists` -- `Logger` -- `Iterate` -- `VectorMath` -- `Util` -- `ThreadUtil` -- `Comms` -- `LinOps` -- `DifferentiableFN` +This module implements useful code for implementing iterative algorithms. +For further documentation, see the submodules +`FunctionalProgramming`, `StructTools`, `LinkedLists`, `Logger`, `Iterate`, +`VectorMath`, `Util`, `ThreadUtil`, `Comms`, `LinOps`, `DifferentiableFN`, +`Metaprogramming`, and `Loops`. """ module AlgTools @@ -31,5 +24,7 @@ include("Comms.jl") include("LinOps.jl") include("DifferentiableFN.jl") +include("Metaprogramming.jl") +include("Loops.jl") end diff -r 22a64e826ee7 -r d881275c6564 src/Loops.jl --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Loops.jl Wed Dec 22 11:14:38 2021 +0200 @@ -0,0 +1,184 @@ +############################### +# Fast multi-dimensional loops +############################### + +""" +`module AlgTools.Loops` + +This module implements `for_zip`, `for_indices`, and `for_linspace` generated +functions for loops amenable to a high level of optimisation by the compiler. +""" +module Loops + +using ..Metaprogramming + +############## +# Our exports +############## + +export Box, + Grid, + @Grid_str, + for_zip, + for_indices, + for_linspace + +##################################### +# Looping over many iterators/arrays +##################################### + +""" +`for_zip(f, a...)` + +With `dims` of type `Dims{N}`, loop over all `(i_1,…,i_N)` with `i_j ∈ 1:dims[j]`. +Typically to be used as + +```julia +for_zip(a, b) do x, y + # do something +end +``` +""" +@generated function for_zip(f :: Function, a...) + # Generate reference variables to avoid execution of potential + # code inputs several times. + aref = [gensym("aref_$(j)") for j ∈ 1:length(a)] + # Generate length-checking asserts + asserts = ( :( length($b)==M ) for b ∈ aref[2:end] ) + + quote + # Assign input expressions to reference variables + local $(lift_exprs(aref)) = (a...,) + # Check that all inputs are of equal length + M = length($(aref[1])) + @assert $(all_exprs(asserts)) "Parameter lengths have to be equal" + # Do the loop + for i = 1:M + f($([ :( $c[i] ) for c ∈ aref]...)) + end + end +end + +####################### +# Looping over indices +####################### + +""" +`for_indices(f, N, dims)` + +With `dims` of type `Dims{N}`, loop over all `(i_1,…,i_N)` with `i_j ∈ 1:dims[j]`. +Typically to be used as + +```julia +for_indices(N, dims) do i_1, …, i_N + # do something +end +``` +""" +@generated function for_indices(f :: Function, dims :: Dims{N}) where N + # This depends on numbers being passed as numbers, not Exprs! + local i = [gensym("i_$(j)") for j ∈ 1:N] + local di = [gensym("di_$(j)") for j ∈ 1:N] + + local expr = :( f($(i...),) ) + + for j = 1:N + expr = quote + for $(i[j]) = 1:$(di[j]) + $expr + end + end + end + quote + $(lift_exprs(di)) = dims + $expr + end +end + + +########################################## +# Looping over indices and transformation +########################################## + +Box{N, T} = NTuple{N, Tuple{T, T}} +struct Grid{R} end + +""" +`@Grid_str` + +A convenience macro for writing Grid{:gridtype} as Grid"gridtype" for use with `for_linspace`. +""" +macro Grid_str(s) + :(Grid{$(Expr(:quote, Symbol(s)))}) +end + +# For type-checking user's code, it would be better to wrap `for_linspace` +# with a multiple dispatch type-case selector. However, the approach of +# avoiding a wrapper and calling `gridalign` from `for_linspace` to get +# the grid parements helps with inlining the code. +@inline gridalign(:: Type{Type{Grid{:midpoint}}}) = (0.5, 0) +@inline gridalign(:: Type{Type{Grid{:firstpoint}}}) = (1.0, 0) +@inline gridalign(:: Type{Type{Grid{:lastpoint}}}) = (0.0, 0) +@inline gridalign(:: Type{Type{Grid{:linspace}}}) = (1.0, 1) + +""" +`for_linspace(f :: Function, dims :: Dims{N} , box :: Box{N, T}, grid :: Type{Grid{G}} = Grid"linspace")` + +With `box` of type `Box{N, T} = NTuple{N, Tuple{T, T}}` with `T <: AbstractFloat`, +loop over all `((i_1,…,i_N), (x_1,…,x_N))` with `i_j ∈ 1:dims[j]` and each `x_i` +sampled in equal-length subintervals of the interval `box[i]`. +If `grid` is `Grid"midpoint"`, `Grid"firstpoint"`, `Grid"lastpoint"`, the point `x_i` is +the mid-point, first point, or last point of the corresponding subinterval, and there +are `dims[j]` subintervals. If `grid` is `Grid"linspace"`, there are `dims[j]-1` subintervals, +and all the endpoints are generated. Typically to be used as + +```julia +for_linspace(dims, box) do (i_1, …, i_N), (x_1, …, x_N) + # do something +end +``` +""" +@inline @generated function + for_linspace(f :: Function, dims :: Dims{N}, box :: Box{N, T}, + grid :: Type{Grid{G}} = Grid"linspace") where {N, T <: AbstractFloat, G} + + shift, extrapoints = gridalign(grid) + + # Generate some local variables of primitive types, storable in registers + i = [gensym("i_$j") for j ∈ 1:N] + c = [gensym("c_$j") for j ∈ 1:N] + d = [gensym("d_$j") for j ∈ 1:N] + n = [gensym("n_$j") for j ∈ 1:N] + + # Function call + x = (:( $(i[j])*$(c[j])+$(d[j]) ) for j ∈ 1:N) + expr = :( f( ($(i...),), ($(x...),) ) ) + + # Main recursion + for j = 1:N + expr = quote + for $(i[j]) = 1:$(n[j]) + $expr + end + end + end + + # Initialisation of linear transformation and variable references + asserts = ( :( @assert ($(n[j]) > extrapoints) "Dimension $j too small" ) for j ∈ 1:N ) + scales = ( :( (box[$j][2]-box[$j][1])/($(n[j])-extrapoints) ) for j ∈ 1:N ) + translates = ( :( box[$j][1] - shift*$(c[j]) ) for j ∈ 1:N ) + expr = quote + # Set up local (register-storable) variables for dims, needed in the loop above + local $(lift_exprs(n)) = dims + # Check everything is ok + $(sequence_exprs(asserts)) + # Set up linear transformation, needed in the loop above + local $(lift_exprs(c)) = $(lift_exprs(scales)) + local $(lift_exprs(d)) = $(lift_exprs(translates)) + # Execute the loop + $expr + end +end + + +end # module \ No newline at end of file diff -r 22a64e826ee7 -r d881275c6564 src/Metaprogramming.jl --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Metaprogramming.jl Wed Dec 22 11:14:38 2021 +0200 @@ -0,0 +1,124 @@ +################## +# Metaprogramming +################## + +""" +`module AlgTools.MetaProgramming` + +This module implements helpers for metaprogramming, used, e.g., by `Loops`. +It includes the macros `@foldl`, `@all`, and `@map` and the metaprogramming +functions `foldl_exprs`, `all_exprs`, `map_exprs`, `lift_exprs`, and `sequence_exprs`. +""" + +module Metaprogramming + +############## +# Our exports +############## + +export @foldl, + @all, + @map, + foldl_exprs, + all_exprs, + map_exprs, + lift_exprs, + sequence_exprs + +########################################### +# Support functions for macro construction +########################################### + +"""" +`foldl_exprs(f, a)` + +Does `foldl` on the `Expr`s in `a` with the `f`, a `Function` or another `Expr` +representing a function. Returns new `Expr` with the expanded `foldl` operation. +""" +function foldl_exprs(f :: Expr, a) + return foldl((b, c) -> :($f($b, $c)), a) +end + +function foldl_exprs(f :: Function, a) + return foldl((b, c) -> :($(f(b, c))), a) +end + +"""" +`all_exprs(a...)` + +Does `all` on the `Expr`s in `a`. +Returns new `Expr` with the expanded `all` operation. +""" +function all_exprs(a) + return foldl((b, c) -> :( $b && $c ), a) +end + +"""" +`map_exprs(f, a...)` + +Does `map` on the `Expr`s in `a` with the `f`, a `Function` or another `Expr` +representing a function. Returns a new `NTuple` with `Expr`s for the result of +applying a to all elements of `a`. +""" +function map_exprs(f :: Expr , a) + return ([ :( $f($b) ) for b ∈ a ]...,) +end + +function map_exprs(f :: Function , a) + return ([ :( $(f(b)) ) for b ∈ a ]...,) +end + +"""" +`lift_exprs(a)` + +Turns a `Tuple` `a` of `Expr`s into an `Expr` containing `Tuple`. +""" +function lift_exprs(a) + return :( ($(a...),) ) +end + +""" +`sequence_exprs(a)` + +Turns a `Tuple` `a` of `Expr`s into an `Expr` with each original expr on its own line (separated by `;`). +""" +function sequence_exprs(a) + return foldl_exprs((b, c) -> :( $b; $c ), a) +end + + +######### +# Macros +######### + +"""" +`@foldl(f, a...)` + +A basic macro version of `foldl`. +""" +macro foldl(f, a...) + return foldl_exprs(f, a) +end + +"""" +`@all(a...)` + +A macro version of `all`. +""" +macro all(a...) + # Implementation based on @foldl: + #return :( @foldl((b, c) -> b && c, $(a...)) ) + # Direct implementation produces prettier generated code: + return all_exprs(a) +end + +"""" +`map(f, a...)` + +A macro version of `map`. +""" +macro map(f, a...) + return lift_exprs(map_exprs(f, a)) +end + +end # module