@@ -86,11 +86,8 @@ function transform_gpu!(def, constargs, force_inbounds, unsafe_indices)
86
86
end
87
87
88
88
struct WorkgroupLoop
89
- indices:: Vector{Any}
90
89
stmts:: Vector{Any}
91
90
allocations:: Vector{Any}
92
- private_allocations:: Vector{Any}
93
- private:: Set{Symbol}
94
91
terminated_in_sync:: Bool
95
92
end
96
93
@@ -111,26 +108,18 @@ function find_sync(stmt)
111
108
end
112
109
113
110
# TODO proper handling of LineInfo
114
- function split (
115
- stmts,
116
- indices = Any[], private = Set {Symbol} (),
117
- )
111
+ function split (stmts)
118
112
# 1. Split the code into blocks separated by `@synchronize`
119
- # 2. Aggregate `@index` expressions
120
- # 3. Hoist allocations
121
- # 4. Hoist uniforms
122
113
123
114
current = Any[]
124
115
allocations = Any[]
125
- private_allocations = Any[]
126
116
new_stmts = Any[]
127
117
for stmt in stmts
128
118
has_sync = find_sync (stmt)
129
119
if has_sync
130
- loop = WorkgroupLoop (deepcopy (indices), current, allocations, private_allocations, deepcopy (private) , is_sync (stmt))
120
+ loop = WorkgroupLoop (current, allocations, is_sync (stmt))
131
121
push! (new_stmts, emit (loop))
132
122
allocations = Any[]
133
- private_allocations = Any[]
134
123
current = Any[]
135
124
is_sync (stmt) && continue
136
125
@@ -142,7 +131,7 @@ function split(
142
131
function recurse (expr:: Expr )
143
132
expr = unblock (expr)
144
133
if is_scope_construct (expr) && any (find_sync, expr. args)
145
- new_args = unblock (split (expr. args, deepcopy (indices), deepcopy (private) ))
134
+ new_args = unblock (split (expr. args))
146
135
return Expr (expr. head, new_args... )
147
136
else
148
137
return Expr (expr. head, map (recurse, expr. args)... )
@@ -156,14 +145,10 @@ function split(
156
145
push! (allocations, stmt)
157
146
continue
158
147
elseif @capture (stmt, @private lhs_ = rhs_)
159
- push! (private, lhs)
160
- push! (private_allocations, :($ lhs = $ rhs))
148
+ push! (allocations, :($ lhs = $ rhs))
161
149
continue
162
150
elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
163
- if @capture (rhs, @index (args__))
164
- push! (indices, stmt)
165
- continue
166
- elseif @capture (rhs, @localmem (args__) | @uniform (args__))
151
+ if @capture (rhs, @localmem (args__) | @uniform (args__))
167
152
push! (allocations, stmt)
168
153
continue
169
154
elseif @capture (rhs, @private (T_, dims_))
@@ -175,7 +160,6 @@ function split(
175
160
end
176
161
alloc = :($ Scratchpad (__ctx__, $ T, Val ($ dims)))
177
162
push! (allocations, :($ lhs = $ alloc))
178
- push! (private, lhs)
179
163
continue
180
164
end
181
165
end
@@ -185,7 +169,7 @@ function split(
185
169
186
170
# everything since the last `@synchronize`
187
171
if ! isempty (current)
188
- loop = WorkgroupLoop (deepcopy (indices), current, allocations, private_allocations, deepcopy (private) , false )
172
+ loop = WorkgroupLoop (current, allocations, false )
189
173
push! (new_stmts, emit (loop))
190
174
end
191
175
return new_stmts
@@ -197,9 +181,7 @@ function emit(loop)
197
181
body = Expr (:block , loop. stmts... )
198
182
loopexpr = quote
199
183
$ (loop. allocations... )
200
- $ (loop. private_allocations... )
201
184
if __active_lane__
202
- $ (loop. indices... )
203
185
$ (unblock (body))
204
186
end
205
187
end
0 commit comments