Skip to content

Commit

Permalink
Allow opt-out of implicit bounds-checking
Browse files Browse the repository at this point in the history
KernelAbstractions currently creates kernels that look like:

```
if __validindex(ctx)
   # Body
end
```

This is problematic due to the convergence requirement on
`@synchronize`.
  • Loading branch information
vchuravy committed Feb 10, 2025
1 parent f038d8c commit 4dd0acc
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
2 changes: 1 addition & 1 deletion examples/histogram.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function create_histogram(input)
end

# This a 1D histogram kernel where the histogramming happens on shmem
@kernel function histogram_kernel!(histogram_output, input)
@kernel implicit_validindex = false function histogram_kernel!(histogram_output, input)
tid = @index(Global, Linear)
lid = @index(Local, Linear)

Expand Down
10 changes: 7 additions & 3 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ synchronize(backend)
```
"""
macro kernel(expr)
return __kernel(expr, #=force_inbounds=# false)
return __kernel(expr, #=force_inbounds=# false, #=implicit_validindex=# true)
end

"""
Expand All @@ -71,8 +71,9 @@ This allows for two different configurations:
"""
macro kernel(ex...)
if length(ex) == 1
return __kernel(ex[1], false)
return __kernel(ex[1], false, true)
else
implicit_validindex = true
force_inbounds = false
for i in 1:(length(ex) - 1)
if ex[i] isa Expr && ex[i].head == :(=) &&
Expand All @@ -81,6 +82,9 @@ macro kernel(ex...)
elseif ex[i] isa Expr && ex[i].head == :(=) &&
ex[i].args[1] == :inbounds && ex[i].args[2] isa Bool
force_inbounds = ex[i].args[2]
elseif ex[i] isa Expr && ex[i].head == :(=) &&
ex[i].args[1] == :implicit_validindex && ex[i].args[2] isa Bool
implicit_validindex = ex[i].args[2]
else
error(
"Configuration should be of form:\n" *
Expand All @@ -90,7 +94,7 @@ macro kernel(ex...)
)
end
end
return __kernel(ex[end], force_inbounds)
return __kernel(ex[end], force_inbounds, implicit_validindex)
end
end

Expand Down
19 changes: 13 additions & 6 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function find_return(stmt)
end

# XXX: Proper errors
function __kernel(expr, force_inbounds = false)
function __kernel(expr, force_inbounds = false, implicit_validindex = true)

Check warning on line 13 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L13

Added line #L13 was not covered by tests
def = splitdef(expr)
name = def[:name]
args = def[:args]
Expand All @@ -30,7 +30,7 @@ function __kernel(expr, force_inbounds = false)

def_gpu = deepcopy(def)
def_gpu[:name] = gpu_name = Symbol(:gpu_, name)
transform_gpu!(def_gpu, constargs, force_inbounds)
transform_gpu!(def_gpu, constargs, force_inbounds, implicit_validindex)

Check warning on line 33 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L33

Added line #L33 was not covered by tests
gpu_function = combinedef(def_gpu)

# create constructor functions
Expand All @@ -50,7 +50,7 @@ end

# The easy case, transform the function for GPU execution
# - mark constant arguments by applying `constify`.
function transform_gpu!(def, constargs, force_inbounds)
function transform_gpu!(def, constargs, force_inbounds, implicit_validindex)

Check warning on line 53 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L53

Added line #L53 was not covered by tests
let_constargs = Expr[]
for (i, arg) in enumerate(def[:args])
if constargs[i]
Expand All @@ -64,11 +64,18 @@ function transform_gpu!(def, constargs, force_inbounds)
@inbounds $(body)
end
end
body = quote
if $__validindex(__ctx__)
if implicit_validindex
body = quote
if $__validindex(__ctx__)
$(body)

Check warning on line 70 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L67-L70

Added lines #L67 - L70 were not covered by tests
end
return nothing

Check warning on line 72 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L72

Added line #L72 was not covered by tests
end
else
body = quote

Check warning on line 75 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L75

Added line #L75 was not covered by tests
$(body)
return nothing

Check warning on line 77 in src/macros.jl

View check run for this annotation

Codecov / codecov/patch

src/macros.jl#L77

Added line #L77 was not covered by tests
end
return nothing
end
def[:body] = Expr(
:let,
Expand Down

0 comments on commit 4dd0acc

Please sign in to comment.