Skip to content

Commit 6306dea

Browse files
authored
add a graph_rewrite pass for creating asts [pr] (tinygrad#9765)
* add a graph_rewrite pass for creating asts [pr] * disk * benchmark
1 parent 07eea56 commit 6306dea

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

tinygrad/engine/schedule.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,8 @@ def check_load_st(glbl:UOp, view:UOp):
369369
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
370370
])
371371

372-
def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
373-
assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}"
372+
def fix_kernel_ast(ctx:dict[Variable, int], k:UOp) -> UOp|None:
373+
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
374374
# substitute kernel sources for the target buffer + apply reshapes
375375
parents_rep: dict[UOp, UOp] = {}
376376
for s in k.src:
@@ -380,11 +380,10 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
380380
# push views to edges
381381
ast = graph_rewrite(graph_rewrite(ast, view_left), view_right)
382382
# add buffer ops + fix_kernel_ops
383-
ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True)
383+
ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, ctx=(ctx, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True)
384384
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
385-
# create subbuffer (TODO: this does not belong here)
386-
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
387385
return k.replace(arg=Kernel(ast, k.arg.metadata))
386+
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
388387

389388
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
390389
if CAPTURE_PROCESS_REPLAY:
@@ -449,6 +448,10 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
449448
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
450449
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
451450

451+
# unbind var_vals and fix kernel ast
452+
var_vals: dict[Variable, int] = {}
453+
sched_sink = graph_rewrite(sched_sink, create_ast, ctx=var_vals, bottom_up=True)
454+
452455
# final toposort (bfs)
453456
children: dict[UOp, list[UOp]] = {}
454457
in_degree: dict[UOp, int] = {}
@@ -462,11 +465,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
462465

463466
queue = deque(k for k,v in in_degree.items() if v == 0)
464467
schedule: list[ScheduleItem] = []
465-
var_vals: dict[Variable, int] = {}
466468
while queue:
467469
u = queue.popleft()
468-
# TODO: move this to create_kernels
469-
k = fix_kernel_ast(u.src[1], var_vals)
470+
# map the BUFFER UOp to a subbuffer if it's a BUFFER_VIEW
471+
if (k:=u.src[1]).arg.ast.op is Ops.BUFFER_VIEW:
472+
buffers[k.src[0]] = (base:=k.src[1].buf_uop.buffer).view(k.size, k.arg.ast.dtype, k.arg.ast.arg[1]*base.dtype.itemsize)
470473
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
471474
for x in children.get(u, []):
472475
in_degree[x] -= 1

0 commit comments

Comments
 (0)