8
8
from tinygrad .engine .realize import CompiledRunner
9
9
from tinygrad .helpers import dedup , flatten , prod
10
10
from tinygrad .renderer .cstyle import CStyleLanguage
11
+ from tinygrad .renderer .ptx import PTXRenderer
11
12
from tinygrad .ops import UOp , Ops
12
13
from tinygrad .renderer import ProgramSpec
13
14
from tinygrad .tensor import Tensor , _to_np_dtype
@@ -41,7 +42,7 @@ def test_inline_const_alu(self):
41
42
ret = _test_uop_result ([Tensor ([1 ])], uops )[0 ]
42
43
self .assertEqual (ret [0 ], 1 )
43
44
44
- @unittest .skipUnless ( Device [Device .DEFAULT ].renderer . has_local and Device . DEFAULT == "PTX" , "need local " )
45
+ @unittest .skipIf ( not isinstance ( Device [Device .DEFAULT ].renderer , PTXRenderer ) , "tests for ptx renderer " )
45
46
class TestPTXFailures (unittest .TestCase ):
46
47
def test_gated_store_with_alu (self ):
47
48
a = UOp (Ops .DEFINE_GLOBAL , dtypes .int .ptr (), (), 0 )
@@ -63,5 +64,12 @@ def test_gated_store_with_if(self):
63
64
ret = _test_uop_result ([], uops , local_size = [4 , 1 , 1 ])[0 ]
64
65
np .testing .assert_equal (ret , [0 , 1 , 1 , 1 ])
65
66
67
+ def test_gated_define_acc_with_half_dtype (self ):
68
+ a = Tensor .randn (32 , 32 , dtype = dtypes .half ).realize ()
69
+ b = Tensor .randn (34 , 32 , dtype = dtypes .half ).realize ()
70
+ result = a .pad ((1 ,1 )).matmul (b , acc_dtype = dtypes .half ).numpy ()
71
+ reference = a .pad ((1 ,1 )).matmul (b , acc_dtype = dtypes .float ).numpy ()
72
+ np .testing .assert_allclose (result , reference )
73
+
66
74
if __name__ == '__main__' :
67
75
unittest .main ()
0 commit comments