@@ -267,8 +267,8 @@ def inputs(self) -> tuple[Buffer, ...]:
267
267
def output_idxs (self ) -> tuple [int , ...]: return tuple (x .src [0 ].arg for x in self .ast .src ) if self .ast .op is Ops .SINK else (0 ,)
268
268
269
269
def kernel_to_si (k :UOp ) -> ScheduleItem :
270
- assert k .op is Ops .KERNEL , f"must be KERNEL { k } "
271
- return ScheduleItem (k .arg .ast , tuple (u .buf_uop .buffer for u in k .src ), k .arg . metadata )
270
+ assert k .op is Ops .KERNEL and isinstance ( k . metadata , tuple ) , f"must be KERNEL { k } "
271
+ return ScheduleItem (k .arg .ast , tuple (u .buf_uop .buffer for u in k .src ), k .metadata )
272
272
273
273
# **** Kernel creation
274
274
@@ -433,7 +433,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
433
433
for k ,v in tensor_map .items ():
434
434
if (b := buffer_map .get (v )) is not None :
435
435
buf_tensors .setdefault (b , []).append (k )
436
- ops_metadata [b ] = k .metadata
436
+ if isinstance ( k . metadata , Metadata ): ops_metadata [b ] = k .metadata
437
437
realize_map = group_realizes (sink , ctx := ScheduleContext (ops_metadata ))
438
438
if len (realize_map ) == 0 : return [], {}, becomes_map
439
439
@@ -460,7 +460,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
460
460
in_degree : defaultdict [ScheduleItem , int ] = defaultdict (int )
461
461
for si in prescheduled :
462
462
# realize outputs before a parent is assigned to
463
- parents_assigns = dedup (xsi for x in ctx .preloads [si .bufs [0 ]] if (xsi := schedule_targets .get (x .buffer )) and xsi is not si )
463
+ parents_assigns = dedup (xsi for x in ctx .preloads [si .bufs [0 ]] if (xsi := schedule_targets .get (x .buffer )) is not None and xsi is not si )
464
464
for assign in parents_assigns :
465
465
graph [si ].append (assign )
466
466
in_degree [assign ] += 1
0 commit comments