Skip to content

Commit e7edadd

Browse files
authored
construct the sched_sink with graph_rewrite [pr] (tinygrad#8903)
* construct the sched_sink with graph_rewrite * diff * move break_sched
1 parent ef7ad3f commit e7edadd

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

tinygrad/engine/schedule.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class ScheduleContext:
8888
assigns: dict[UOp, None] = field(default_factory=dict) # this holds all the BUFFER uops we ASSIGN to in this schedule
8989
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
9090
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
91+
var_vals: dict[Variable, int] = field(default_factory=dict)
9192
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
9293
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
9394

@@ -230,7 +231,6 @@ def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
230231
for reduceop in double_reduces:
231232
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
232233
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
233-
graph_rewrite(sink, break_sched, ctx)
234234
return ctx.realizes
235235

236236
# break the SINK into stores
@@ -372,11 +372,11 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
372372
return var
373373
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
374374

375-
def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> UOp:
375+
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> UOp:
376376
# unbind_vars + push views to edges
377-
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right)
377+
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=ctx.var_vals), view_right)
378378
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
379-
ast = graph_rewrite(sink, to_si, si_ctx:=KernelContext(var_vals))
379+
ast = graph_rewrite(sink, to_si, si_ctx:=KernelContext(ctx.var_vals))
380380
# deal with ASSIGN
381381
if len(ctx.assigns) != 0:
382382
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
@@ -399,6 +399,11 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> UOp:
399399
metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None))
400400
return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata))
401401

402+
create_kernels = PatternMatcher([
403+
(UPat(Ops.SINK, name="x"), lambda ctx,x: x.replace(src=tuple(schedule_uop(s.sink(), ctx) for s in x.src))
404+
if any(s.op is not Ops.KERNEL for s in x.src) else None),
405+
])
406+
402407
# **** schedule creation and toposort
403408

404409
@track_rewrites(named=True)
@@ -425,20 +430,16 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
425430
buf_tensors.setdefault(b, []).append(k)
426431
ops_metadata[b] = k.metadata
427432
realize_map = group_realizes(sink, ctx:=ScheduleContext(ops_metadata))
433+
if len(realize_map) == 0: return [], {}, becomes_map
428434

429-
# create kernels + map buffers to realized tensors
430-
sinks: list[UOp] = []
431-
var_vals: dict[Variable, int] = {}
432-
for buf_uop,store in realize_map.items():
433-
assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}"
434-
sinks.append(schedule_uop(store.sink(), ctx, var_vals))
435-
# can only schedule once
435+
# map buffers to realized tensors
436+
for buf_uop in realize_map:
436437
for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
437-
# increment refcount for this buffer
438438
buf_uop.buffer.ref(1)
439-
sched_sink = UOp(Ops.SINK, src=tuple(sinks))
440-
# display, TODO: this isn't a complete sched_sink yet
441-
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]))
439+
440+
# create kernels, TODO: this should use the SINK from tensor_map
441+
graph_rewrite(sink, break_sched, ctx)
442+
sched_sink = graph_rewrite(UOp.sink(*realize_map.values()), create_kernels, ctx)
442443
type_verify(list(sched_sink.toposort), kernel_spec)
443444

444445
# TODO: this should be the break between the "grouper" and the "linearizer"
@@ -479,4 +480,4 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
479480
if CAPTURE_PROCESS_REPLAY:
480481
with Context(PICKLE_BUFFERS=0):
481482
diskcache_put("schedule_process_replay", str(big_sink.key), (big_sink, ContextVar._cache, [x.ast for x in schedule]))
482-
return schedule, var_vals, becomes_map
483+
return schedule, ctx.var_vals, becomes_map

0 commit comments

Comments
 (0)