Skip to content

Commit 07eea56

Browse files
authored
reorder tensor_map and grouper parts [pr] (tinygrad#9764)
1 parent 8ddb135 commit 07eea56

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

tinygrad/engine/schedule.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,10 @@ def __repr__(self):
234234
@dataclass(frozen=True)
235235
class KernelContext:
236236
realizes: dict[UOp, None]
237-
ops_metadata: dict[UOp, Metadata]
237+
metadata: dict[UOp, Metadata]
238238

239239
def create_kernel(ctx:KernelContext, x:UOp, b:UOp):
240-
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), (m,) if (m:=ctx.ops_metadata.get(x)) else ()))
240+
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), (m,) if (m:=ctx.metadata.get(x)) else ()))
241241
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
242242
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
243243

@@ -250,7 +250,7 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
250250
if s.op in DONT_PLACE_IN_KERNEL or s in ctx.realizes: new_srcs.append(s)
251251
else:
252252
new_srcs.extend(s.src)
253-
if (m:=ctx.ops_metadata.get(s)) is not None: metadata[m] = None
253+
if (m:=ctx.metadata.get(s)) is not None: metadata[m] = None
254254
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(metadata)))
255255

256256
create_kernels = PatternMatcher([
@@ -404,17 +404,15 @@ class ScheduleItem:
404404
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
405405
# merge_views + sym + reorder_view + replace_contiguous
406406
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous, ctx={})
407+
sink = tensor_map[big_sink]
408+
metadata = {v:k.metadata for k,v in tensor_map.items() if k.base.op not in {Ops.CONST, Ops.DEVICE} and isinstance(k.metadata, Metadata)}
407409

408410
# display the cleaned up tensor graph
409-
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
411+
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Tensor Graph")
410412

411-
# get realizes
412-
sink = tensor_map[big_sink]
413+
# group into kernels
413414
realize_map = group_realizes(sink)
414-
# map tensor metadata to simplified ops
415-
ops_metadata = {v:k.metadata for k,v in tensor_map.items() if k.base.op not in {Ops.CONST, Ops.DEVICE} and isinstance(k.metadata, Metadata)}
416-
# merge_views + create_kernels
417-
kernel_map = graph_rewrite_map(sink, merge_views+create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
415+
kernel_map = graph_rewrite_map(sink, merge_views+create_kernels, ctx=KernelContext(realize_map, metadata), bottom_up=True)
418416
sched_sink = kernel_map[sink]
419417
type_verify(list(sched_sink.toposort), kernel_spec)
420418

0 commit comments

Comments
 (0)