Skip to content

Commit cf21e27

Browse files
authored
little better VIEW simplifier pattern [pr] (tinygrad#8954)
1 parent 329013f commit cf21e27

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tinygrad/ops.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1267,7 +1267,8 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
12671267
merge_views = PatternMatcher([
12681268
# VIEW(VIEW) merges to a single VIEW
12691269
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
1270-
(UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None),
1270+
# remove VIEW if it's contiguous and same as the base shape
1271+
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda vm,x: x if vm.st.contiguous and x.shape == vm.shape else None),
12711272
# merge unmasked const views
12721273
(UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
12731274
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),

0 commit comments

Comments
 (0)