Skip to content

Commit 58ed8cc

Browse files
committed
Forbid divergent execution of work-group barriers
1 parent 0dcdc8b commit 58ed8cc

File tree

1 file changed

+6
-24
lines changed

1 file changed

+6
-24
lines changed

src/macros.jl

+6-24
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,8 @@ function transform_gpu!(def, constargs, force_inbounds, unsafe_indices)
8686
end
8787

8888
struct WorkgroupLoop
89-
indices::Vector{Any}
9089
stmts::Vector{Any}
9190
allocations::Vector{Any}
92-
private_allocations::Vector{Any}
93-
private::Set{Symbol}
9491
terminated_in_sync::Bool
9592
end
9693

@@ -111,26 +108,18 @@ function find_sync(stmt)
111108
end
112109

113110
# TODO proper handling of LineInfo
114-
function split(
115-
stmts,
116-
indices = Any[], private = Set{Symbol}(),
117-
)
111+
function split(stmts)
118112
# 1. Split the code into blocks separated by `@synchronize`
119-
# 2. Aggregate `@index` expressions
120-
# 3. Hoist allocations
121-
# 4. Hoist uniforms
122113

123114
current = Any[]
124115
allocations = Any[]
125-
private_allocations = Any[]
126116
new_stmts = Any[]
127117
for stmt in stmts
128118
has_sync = find_sync(stmt)
129119
if has_sync
130-
loop = WorkgroupLoop(deepcopy(indices), current, allocations, private_allocations, deepcopy(private), is_sync(stmt))
120+
loop = WorkgroupLoop(current, allocations, is_sync(stmt))
131121
push!(new_stmts, emit(loop))
132122
allocations = Any[]
133-
private_allocations = Any[]
134123
current = Any[]
135124
is_sync(stmt) && continue
136125

@@ -142,7 +131,7 @@ function split(
142131
function recurse(expr::Expr)
143132
expr = unblock(expr)
144133
if is_scope_construct(expr) && any(find_sync, expr.args)
145-
new_args = unblock(split(expr.args, deepcopy(indices), deepcopy(private)))
134+
new_args = unblock(split(expr.args))
146135
return Expr(expr.head, new_args...)
147136
else
148137
return Expr(expr.head, map(recurse, expr.args)...)
@@ -156,14 +145,10 @@ function split(
156145
push!(allocations, stmt)
157146
continue
158147
elseif @capture(stmt, @private lhs_ = rhs_)
159-
push!(private, lhs)
160-
push!(private_allocations, :($lhs = $rhs))
148+
push!(allocations, :($lhs = $rhs))
161149
continue
162150
elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
163-
if @capture(rhs, @index(args__))
164-
push!(indices, stmt)
165-
continue
166-
elseif @capture(rhs, @localmem(args__) | @uniform(args__))
151+
if @capture(rhs, @localmem(args__) | @uniform(args__))
167152
push!(allocations, stmt)
168153
continue
169154
elseif @capture(rhs, @private(T_, dims_))
@@ -175,7 +160,6 @@ function split(
175160
end
176161
alloc = :($Scratchpad(__ctx__, $T, Val($dims)))
177162
push!(allocations, :($lhs = $alloc))
178-
push!(private, lhs)
179163
continue
180164
end
181165
end
@@ -185,7 +169,7 @@ function split(
185169

186170
# everything since the last `@synchronize`
187171
if !isempty(current)
188-
loop = WorkgroupLoop(deepcopy(indices), current, allocations, private_allocations, deepcopy(private), false)
172+
loop = WorkgroupLoop(current, allocations, false)
189173
push!(new_stmts, emit(loop))
190174
end
191175
return new_stmts
@@ -197,9 +181,7 @@ function emit(loop)
197181
body = Expr(:block, loop.stmts...)
198182
loopexpr = quote
199183
$(loop.allocations...)
200-
$(loop.private_allocations...)
201184
if __active_lane__
202-
$(loop.indices...)
203185
$(unblock(body))
204186
end
205187
end

0 commit comments

Comments
 (0)