Skip to content

Commit 219b8c9

Browse files
authored
return becomes_map in scheduler [pr] (tinygrad#9766)
* add a graph_rewrite pass for creating asts [pr] * disk * benchmark * return becomes_map in scheduler * reorder schedule.py into grouper and linearizer [pr] * comments
1 parent 6306dea commit 219b8c9

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

tinygrad/engine/schedule.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -385,22 +385,8 @@ def fix_kernel_ast(ctx:dict[Variable, int], k:UOp) -> UOp|None:
385385
return k.replace(arg=Kernel(ast, k.arg.metadata))
386386
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
387387

388-
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
389-
if CAPTURE_PROCESS_REPLAY:
390-
@atexit.register
391-
def save_process_replay():
392-
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
393-
394-
# **** schedule creation and toposort
395-
396-
@dataclass(frozen=True)
397-
class ScheduleItem:
398-
ast: UOp
399-
bufs: tuple[Buffer, ...]
400-
metadata: tuple[Metadata, ...] = ()
401-
402388
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
403-
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
389+
def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
404390
# merge_views + sym + reorder_view + replace_contiguous
405391
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous, ctx={})
406392
sink = tensor_map[big_sink]
@@ -451,8 +437,28 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
451437
# unbind var_vals and fix kernel ast
452438
var_vals: dict[Variable, int] = {}
453439
sched_sink = graph_rewrite(sched_sink, create_ast, ctx=var_vals, bottom_up=True)
440+
becomes_map[big_sink] = sched_sink
441+
return becomes_map, var_vals
454442

455-
# final toposort (bfs)
443+
# **** schedule linearizer
444+
445+
@dataclass(frozen=True)
446+
class ScheduleItem:
447+
ast: UOp
448+
bufs: tuple[Buffer, ...]
449+
metadata: tuple[Metadata, ...] = ()
450+
451+
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
452+
if CAPTURE_PROCESS_REPLAY:
453+
@atexit.register
454+
def save_process_replay():
455+
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
456+
457+
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
458+
becomes_map, var_vals = get_becomes_map(big_sink)
459+
sched_sink = becomes_map.pop(big_sink)
460+
461+
# bfs toposort
456462
children: dict[UOp, list[UOp]] = {}
457463
in_degree: dict[UOp, int] = {}
458464
for u in sched_sink.toposort:
@@ -476,9 +482,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
476482
if in_degree[x] == 0: queue.append(x)
477483

478484
# confirm everything was scheduled correctly
479-
if len(schedule) != (kc:=len(in_degree)): raise RuntimeError(f"cycle detected in graph, created {kc} kernels but only scheduled {len(schedule)}")
485+
if len(schedule) != len(in_degree): raise RuntimeError(f"created {len(in_degree)} kernels but only scheduled {len(schedule)}")
480486
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
487+
481488
# capture process replay
482489
if CAPTURE_PROCESS_REPLAY:
483490
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(big_sink.key)] = pickle.dumps((big_sink, ContextVar._cache, [x.ast for x in schedule]))
491+
484492
return schedule, var_vals, becomes_map

0 commit comments

Comments
 (0)