Skip to content

Commit 9df8e34

Browse files
authored
prereqs for giving BUFFER UOps a ShapeTracker [pr] (tinygrad#8809)
1 parent 78c0455 commit 9df8e34

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tinygrad/engine/schedule.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ def elementwise_view_right(root:UOp) -> UOp|None:
142142
output_swizzle = swizzles[0]
143143
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
144144
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
145-
# NOTE: swizzle resolves once we hit STORE
146-
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(output_swizzle.shape))
145+
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
147146

148147
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
149148
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
@@ -155,6 +154,8 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
155154
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
156155
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
157156
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
157+
# STORE is the last child, so we just merge the ShapeTrackers and store the base
158+
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
158159
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
159160
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
160161
# REDUCE(src.view()) -> REDUCE(src).view()
@@ -303,7 +304,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
303304
if (p_uop:=ctx.allbufs.get(p:=parents.pop())) is None: continue
304305
if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False
305306
if p in ctx.realizes: continue
306-
parents.extend([x.base.src[0] for x in p_uop.src if x.base.op is Ops.VIEW and len(x.base.src) != 0])
307+
parents.extend([x.base.buf_uop for x in p_uop.src if x.base.is_realized or (x.base.op is Ops.VIEW and len(x.base.src) != 0)])
307308
if forced_realize or not group:
308309
tr = r
309310
if can_chase:

0 commit comments

Comments
 (0)