Skip to content

Commit aefbc26

Browse files
authored
test fixups from unmasked valid deletion [pr] (tinygrad#8776)
1 parent ed67288 commit aefbc26

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

test/test_uops.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -389,12 +389,11 @@ def test_compare_alu_same_src_different_arg(self):
389389

390390
def test_uop_variables(self):
391391
a = UOp.variable("a", 1, 10)
392-
uop_var = UOp.const(dtypes.int, a)
393-
st_var = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0),
394-
ShapeTracker.from_shape((2, a)).to_uop()))
395-
ast_vars = (st_var+uop_var).variables()
396-
self.assertEqual(len(ast_vars), 1)
397-
self.assertEqual(ast_vars[0], a)
392+
uop_var = Tensor(a.bind(1))
393+
st_var = Tensor.empty((2, 1)).reshape((2, a.bind(1)))
394+
_, var_vals = (uop_var+st_var).schedule_with_vars()
395+
self.assertEqual(len(var_vals), 1)
396+
self.assertEqual(list(var_vals)[0], a)
398397

399398
def test_const_factor(self):
400399
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 8))

test/unit/test_verify_ast.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,12 @@ def test_assert_swizzle(self):
8888
st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a)
8989
with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st)
9090

91-
def test_flat_const_always_valid(self):
91+
def test_const_view_always_valid(self):
9292
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
93-
a = UOp.const(dtypes.int, 0).cast(dtypes.float)
94-
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a)
95-
helper_test_verify_ast(st)
93+
a = UOp.const(dtypes.int, 0).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg="CLANG"),), ShapeTracker.from_shape(())),))
94+
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a.cast(dtypes.float))
95+
# lowerer asserts because it does not remove ShapeTracker on CONST(VIEW(DEVICE))
96+
with self.assertRaises(AssertionError): helper_test_verify_ast(st)
9697

9798
if __name__ == '__main__':
9899
unittest.main()

tinygrad/codegen/kernel.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) ->
697697
if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
698698
# movementops are pushed to VIEW
699699
elif uop.op is Ops.VIEW:
700-
assert len(uop.src) == 0, f"can't swizzle in kernel yet {uop}"
700+
# NOTE: we disallow VIEW in the middle of the AST, if it has a DEVICE source it's fine
701+
assert len(uop.src) == 0 or uop.src[0].op is Ops.DEVICE, f"can't swizzle in kernel yet {uop}"
701702
st = uop.arg
702703
# everything else inherits shape
703704
else:

0 commit comments

Comments
 (0)