Skip to content

Commit ef7ad3f

Browse files
authored
simpler subbuffer construction + copyin is always base (tinygrad#8900)
* realize copy * cleanup buffer_view * smaller
1 parent 6f0cc2e commit ef7ad3f

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

tinygrad/engine/schedule.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
6363
# support for using a contiguous permuted view instead of the parent view if one exists
6464
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
6565
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
66+
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
67+
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
68+
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
6669
# remove CONST/BIND/BUFFER from SINK
6770
(UPat(Ops.SINK, name="root"),
6871
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
@@ -112,6 +115,7 @@ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp
112115
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
113116
# track the buffer uop for the simplified uop
114117
buffer_map[buf] = buf_uop
118+
if op.op is Ops.BUFFER_VIEW: buffers[buf_uop] = (x:=op.src[0]).buf_uop.buffer.view(op.size, op.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
115119
# (early) bufferize
116120
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
117121
return ret
@@ -132,23 +136,15 @@ def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs)
132136
# otherwise safety check pads
133137
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, dict())) else realize(ctx, b, src)
134138

135-
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
136-
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
137-
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
138-
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
139-
140139
do_realize = PatternMatcher([
141140
# always realize SINK parents
142141
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
143142
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
144143
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
145144
# realize before expand or unsafe pad ops
146145
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
147-
# realize before COPY or BUFFER_VIEW
148-
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
149-
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
150-
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
151-
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
146+
# realize before COPY
147+
(UPat(Ops.COPY, src=(UPat(), UPatScheduled())), realize),
152148
])
153149

154150
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:

0 commit comments

Comments
 (0)