Skip to content

Commit 9a20063

Browse files
authored
create subbuffer immediately before constructing ScheduleItem [pr] (tinygrad#9162)
1 parent 1c92534 commit 9a20063

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tinygrad/engine/schedule.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp
110110
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
111111
# track the buffer uop for the simplified uop
112112
buffer_map[buf] = buf_uop
113-
if op.op is Ops.BUFFER_VIEW: buffers[buf_uop] = (x:=op.src[0]).buf_uop.buffer.view(op.size, op.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
114113
# (early) bufferize
115114
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
116115
return ret
@@ -391,6 +390,8 @@ def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
391390
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
392391
# fix_kernel_ops
393392
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
393+
# create subbuffer
394+
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = bufs[1].buffer.view(ast.size, ast.dtype, (x:=ast.src[0]).st_arg.views[0].offset*x.dtype.itemsize)
394395
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
395396

396397
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
@@ -454,8 +455,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
454455
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort):
455456
raise RuntimeError(f"cycle detected in graph, kernel must either depend on ASSIGN or BUFFER for {k}")
456457
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
457-
# increment the refcount of the target buf (this is required by the JIT and memory planner)
458-
u.buf_uop.buffer.ref(1)
459458
if assign_rep: sched_sink = sched_sink.substitute(assign_rep)
460459

461460
# display the final graph
@@ -478,6 +477,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
478477
while queue:
479478
u = queue.popleft()
480479
schedule.append(schedule_uop(u, var_vals))
480+
# increment the refcount of the target buf (this is required by the JIT and memory planner)
481+
u.buf_uop.buffer.ref(1)
481482
for x in children.get(u, []):
482483
in_degree[x] -= 1
483484
if in_degree[x] == 0: queue.append(x)

0 commit comments

Comments
 (0)