Skip to content

Commit 7e448d1

Browse files
committed
Forbid divergent execution of work-group barriers
1 parent 0121280 commit 7e448d1

File tree

1 file changed

+6
-24
lines changed

1 file changed

+6
-24
lines changed

src/macros.jl

+6-24
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,8 @@ function transform_gpu!(def, constargs, force_inbounds, unsafe_indicies)
8787
end
8888

8989
struct WorkgroupLoop
90-
indicies::Vector{Any}
9190
stmts::Vector{Any}
9291
allocations::Vector{Any}
93-
private_allocations::Vector{Any}
94-
private::Set{Symbol}
9592
terminated_in_sync::Bool
9693
end
9794

@@ -112,26 +109,18 @@ function find_sync(stmt)
112109
end
113110

114111
# TODO proper handling of LineInfo
115-
function split(
116-
stmts,
117-
indicies = Any[], private = Set{Symbol}(),
118-
)
112+
function split(stmts)
119113
# 1. Split the code into blocks separated by `@synchronize`
120-
# 2. Aggregate `@index` expressions
121-
# 3. Hoist allocations
122-
# 4. Hoist uniforms
123114

124115
current = Any[]
125116
allocations = Any[]
126-
private_allocations = Any[]
127117
new_stmts = Any[]
128118
for stmt in stmts
129119
has_sync = find_sync(stmt)
130120
if has_sync
131-
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private), is_sync(stmt))
121+
loop = WorkgroupLoop(current, allocations, is_sync(stmt))
132122
push!(new_stmts, emit(loop))
133123
allocations = Any[]
134-
private_allocations = Any[]
135124
current = Any[]
136125
is_sync(stmt) && continue
137126

@@ -143,7 +132,7 @@ function split(
143132
function recurse(expr::Expr)
144133
expr = unblock(expr)
145134
if is_scope_construct(expr) && any(find_sync, expr.args)
146-
new_args = unblock(split(expr.args, deepcopy(indicies), deepcopy(private)))
135+
new_args = unblock(split(expr.args))
147136
return Expr(expr.head, new_args...)
148137
else
149138
return Expr(expr.head, map(recurse, expr.args)...)
@@ -157,14 +146,10 @@ function split(
157146
push!(allocations, stmt)
158147
continue
159148
elseif @capture(stmt, @private lhs_ = rhs_)
160-
push!(private, lhs)
161-
push!(private_allocations, :($lhs = $rhs))
149+
push!(allocations, :($lhs = $rhs))
162150
continue
163151
elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
164-
if @capture(rhs, @index(args__))
165-
push!(indicies, stmt)
166-
continue
167-
elseif @capture(rhs, @localmem(args__) | @uniform(args__))
152+
if @capture(rhs, @localmem(args__) | @uniform(args__))
168153
push!(allocations, stmt)
169154
continue
170155
elseif @capture(rhs, @private(T_, dims_))
@@ -176,7 +161,6 @@ function split(
176161
end
177162
alloc = :($Scratchpad(__ctx__, $T, Val($dims)))
178163
push!(allocations, :($lhs = $alloc))
179-
push!(private, lhs)
180164
continue
181165
end
182166
end
@@ -186,7 +170,7 @@ function split(
186170

187171
# everything since the last `@synchronize`
188172
if !isempty(current)
189-
loop = WorkgroupLoop(deepcopy(indicies), current, allocations, private_allocations, deepcopy(private), false)
173+
loop = WorkgroupLoop(current, allocations, false)
190174
push!(new_stmts, emit(loop))
191175
end
192176
return new_stmts
@@ -198,9 +182,7 @@ function emit(loop)
198182
body = Expr(:block, loop.stmts...)
199183
loopexpr = quote
200184
$(loop.allocations...)
201-
$(loop.private_allocations...)
202185
if __active_lane__
203-
$(loop.indicies...)
204186
$(unblock(body))
205187
end
206188
end

0 commit comments

Comments
 (0)