@@ -88,6 +88,7 @@ class ScheduleContext:
88
88
assigns : dict [UOp , None ] = field (default_factory = dict ) # this holds all the BUFFER uops we ASSIGN to in this schedule
89
89
realizes : dict [UOp , UOp ] = field (default_factory = dict ) # this holds all the BUFFER uops we mutate in this schedule
90
90
allbufs : dict [UOp , UOp ] = field (default_factory = dict ) # this maps BUFFER uops the actual op
91
+ var_vals : dict [Variable , int ] = field (default_factory = dict )
91
92
children : defaultdict [UOp , dict [UOp , None ]] = field (default_factory = lambda : defaultdict (dict ))
92
93
preloads : defaultdict [Buffer , dict [UOp , None ]] = field (default_factory = lambda : defaultdict (dict ))
93
94
@@ -230,7 +231,6 @@ def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
230
231
for reduceop in double_reduces :
231
232
top_reduce = uval (ctx .allbufs [reduceop ]).src [0 ].base .buf_uop
232
233
if len (ctx .children [top_reduce ]) == 1 : del ctx .realizes [top_reduce ]
233
- graph_rewrite (sink , break_sched , ctx )
234
234
return ctx .realizes
235
235
236
236
# break the SINK into stores
@@ -372,11 +372,11 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
372
372
return var
373
373
unbind_vars = PatternMatcher ([(UPat (Ops .BIND , name = "bind" , src = (UPat .var ("var" ), UPat .cvar ("val" ))), unbind_variable ),])
374
374
375
- def schedule_uop (pre :UOp , ctx :ScheduleContext , var_vals : dict [ UOp , int ] ) -> UOp :
375
+ def schedule_uop (pre :UOp , ctx :ScheduleContext ) -> UOp :
376
376
# unbind_vars + push views to edges
377
- sink = graph_rewrite (graph_rewrite (pre , unbind_vars + view_left , ctx = var_vals ), view_right )
377
+ sink = graph_rewrite (graph_rewrite (pre , unbind_vars + view_left , ctx = ctx . var_vals ), view_right )
378
378
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
379
- ast = graph_rewrite (sink , to_si , si_ctx := KernelContext (var_vals ))
379
+ ast = graph_rewrite (sink , to_si , si_ctx := KernelContext (ctx . var_vals ))
380
380
# deal with ASSIGN
381
381
if len (ctx .assigns ) != 0 :
382
382
assign_preloads = ctx .preloads [si_ctx .bufs [0 ].buffer ]
@@ -399,6 +399,11 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> UOp:
399
399
metadata = tuple (dedup (m for x in pre .toposort if x .op is not Ops .BUFFER and (m := ctx .ops_metadata .get (x )) is not None ))
400
400
return UOp (Ops .KERNEL , src = tuple (si_ctx .bufs ), arg = Kernel (ast , metadata ))
401
401
402
+ create_kernels = PatternMatcher ([
403
+ (UPat (Ops .SINK , name = "x" ), lambda ctx ,x : x .replace (src = tuple (schedule_uop (s .sink (), ctx ) for s in x .src ))
404
+ if any (s .op is not Ops .KERNEL for s in x .src ) else None ),
405
+ ])
406
+
402
407
# **** schedule creation and toposort
403
408
404
409
@track_rewrites (named = True )
@@ -425,20 +430,16 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
425
430
buf_tensors .setdefault (b , []).append (k )
426
431
ops_metadata [b ] = k .metadata
427
432
realize_map = group_realizes (sink , ctx := ScheduleContext (ops_metadata ))
433
+ if len (realize_map ) == 0 : return [], {}, becomes_map
428
434
429
- # create kernels + map buffers to realized tensors
430
- sinks : list [UOp ] = []
431
- var_vals : dict [Variable , int ] = {}
432
- for buf_uop ,store in realize_map .items ():
433
- assert store .op is Ops .STORE , f"expected a realized BUFFER to get a STORE { sink } "
434
- sinks .append (schedule_uop (store .sink (), ctx , var_vals ))
435
- # can only schedule once
435
+ # map buffers to realized tensors
436
+ for buf_uop in realize_map :
436
437
for tensor_uop in buf_tensors [buf_uop ]: becomes_map [tensor_uop ] = buf_uop .view (unwrap (tensor_uop .st ))
437
- # increment refcount for this buffer
438
438
buf_uop .buffer .ref (1 )
439
- sched_sink = UOp (Ops .SINK , src = tuple (sinks ))
440
- # display, TODO: this isn't a complete sched_sink yet
441
- if getenv ("VIZ" ): graph_rewrite (sched_sink , PatternMatcher ([]))
439
+
440
+ # create kernels, TODO: this should use the SINK from tensor_map
441
+ graph_rewrite (sink , break_sched , ctx )
442
+ sched_sink = graph_rewrite (UOp .sink (* realize_map .values ()), create_kernels , ctx )
442
443
type_verify (list (sched_sink .toposort ), kernel_spec )
443
444
444
445
# TODO: this should be the break between the "grouper" and the "linearizer"
@@ -479,4 +480,4 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
479
480
if CAPTURE_PROCESS_REPLAY :
480
481
with Context (PICKLE_BUFFERS = 0 ):
481
482
diskcache_put ("schedule_process_replay" , str (big_sink .key ), (big_sink , ContextVar ._cache , [x .ast for x in schedule ]))
482
- return schedule , var_vals , becomes_map
483
+ return schedule , ctx . var_vals , becomes_map
0 commit comments