@@ -86,22 +86,24 @@ function transform_gpu!(def, constargs, force_inbounds)
86
86
end
87
87
end
88
88
pushfirst! (def[:args ], :__ctx__ )
89
- body = def[:body ]
89
+ new_stmts = Expr[]
90
+ body = MacroTools. flatten (def[:body ])
91
+ stmts = body. args
92
+ push! (new_stmts, Expr (:aliasscope ))
93
+ push! (new_stmts, :(__active_lane__ = $ __validindex (__ctx__)))
90
94
if force_inbounds
91
- body = quote
92
- @inbounds $ (body)
93
- end
95
+ push! (new_stmts, Expr (:inbounds , true ))
94
96
end
95
- body = quote
96
- if $ __validindex (__ctx__)
97
- $ (body)
98
- end
99
- return nothing
97
+ append! (new_stmts, split (emit_gpu, body. args))
98
+ if force_inbounds
99
+ push! (new_stmts, Expr (:inbounds , :pop ))
100
100
end
101
+ push! (new_stmts, Expr (:popaliasscope ))
102
+ push! (new_stmts, :(return nothing ))
101
103
def[:body ] = Expr (
102
104
:let ,
103
105
Expr (:block , let_constargs... ),
104
- body ,
106
+ Expr ( :block , new_stmts ... ) ,
105
107
)
106
108
return
107
109
end
@@ -127,7 +129,7 @@ function transform_cpu!(def, constargs, force_inbounds)
127
129
if force_inbounds
128
130
push! (new_stmts, Expr (:inbounds , true ))
129
131
end
130
- append! (new_stmts, split (body. args))
132
+ append! (new_stmts, split (emit_cpu, body. args))
131
133
if force_inbounds
132
134
push! (new_stmts, Expr (:inbounds , :pop ))
133
135
end
167
169
168
170
# TODO proper handling of LineInfo
169
171
function split (
172
+ emit,
170
173
stmts,
171
174
indicies = Any[], private = Set {Symbol} (),
172
175
)
@@ -197,7 +200,7 @@ function split(
197
200
function recurse (expr:: Expr )
198
201
expr = unblock (expr)
199
202
if is_scope_construct (expr) && any (find_sync, expr. args)
200
- new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private)))
203
+ new_args = unblock (split (emit, expr. args, deepcopy (indicies), deepcopy (private)))
201
204
return Expr (expr. head, new_args... )
202
205
else
203
206
return Expr (expr. head, map (recurse, expr. args)... )
@@ -246,7 +249,7 @@ function split(
246
249
return new_stmts
247
250
end
248
251
249
- function emit (loop)
252
+ function emit_cpu (loop)
250
253
idx = gensym (:I )
251
254
for stmt in loop. indicies
252
255
# splice index into the i = @index(Cartesian, $idx)
@@ -300,3 +303,38 @@ function emit(loop)
300
303
301
304
return unblock (Expr (:block , stmts... ))
302
305
end
306
+
307
+ function emit_gpu (loop)
308
+ stmts = Any[]
309
+ append! (stmts, loop. allocations)
310
+ for stmt in loop. private_allocations
311
+ if @capture (stmt, lhs_ = rhs_)
312
+ push! (stmts, :($ lhs = $ rhs))
313
+ else
314
+ error (" @private $stmt not an assignment" )
315
+ end
316
+ end
317
+
318
+ # don't emit empty loops
319
+ if ! (isempty (loop. stmts) || all (s -> s isa LineNumberNode, loop. stmts))
320
+ body = Expr (:block , loop. stmts... )
321
+ body = postwalk (body) do expr
322
+ if @capture (expr, lhs_ = rhs_)
323
+ if lhs in loop. private
324
+ error (" Can't assign to variables marked private" )
325
+ end
326
+ end
327
+ return expr
328
+ end
329
+ loopexpr = quote
330
+ if __active_lane__
331
+ $ (loop. indicies... )
332
+ $ (unblock (body))
333
+ end
334
+ $ __synchronize ()
335
+ end
336
+ push! (stmts, loopexpr)
337
+ end
338
+
339
+ return unblock (Expr (:block , stmts... ))
340
+ end
0 commit comments