Skip to content

Commit f0924e0

Browse files
fix and test (tinygrad#8814)
Co-authored-by: chenyu <chenyu@fastmail.com>
1 parent f5da275 commit f0924e0

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

test/test_renderer_failures.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tinygrad.engine.realize import CompiledRunner
99
from tinygrad.helpers import dedup, flatten, prod
1010
from tinygrad.renderer.cstyle import CStyleLanguage
11+
from tinygrad.renderer.ptx import PTXRenderer
1112
from tinygrad.ops import UOp, Ops
1213
from tinygrad.renderer import ProgramSpec
1314
from tinygrad.tensor import Tensor, _to_np_dtype
@@ -41,7 +42,7 @@ def test_inline_const_alu(self):
4142
ret = _test_uop_result([Tensor([1])], uops)[0]
4243
self.assertEqual(ret[0], 1)
4344

44-
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local and Device.DEFAULT == "PTX", "need local")
45+
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "tests for ptx renderer")
4546
class TestPTXFailures(unittest.TestCase):
4647
def test_gated_store_with_alu(self):
4748
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
@@ -63,5 +64,12 @@ def test_gated_store_with_if(self):
6364
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
6465
np.testing.assert_equal(ret, [0, 1, 1, 1])
6566

67+
def test_gated_define_acc_with_half_dtype(self):
68+
a = Tensor.randn(32, 32, dtype=dtypes.half).realize()
69+
b = Tensor.randn(34, 32, dtype=dtypes.half).realize()
70+
result = a.pad((1,1)).matmul(b, acc_dtype=dtypes.half).numpy()
71+
reference = a.pad((1,1)).matmul(b, acc_dtype=dtypes.float).numpy()
72+
np.testing.assert_allclose(result, reference)
73+
6674
if __name__ == '__main__':
6775
unittest.main()

tinygrad/renderer/ptx.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,7 @@ def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str:
176176
if u.op in {Ops.CAST, Ops.BITCAST} and (u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType)):
177177
r[u] = r[u.src[0]]
178178
continue
179-
if u.op is Ops.DEFINE_ACC and u.dtype in [dtypes.half, dtypes.bool]: r[u.src[0]] = ssa("const", u.src[0])
180-
elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
179+
if u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0]
181180
elif u.op is Ops.DEFINE_VAR: bufs.append((u.arg[0], u.dtype))
182181
elif u.op is Ops.LOAD:
183182
assert u.src[0].dtype == dtypes.int64, "load isn't int64"

0 commit comments

Comments
 (0)