Skip to content

Commit be21616

Browse files
authored
reorder into swizzler + ast_fixup [pr] (tinygrad#9456)
1 parent cb7a7f6 commit be21616

File tree

1 file changed

+27
-30
lines changed

1 file changed

+27
-30
lines changed

tinygrad/engine/schedule.py

+27-30
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
223223
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
224224
return ctx.realizes
225225

226-
# break the SINK into kernels
226+
# **** create kernels
227227

228228
@dataclass(frozen=True)
229229
class Kernel:
@@ -243,6 +243,7 @@ def create_kernel(ctx:KernelContext, x:UOp, b:UOp):
243243
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
244244

245245
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
246+
246247
def append_to_kernel(ctx:KernelContext, x:UOp):
247248
new_srcs: list[UOp] = []
248249
metadata = dict.fromkeys(x.arg.metadata)
@@ -268,30 +269,7 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
268269
(UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None),
269270
])
270271

271-
# **** fix kernel AST
272-
273-
# ** create buffer ops + enumerate buffers
274-
275-
add_buffer_ops = PatternMatcher([
276-
# LOAD
277-
(UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)),
278-
# STORE (except for COPY/BUFFER_VIEW)
279-
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
280-
# partial assign can store to a non-contiguous ShapeTracker
281-
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
282-
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
283-
# otherwise the store is contiguous
284-
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
285-
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
286-
# if the last child is a VIEW we merge the ShapeTrackers and store the base
287-
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
288-
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
289-
# remove CONTIGUOUS/DEVICE from kernel AST
290-
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
291-
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
292-
])
293-
294-
# ** push views to buffer ops
272+
# **** swizzler
295273

296274
def apply_swizzle(u:UOp) -> UOp:
297275
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
@@ -314,7 +292,7 @@ def reduceop_view_right(src:UOp, v:UOp, r:UOp):
314292
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
315293
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape))
316294

317-
def elementwise_view_right(root:UOp) -> UOp|None:
295+
def elementwise_view_right(root:UOp):
318296
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in DONT_PUSH_VIEWS]): return None
319297
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
320298
# place view after applying the elementwise op
@@ -323,7 +301,7 @@ def elementwise_view_right(root:UOp) -> UOp|None:
323301
# reshape to match downstream shapes
324302
return root.replace(src=tuple(new_src)).reshape(root.shape)
325303

326-
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
304+
def merge_double_reduce(root:UOp, first_reduce:UOp):
327305
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
328306
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
329307
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
@@ -340,9 +318,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
340318
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
341319
])
342320

343-
# ** unbind variables
321+
# **** unbind variables
344322

345-
def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp) -> UOp|None:
323+
def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp):
346324
st = unwrap(x.st).simplify()
347325
if any(x.op is Ops.BIND for x in st.vars()):
348326
st, var_vals = st.unbind()
@@ -354,7 +332,26 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
354332
return var
355333
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
356334

357-
# ** fix_kernel_ops
335+
# **** fix kernel AST
336+
337+
add_buffer_ops = PatternMatcher([
338+
# LOAD
339+
(UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)),
340+
# STORE (except for COPY/BUFFER_VIEW)
341+
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
342+
# partial assign can store to a non-contiguous ShapeTracker
343+
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
344+
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
345+
# otherwise the store is contiguous
346+
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
347+
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
348+
# if the last child is a VIEW we merge the ShapeTrackers and store the base
349+
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
350+
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
351+
# remove CONTIGUOUS/DEVICE from kernel AST
352+
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
353+
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
354+
])
358355

359356
def check_load_st(glbl:UOp, view:UOp):
360357
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return

0 commit comments

Comments
 (0)