@@ -615,10 +615,20 @@ def test_pow_const(self):
615
615
helper_test_op ([(45 ,65 )], lambda x : 2.0 ** x )
616
616
helper_test_op ([()], lambda x : x ** 2.0 )
617
617
helper_test_op ([()], lambda x : 2.0 ** x )
618
- # TODO: fix backward
619
- helper_test_op (None , lambda x : 0 ** x , vals = [[- 2. ,- 1 ,0 ,1 ,2 ,3 ]], forward_only = True )
618
+ helper_test_op (None , lambda x : 0 ** x , vals = [[- 2. ,- 1 ,0 ,1 ,2 ,3 ]])
620
619
helper_test_op (None , lambda x : (- 2 )** x , vals = [[- 2. ,- 1 ,0 ,1 ,2 ,3 ]])
621
620
621
+ def test_pow_zero_tensor (self ):
622
+ helper_test_op (None , lambda x ,y : x ** y , vals = [[0.0 ], [0.3 ]])
623
+ helper_test_op (None , lambda x ,y : x ** y , vals = [[0.0 ], [0.0 ]])
624
+ # TODO: fix WEBGPU
625
+ if Device .DEFAULT != "WEBGPU" :
626
+ helper_test_op (None , lambda x ,y : x ** y , vals = [[0.0 ], [- 0.3 ]])
627
+ def test_pow_zero_const (self ):
628
+ helper_test_op (None , lambda x : x ** 0.3 , vals = [[0.0 ]])
629
+ helper_test_op (None , lambda x : x ** 0.0 , vals = [[0.0 ]])
630
+ helper_test_op (None , lambda x : x ** - 0.3 , vals = [[0.0 ]])
631
+
622
632
@unittest .skip ("not supported" )
623
633
def test_pow_int (self ):
624
634
def _test (base , exponent ): helper_test_op (None , lambda x ,y : x ** y , vals = [base , exponent ], forward_only = True )
0 commit comments