@@ -385,22 +385,8 @@ def fix_kernel_ast(ctx:dict[Variable, int], k:UOp) -> UOp|None:
385
385
return k .replace (arg = Kernel (ast , k .arg .metadata ))
386
386
create_ast = PatternMatcher ([(UPat (Ops .KERNEL , name = "k" ), fix_kernel_ast ),])
387
387
388
- PROCESS_REPLAY_CAPTURE :dict [str , bytes ] = {}
389
- if CAPTURE_PROCESS_REPLAY :
390
- @atexit .register
391
- def save_process_replay ():
392
- for k ,v in PROCESS_REPLAY_CAPTURE .items (): diskcache_put ("schedule_process_replay" , k , v , prepickled = True )
393
-
394
- # **** schedule creation and toposort
395
-
396
- @dataclass (frozen = True )
397
- class ScheduleItem :
398
- ast : UOp
399
- bufs : tuple [Buffer , ...]
400
- metadata : tuple [Metadata , ...] = ()
401
-
402
388
@track_rewrites (name_fxn = lambda r : f"Schedule { pluralize ('Kernel' , len (r [0 ]))} " + (f" (with_{ pluralize ('Var' , len (r [1 ]))} )" if len (r [1 ]) != 0 else "" ))
403
- def create_schedule_with_vars (big_sink :UOp ) -> tuple [list [ ScheduleItem ], dict [Variable , int ], dict [UOp , UOp ]]:
389
+ def get_becomes_map (big_sink :UOp ) -> tuple [dict [UOp , UOp ], dict [Variable , int ]]:
404
390
# merge_views + sym + reorder_view + replace_contiguous
405
391
tensor_map = graph_rewrite_map (big_sink , merge_views + sym + reorder_view + replace_contiguous , ctx = {})
406
392
sink = tensor_map [big_sink ]
@@ -451,8 +437,28 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
451
437
# unbind var_vals and fix kernel ast
452
438
var_vals : dict [Variable , int ] = {}
453
439
sched_sink = graph_rewrite (sched_sink , create_ast , ctx = var_vals , bottom_up = True )
440
+ becomes_map [big_sink ] = sched_sink
441
+ return becomes_map , var_vals
454
442
455
- # final toposort (bfs)
443
+ # **** schedule linearizer
444
+
445
+ @dataclass (frozen = True )
446
+ class ScheduleItem :
447
+ ast : UOp
448
+ bufs : tuple [Buffer , ...]
449
+ metadata : tuple [Metadata , ...] = ()
450
+
451
+ PROCESS_REPLAY_CAPTURE :dict [str , bytes ] = {}
452
+ if CAPTURE_PROCESS_REPLAY :
453
+ @atexit .register
454
+ def save_process_replay ():
455
+ for k ,v in PROCESS_REPLAY_CAPTURE .items (): diskcache_put ("schedule_process_replay" , k , v , prepickled = True )
456
+
457
+ def create_schedule_with_vars (big_sink :UOp ) -> tuple [list [ScheduleItem ], dict [Variable , int ], dict [UOp , UOp ]]:
458
+ becomes_map , var_vals = get_becomes_map (big_sink )
459
+ sched_sink = becomes_map .pop (big_sink )
460
+
461
+ # bfs toposort
456
462
children : dict [UOp , list [UOp ]] = {}
457
463
in_degree : dict [UOp , int ] = {}
458
464
for u in sched_sink .toposort :
@@ -476,9 +482,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
476
482
if in_degree [x ] == 0 : queue .append (x )
477
483
478
484
# confirm everything was scheduled correctly
479
- if len (schedule ) != ( kc := len (in_degree )) : raise RuntimeError (f"cycle detected in graph, created { kc } kernels but only scheduled { len (schedule )} " )
485
+ if len (schedule ) != len (in_degree ): raise RuntimeError (f"created { len ( in_degree ) } kernels but only scheduled { len (schedule )} " )
480
486
if DEBUG >= 1 and len (schedule ) >= 10 : print (f"scheduled { len (schedule )} kernels" )
487
+
481
488
# capture process replay
482
489
if CAPTURE_PROCESS_REPLAY :
483
490
with Context (PICKLE_BUFFERS = 0 ): PROCESS_REPLAY_CAPTURE [str (big_sink .key )] = pickle .dumps ((big_sink , ContextVar ._cache , [x .ast for x in schedule ]))
491
+
484
492
return schedule , var_vals , becomes_map
0 commit comments