Skip to content

Commit e618efc

Browse files
authored
COMMUTATIVE flipping is only for ints (tinygrad#8996)
* COMMUTATIVE flipping is only for ints [pr] * no pr * comm fixes this
1 parent 2983285 commit e618efc

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

test/test_linearizer_failures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,7 +1368,7 @@ def test_failure_56(self):
13681368
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
13691369
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),))
13701370
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=2, arg=32)]
1371-
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"])
1371+
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
13721372

13731373
def test_failure_57(self):
13741374
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
@@ -1409,7 +1409,7 @@ def test_failure_57(self):
14091409
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
14101410
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),))
14111411
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32)]
1412-
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"])
1412+
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
14131413

14141414
if __name__ == '__main__':
14151415
unittest.main()

tinygrad/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,8 +1170,8 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
11701170
])
11711171

11721172
symbolic = symbolic_simple+PatternMatcher([
1173-
# ** COMMUTATIVE flipping **
1174-
(UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
1173+
# ** COMMUTATIVE flipping (only for ints) **
1174+
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
11751175
# ** boolean algebra **
11761176
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
11771177
# ** combine terms **

0 commit comments

Comments
 (0)