@@ -87,11 +87,8 @@ function transform_gpu!(def, constargs, force_inbounds, unsafe_indicies)
87
87
end
88
88
89
89
struct WorkgroupLoop
90
- indicies:: Vector{Any}
91
90
stmts:: Vector{Any}
92
91
allocations:: Vector{Any}
93
- private_allocations:: Vector{Any}
94
- private:: Set{Symbol}
95
92
terminated_in_sync:: Bool
96
93
end
97
94
@@ -112,26 +109,18 @@ function find_sync(stmt)
112
109
end
113
110
114
111
# TODO proper handling of LineInfo
115
- function split (
116
- stmts,
117
- indicies = Any[], private = Set {Symbol} (),
118
- )
112
+ function split (stmts)
119
113
# 1. Split the code into blocks separated by `@synchronize`
120
- # 2. Aggregate `@index` expressions
121
- # 3. Hoist allocations
122
- # 4. Hoist uniforms
123
114
124
115
current = Any[]
125
116
allocations = Any[]
126
- private_allocations = Any[]
127
117
new_stmts = Any[]
128
118
for stmt in stmts
129
119
has_sync = find_sync (stmt)
130
120
if has_sync
131
- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private) , is_sync (stmt))
121
+ loop = WorkgroupLoop (current, allocations, is_sync (stmt))
132
122
push! (new_stmts, emit (loop))
133
123
allocations = Any[]
134
- private_allocations = Any[]
135
124
current = Any[]
136
125
is_sync (stmt) && continue
137
126
@@ -143,7 +132,7 @@ function split(
143
132
function recurse (expr:: Expr )
144
133
expr = unblock (expr)
145
134
if is_scope_construct (expr) && any (find_sync, expr. args)
146
- new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private) ))
135
+ new_args = unblock (split (expr. args))
147
136
return Expr (expr. head, new_args... )
148
137
else
149
138
return Expr (expr. head, map (recurse, expr. args)... )
@@ -157,14 +146,10 @@ function split(
157
146
push! (allocations, stmt)
158
147
continue
159
148
elseif @capture (stmt, @private lhs_ = rhs_)
160
- push! (private, lhs)
161
- push! (private_allocations, :($ lhs = $ rhs))
149
+ push! (allocations, :($ lhs = $ rhs))
162
150
continue
163
151
elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
164
- if @capture (rhs, @index (args__))
165
- push! (indicies, stmt)
166
- continue
167
- elseif @capture (rhs, @localmem (args__) | @uniform (args__))
152
+ if @capture (rhs, @localmem (args__) | @uniform (args__))
168
153
push! (allocations, stmt)
169
154
continue
170
155
elseif @capture (rhs, @private (T_, dims_))
@@ -176,7 +161,6 @@ function split(
176
161
end
177
162
alloc = :($ Scratchpad (__ctx__, $ T, Val ($ dims)))
178
163
push! (allocations, :($ lhs = $ alloc))
179
- push! (private, lhs)
180
164
continue
181
165
end
182
166
end
@@ -186,7 +170,7 @@ function split(
186
170
187
171
# everything since the last `@synchronize`
188
172
if ! isempty (current)
189
- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private) , false )
173
+ loop = WorkgroupLoop (current, allocations, false )
190
174
push! (new_stmts, emit (loop))
191
175
end
192
176
return new_stmts
@@ -198,9 +182,7 @@ function emit(loop)
198
182
body = Expr (:block , loop. stmts... )
199
183
loopexpr = quote
200
184
$ (loop. allocations... )
201
- $ (loop. private_allocations... )
202
185
if __active_lane__
203
- $ (loop. indicies... )
204
186
$ (unblock (body))
205
187
end
206
188
end
0 commit comments