###############################
# Fast multi-dimensional loops
###############################

"""
`module AlgTools.Loops`

This module implements `for_zip`, `for_indices`, and `for_linspace` generated
functions for loops amenable to a high level of optimisation by the compiler.
"""
module Loops

using ..Metaprogramming

##############
# Our exports
##############

export Box,
       Grid,
       @Grid_str,
       for_zip,
       for_indices,
       for_linspace

#####################################
# Looping over many iterators/arrays
#####################################

"""
`for_zip(f, a...)`

With `dims` of type `Dims{N}`, loop over all `(i_1,…,i_N)` with `i_j ∈ 1:dims[j]`.
Typically to be used as

```julia
for_zip(a, b) do x, y
    # do something
end
```
"""
@generated function for_zip(f :: Function, a...)
    # Generate reference variables to avoid execution of potential
    # code inputs several times.
    aref = [gensym("aref_$(j)") for j ∈ 1:length(a)]
    # Generate length-checking asserts
    asserts = ( :( length($b)==M ) for b ∈ aref[2:end] )

    quote
        # Assign input expressions to reference variables
        local $(lift_exprs(aref)) = (a...,)
        # Check that all inputs are of equal length
        M = length($(aref[1]))
        @assert $(all_exprs(asserts)) "Parameter lengths have to be equal"
        # Do the loop
        for i = 1:M
            f($([ :( $c[i] ) for c ∈ aref]...))
        end
    end
end

#######################
# Looping over indices
#######################

"""
`for_indices(f, N, dims)`

With `dims` of type `Dims{N}`, loop over all `(i_1,…,i_N)` with `i_j ∈ 1:dims[j]`.
Typically to be used as

```julia
for_indices(N, dims) do i_1, …, i_N
    # do something
end
```
"""
@generated function for_indices(f :: Function, dims :: Dims{N}) where N
    # This depends on numbers being passed as numbers, not Exprs!
    local i = [gensym("i_$(j)") for j ∈ 1:N]
    local di = [gensym("di_$(j)") for j ∈ 1:N]

    local expr = :( f($(i...),) )

    for j = 1:N
        expr = quote
            for $(i[j]) = 1:$(di[j])
                $expr
            end
        end
    end
    quote
        $(lift_exprs(di)) = dims
        $expr
    end
end


##########################################
# Looping over indices and transformation
##########################################

const Box{N, T} = NTuple{N, Tuple{T, T}}
struct Grid{R} end

"""
`@Grid_str`

A convenience macro for writing Grid{:gridtype} as Grid"gridtype" for use with `for_linspace`.
"""
macro Grid_str(s)
    :(Grid{$(Expr(:quote, Symbol(s)))})
end

# For type-checking user's code, it would be better to wrap `for_linspace`
# with a multiple dispatch type-case selector. However, the approach of
# avoiding a wrapper and calling `gridalign` from `for_linspace` to get
# the grid parements helps with inlining the code.
@inline gridalign(:: Type{Type{Grid{:midpoint}}}) = (0.5, 0)
@inline gridalign(:: Type{Type{Grid{:firstpoint}}}) = (1.0, 0)
@inline gridalign(:: Type{Type{Grid{:lastpoint}}}) = (0.0, 0)
@inline gridalign(:: Type{Type{Grid{:linspace}}}) = (1.0, 1)

"""
`for_linspace(f :: Function, dims :: Dims{N} , box :: Box{N, T}, grid :: Type{Grid{G}} = Grid"linspace")`

With `box` of type `Box{N, T} = NTuple{N, Tuple{T, T}}` with `T <: AbstractFloat`,
loop over all `((i_1,…,i_N), (x_1,…,x_N))`  with `i_j ∈ 1:dims[j]` and each `x_i`
sampled in equal-length subintervals of the interval `box[i]`.
If `grid` is `Grid"midpoint"`, `Grid"firstpoint"`, `Grid"lastpoint"`, the point `x_i` is
the mid-point, first point, or last point of the corresponding subinterval, and there
are `dims[j]` subintervals. If `grid` is `Grid"linspace"`, there are `dims[j]-1` subintervals,
and all the endpoints are generated. Typically to be used as

```julia
for_linspace(dims, box) do (i_1, …, i_N), (x_1, …, x_N)
    # do something
end
```
"""
@inline @generated function
  for_linspace(f :: Function, dims :: Dims{N}, box :: Box{N, T},
               grid :: Type{Grid{G}} = Grid"linspace") where {N, T <: AbstractFloat, G}

    shift, extrapoints = gridalign(grid)

    # Generate some local variables of primitive types, storable in registers
    i = [gensym("i_$j") for j ∈ 1:N]
    c = [gensym("c_$j") for j ∈ 1:N]
    d = [gensym("d_$j") for j ∈ 1:N]
    n = [gensym("n_$j") for j ∈ 1:N]

    # Function call
    x = (:( $(i[j])*$(c[j])+$(d[j]) ) for j ∈ 1:N)
    expr = :( f( ($(i...),), ($(x...),) ) )

    # Main recursion
    for j = 1:N
        expr = quote
            for $(i[j]) = 1:$(n[j])
                $expr
            end
        end
    end
    
    # Initialisation of linear transformation and variable references
    asserts = ( :( @assert ($(n[j]) > extrapoints) "Dimension $j too small" ) for j ∈ 1:N )
    scales = ( :( (box[$j][2]-box[$j][1])/($(n[j])-extrapoints) ) for j ∈ 1:N )
    translates = ( :( box[$j][1] - shift*$(c[j]) ) for j ∈ 1:N )
    expr = quote
        # Set up local (register-storable) variables for dims, needed in the loop above
        local $(lift_exprs(n)) = dims
        # Check everything is ok
        $(sequence_exprs(asserts))
        # Set up linear transformation, needed in the loop above
        local $(lift_exprs(c)) = $(lift_exprs(scales))
        local $(lift_exprs(d)) = $(lift_exprs(translates))
        # Execute the loop
        $expr
    end
end


end # module