Skip to content

Commit 0d2762c

Browse files
authored
prep refactor for adding buffer ops last [pr] (tinygrad#9383)
* prep refactor for adding buffer ops last [pr] * freeze buffers * add swizzle_reduceop * shape for reduceop_view_right * simpler elementwise_view_right * add shapetracker to const * only const * from process replay
1 parent bde0347 commit 0d2762c

File tree

2 files changed

+21
-26
lines changed

2 files changed

+21
-26
lines changed

test/test_schedule.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1871,7 +1871,7 @@ def test_simple_store_reshape(self):
18711871
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
18721872
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
18731873
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
1874-
r = r + 2
1874+
r = r + r.const_like(2).replace(src=(unwrap(r.st).to_uop(),))
18751875
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
18761876
rsink = graph_rewrite(sink, view_right)
18771877
# this AST first needs to swizzle, but it doesn't have implicit movementops

tinygrad/engine/schedule.py

+20-25
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,9 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
260260

261261
# ** create buffer ops + enumerate buffers
262262

263-
def load_buf(ctx:list[UOp], x:UOp):
264-
if x not in ctx: ctx.append(x)
265-
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
266-
267263
add_buffer_ops = PatternMatcher([
268264
# LOAD
269-
(UPat(Ops.BUFFER, name="x"), load_buf),
265+
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))),
270266
# STORE (except for COPY/BUFFER_VIEW)
271267
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
272268
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
@@ -278,8 +274,9 @@ def load_buf(ctx:list[UOp], x:UOp):
278274
def apply_swizzle(u:UOp) -> UOp:
279275
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
280276

281-
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
282-
input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
277+
def swizzle_reduceop(r:UOp, src:UOp, view:UOp):
278+
if (st:=unwrap(view.st)).contiguous: return None
279+
input_st = ShapeTracker.from_shape(src.shape)
283280
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
284281
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
285282
strides = strides_for_shape(rshape)
@@ -290,20 +287,18 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
290287
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
291288
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
292289

293-
def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
294-
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
295-
output_shape = swizzle_st.reduce(r.axis_arg)
296-
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
290+
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
291+
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
292+
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape))
297293

298294
def elementwise_view_right(root:UOp) -> UOp|None:
299-
if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
300-
assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
295+
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW]): return None
301296
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
302-
# push the swizzle from src to root
303-
output_swizzle = swizzles[0]
304-
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
305-
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
306-
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
297+
# place view after applying the elementwise op
298+
new_shape = swizzles[0].base.shape
299+
ret = root.replace(src=tuple(x.base if x.base.shape == new_shape else apply_swizzle(x.view(ShapeTracker.from_shape(new_shape))) for x in root.src))
300+
# reshape to match downstream shapes
301+
return ret.reshape(root.shape)
307302

308303
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
309304
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
@@ -317,12 +312,12 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
317312
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
318313
# STORE is the last child, so we just merge the ShapeTrackers and store the base
319314
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
320-
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
321-
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
322-
# REDUCE(src.view()) -> REDUCE(src).view()
323-
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
324-
# ALU(src.view()) -> ALU(src).view()
325-
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
315+
# push a non contiguous ShapeTracker through reduceop
316+
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
317+
# apply view after reduceops
318+
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat.var("src"),), name="v"),), name="r"), reduceop_view_right),
319+
# apply view after elementwise ops
320+
(UPat(GroupOp.All-GroupOp.Buffer, name="root"), elementwise_view_right),
326321
# double reduce op collapses to a single reduce op
327322
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
328323
])
@@ -372,7 +367,7 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
372367
# substitute kernel sources for the target buffer
373368
ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink()
374369
# add buffer ops
375-
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[s.buf_uop for s in k.src], bottom_up=True)
370+
ast = graph_rewrite(ast, add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True)
376371
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
377372
# unbind_vars + push views to edges
378373
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)

0 commit comments

Comments
 (0)