Skip to content

Commit 5e3e1f4

Browse files
committed
Forbid divergent execution of work-group barriers
1 parent b435bb2 commit 5e3e1f4

File tree

3 files changed

+61
-15
lines changed

3 files changed

+61
-15
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KernelAbstractions"
22
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
33
authors = ["Valentin Churavy <v.churavy@gmail.com> and contributors"]
4-
version = "0.9.33"
4+
version = "0.9.34"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/KernelAbstractions.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ end
284284
After a `@synchronize` statement all read and writes to global and local memory
285285
from each thread in the workgroup are visible in from all other threads in the
286286
workgroup.
287+
288+
!!! note
289+
`@synchronize()` must be encountered by all workitems of a work-group executing the kernel or by none at all.
287290
"""
288291
macro synchronize()
289292
return quote
@@ -301,10 +304,15 @@ workgroup. `cond` is not allowed to have any visible sideffects.
301304
# Platform differences
302305
- `GPU`: This synchronization will only occur if the `cond` evaluates.
303306
- `CPU`: This synchronization will always occur.
307+
308+
!!! warn
309+
This variant of the `@synchronize` macro violates the requirement that `@synchronize` must be encountered
310+
by all workitems of a work-group executing the kernel or by none at all.
311+
Since v`0.9.34` this version of the macro is deprecated and lowers to `@synchronize()`
304312
"""
305313
macro synchronize(cond)
306314
return quote
307-
$(esc(cond)) && $__synchronize()
315+
$__synchronize()
308316
end
309317
end
310318

src/macros.jl

+51-13
Original file line numberDiff line numberDiff line change
@@ -86,22 +86,24 @@ function transform_gpu!(def, constargs, force_inbounds)
8686
end
8787
end
8888
pushfirst!(def[:args], :__ctx__)
89-
body = def[:body]
89+
new_stmts = Expr[]
90+
body = MacroTools.flatten(def[:body])
91+
stmts = body.args
92+
push!(new_stmts, Expr(:aliasscope))
93+
push!(new_stmts, :(__active_lane__ = $__validindex(__ctx__)))
9094
if force_inbounds
91-
body = quote
92-
@inbounds $(body)
93-
end
95+
push!(new_stmts, Expr(:inbounds, true))
9496
end
95-
body = quote
96-
if $__validindex(__ctx__)
97-
$(body)
98-
end
99-
return nothing
97+
append!(new_stmts, split(emit_gpu, body.args))
98+
if force_inbounds
99+
push!(new_stmts, Expr(:inbounds, :pop))
100100
end
101+
push!(new_stmts, Expr(:popaliasscope))
102+
push!(new_stmts, :(return nothing))
101103
def[:body] = Expr(
102104
:let,
103105
Expr(:block, let_constargs...),
104-
body,
106+
Expr(:block, new_stmts...),
105107
)
106108
return
107109
end
@@ -127,7 +129,7 @@ function transform_cpu!(def, constargs, force_inbounds)
127129
if force_inbounds
128130
push!(new_stmts, Expr(:inbounds, true))
129131
end
130-
append!(new_stmts, split(body.args))
132+
append!(new_stmts, split(emit_cpu, body.args))
131133
if force_inbounds
132134
push!(new_stmts, Expr(:inbounds, :pop))
133135
end
@@ -167,6 +169,7 @@ end
167169

168170
# TODO proper handling of LineInfo
169171
function split(
172+
emit,
170173
stmts,
171174
indicies = Any[], private = Set{Symbol}(),
172175
)
@@ -197,7 +200,7 @@ function split(
197200
function recurse(expr::Expr)
198201
expr = unblock(expr)
199202
if is_scope_construct(expr) && any(find_sync, expr.args)
200-
new_args = unblock(split(expr.args, deepcopy(indicies), deepcopy(private)))
203+
new_args = unblock(split(emit, expr.args, deepcopy(indicies), deepcopy(private)))
201204
return Expr(expr.head, new_args...)
202205
else
203206
return Expr(expr.head, map(recurse, expr.args)...)
@@ -246,7 +249,7 @@ function split(
246249
return new_stmts
247250
end
248251

249-
function emit(loop)
252+
function emit_cpu(loop)
250253
idx = gensym(:I)
251254
for stmt in loop.indicies
252255
# splice index into the i = @index(Cartesian, $idx)
@@ -300,3 +303,38 @@ function emit(loop)
300303

301304
return unblock(Expr(:block, stmts...))
302305
end
306+
307+
function emit_gpu(loop)
308+
stmts = Any[]
309+
append!(stmts, loop.allocations)
310+
for stmt in loop.private_allocations
311+
if @capture(stmt, lhs_ = rhs_)
312+
push!(stmts, :($lhs = $rhs))
313+
else
314+
error("@private $stmt not an assignment")
315+
end
316+
end
317+
318+
# don't emit empty loops
319+
if !(isempty(loop.stmts) || all(s -> s isa LineNumberNode, loop.stmts))
320+
body = Expr(:block, loop.stmts...)
321+
body = postwalk(body) do expr
322+
if @capture(expr, lhs_ = rhs_)
323+
if lhs in loop.private
324+
error("Can't assign to variables marked private")
325+
end
326+
end
327+
return expr
328+
end
329+
loopexpr = quote
330+
if __active_lane__
331+
$(loop.indicies...)
332+
$(unblock(body))
333+
end
334+
$__synchronize()
335+
end
336+
push!(stmts, loopexpr)
337+
end
338+
339+
return unblock(Expr(:block, stmts...))
340+
end

0 commit comments

Comments
 (0)