@@ -234,10 +234,10 @@ def __repr__(self):
234
234
@dataclass (frozen = True )
235
235
class KernelContext :
236
236
realizes : dict [UOp , None ]
237
- ops_metadata : dict [UOp , Metadata ]
237
+ metadata : dict [UOp , Metadata ]
238
238
239
239
def create_kernel (ctx :KernelContext , x :UOp , b :UOp ):
240
- kernel = UOp (Ops .KERNEL , src = (b ,)+ x .src , arg = Kernel (x .sink (), (m ,) if (m := ctx .ops_metadata .get (x )) else ()))
240
+ kernel = UOp (Ops .KERNEL , src = (b ,)+ x .src , arg = Kernel (x .sink (), (m ,) if (m := ctx .metadata .get (x )) else ()))
241
241
buffer = b .base if b .size == b .base .size else UOp (Ops .BUFFER_VIEW , b .dtype , (b .base ,), (b .size , b .arg .views [0 ].offset ))
242
242
return UOp (Ops .ASSIGN , x .dtype , (buffer , kernel )).reshape (x .shape )
243
243
@@ -250,7 +250,7 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
250
250
if s .op in DONT_PLACE_IN_KERNEL or s in ctx .realizes : new_srcs .append (s )
251
251
else :
252
252
new_srcs .extend (s .src )
253
- if (m := ctx .ops_metadata .get (s )) is not None : metadata [m ] = None
253
+ if (m := ctx .metadata .get (s )) is not None : metadata [m ] = None
254
254
if (new_src := tuple (dedup (new_srcs ))) != x .src : return x .replace (src = new_src , arg = Kernel (x .arg .ast , tuple (metadata )))
255
255
256
256
create_kernels = PatternMatcher ([
@@ -404,17 +404,15 @@ class ScheduleItem:
404
404
def create_schedule_with_vars (big_sink :UOp ) -> tuple [list [ScheduleItem ], dict [Variable , int ], dict [UOp , UOp ]]:
405
405
# merge_views + sym + reorder_view + replace_contiguous
406
406
tensor_map = graph_rewrite_map (big_sink , merge_views + sym + reorder_view + replace_contiguous , ctx = {})
407
+ sink = tensor_map [big_sink ]
408
+ metadata = {v :k .metadata for k ,v in tensor_map .items () if k .base .op not in {Ops .CONST , Ops .DEVICE } and isinstance (k .metadata , Metadata )}
407
409
408
410
# display the cleaned up tensor graph
409
- if getenv ("VIZ" ): graph_rewrite (tensor_map [ big_sink ] , PatternMatcher ([]), name = "View Tensor Graph" )
411
+ if getenv ("VIZ" ): graph_rewrite (sink , PatternMatcher ([]), name = "View Tensor Graph" )
410
412
411
- # get realizes
412
- sink = tensor_map [big_sink ]
413
+ # group into kernels
413
414
realize_map = group_realizes (sink )
414
- # map tensor metadata to simplified ops
415
- ops_metadata = {v :k .metadata for k ,v in tensor_map .items () if k .base .op not in {Ops .CONST , Ops .DEVICE } and isinstance (k .metadata , Metadata )}
416
- # merge_views + create_kernels
417
- kernel_map = graph_rewrite_map (sink , merge_views + create_kernels , ctx = KernelContext (realize_map , ops_metadata ), bottom_up = True )
415
+ kernel_map = graph_rewrite_map (sink , merge_views + create_kernels , ctx = KernelContext (realize_map , metadata ), bottom_up = True )
418
416
sched_sink = kernel_map [sink ]
419
417
type_verify (list (sched_sink .toposort ), kernel_spec )
420
418
0 commit comments