Skip to content

Commit 2845f87

Browse files
authored
failed test cases for rsqrt at 0 and similar ones (tinygrad#9035)
* failed test cases for rsqrt at 0 and similar ones related to 0*inf * this failed
1 parent 101652a commit 2845f87

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

test/test_ops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,9 +641,14 @@ def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, ex
641641

642642
def test_sqrt(self):
643643
helper_test_op([(45,65)], lambda x: x.sqrt())
644+
if Device.DEFAULT not in ("LLVM", "DSP"):
645+
# TODO: fix backward
646+
helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]])
644647
helper_test_op([()], lambda x: x.sqrt())
645648
def test_rsqrt(self):
646649
helper_test_op([(45,65)], lambda x: x.rsqrt())
650+
# TODO: fix backward
651+
helper_test_op(None, lambda x: x.rsqrt(), vals=[[0.0]], forward_only=True)
647652
helper_test_op([()], lambda x: x.rsqrt())
648653

649654
def test_xor(self):
@@ -1274,6 +1279,7 @@ def test_var_zero_in_axis(self):
12741279
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=0))
12751280
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=5))
12761281
def test_var_one_in_axis(self):
1282+
helper_test_op([(1,)], lambda x: x.var(axis=(0,), correction=0))
12771283
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3)))
12781284
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3), correction=0))
12791285
helper_test_op([(1,2,3,1,5)], lambda x: x.var(axis=(0,3), correction=5))
@@ -1301,8 +1307,9 @@ def test_std_zero_in_axis(self):
13011307
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=0))
13021308
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=5))
13031309
def test_std_one_in_axis(self):
1310+
# TODO: fix backward
1311+
helper_test_op([(1,)], lambda x: x.std(axis=(0,), correction=0), forward_only=True)
13041312
helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,3)))
1305-
# TODO: this one broke with correction=0 in new gradient
13061313
helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,3), correction=0), forward_only=True)
13071314
helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,3), correction=5))
13081315
helper_test_op([(1,2,3,1,5)], lambda x: x.std(axis=(0,4)))

0 commit comments

Comments
 (0)