@@ -369,8 +369,8 @@ def check_load_st(glbl:UOp, view:UOp):
369
369
(UPat (Ops .LOAD , src = (UPat .var ("glbl" ), UPat .var ("view" ))), check_load_st ),
370
370
])
371
371
372
- def fix_kernel_ast (k : UOp , var_vals : dict [Variable , int ]) -> UOp :
373
- assert k .op is Ops . KERNEL , f"kernel isn't kernel, it's { k } "
372
+ def fix_kernel_ast (ctx : dict [Variable , int ], k : UOp ) -> UOp | None :
373
+ if k .arg . ast . op in GroupOp . Meta or all ( s . op is Ops . STORE for s in k . arg . ast . src ): return None
374
374
# substitute kernel sources for the target buffer + apply reshapes
375
375
parents_rep : dict [UOp , UOp ] = {}
376
376
for s in k .src :
@@ -380,11 +380,10 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
380
380
# push views to edges
381
381
ast = graph_rewrite (graph_rewrite (ast , view_left ), view_right )
382
382
# add buffer ops + fix_kernel_ops
383
- ast = graph_rewrite (ast , merge_views + add_buffer_ops + fix_kernel_ops , ctx = (var_vals , bufs := tuple (s .buf_uop for s in k .src )), bottom_up = True )
383
+ ast = graph_rewrite (ast , merge_views + add_buffer_ops + fix_kernel_ops , ctx = (ctx , bufs := tuple (s .buf_uop for s in k .src )), bottom_up = True )
384
384
if ast .op is Ops .SINK and not all_same (dev := [x .device for x in bufs ]): raise RuntimeError (f"all buffers must be on the same device: { dev } " )
385
- # create subbuffer (TODO: this does not belong here)
386
- if ast .op is Ops .BUFFER_VIEW : buffers [bufs [0 ]] = (base := bufs [1 ].buffer ).view (ast .size , ast .dtype , ast .arg [1 ]* base .dtype .itemsize )
387
385
return k .replace (arg = Kernel (ast , k .arg .metadata ))
386
+ create_ast = PatternMatcher ([(UPat (Ops .KERNEL , name = "k" ), fix_kernel_ast ),])
388
387
389
388
PROCESS_REPLAY_CAPTURE :dict [str , bytes ] = {}
390
389
if CAPTURE_PROCESS_REPLAY :
@@ -449,6 +448,10 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
449
448
if getenv ("VIZ" ): graph_rewrite (sched_sink , PatternMatcher ([]), name = "View Kernel Graph" )
450
449
if getenv ("VIZ" ): graph_rewrite (sched_sink , PatternMatcher ([]), name = "View Memory Graph" )
451
450
451
+ # unbind var_vals and fix kernel ast
452
+ var_vals : dict [Variable , int ] = {}
453
+ sched_sink = graph_rewrite (sched_sink , create_ast , ctx = var_vals , bottom_up = True )
454
+
452
455
# final toposort (bfs)
453
456
children : dict [UOp , list [UOp ]] = {}
454
457
in_degree : dict [UOp , int ] = {}
@@ -462,11 +465,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
462
465
463
466
queue = deque (k for k ,v in in_degree .items () if v == 0 )
464
467
schedule : list [ScheduleItem ] = []
465
- var_vals : dict [Variable , int ] = {}
466
468
while queue :
467
469
u = queue .popleft ()
468
- # TODO: move this to create_kernels
469
- k = fix_kernel_ast (u .src [1 ], var_vals )
470
+ # map the BUFFER UOp to a subbuffer if it's a BUFFER_VIEW
471
+ if (k := u .src [1 ]).arg .ast .op is Ops .BUFFER_VIEW :
472
+ buffers [k .src [0 ]] = (base := k .src [1 ].buf_uop .buffer ).view (k .size , k .arg .ast .dtype , k .arg .ast .arg [1 ]* base .dtype .itemsize )
470
473
schedule .append (ScheduleItem (k .arg .ast , tuple (s .buf_uop .buffer for s in k .src ), k .arg .metadata ))
471
474
for x in children .get (u , []):
472
475
in_degree [x ] -= 1
0 commit comments