Skip to content

Commit 73af42a

Browse files
authored
fix pow backward when base is 0 (tinygrad#9075)
1 parent 2d04a75 commit 73af42a

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

test/test_ops.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,20 @@ def test_pow_const(self):
615615
helper_test_op([(45,65)], lambda x: 2.0**x)
616616
helper_test_op([()], lambda x: x**2.0)
617617
helper_test_op([()], lambda x: 2.0**x)
618-
# TODO: fix backward
619-
helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
618+
helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]])
620619
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]])
621620

621+
def test_pow_zero_tensor(self):
622+
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.3]])
623+
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [0.0]])
624+
# TODO: fix WEBGPU
625+
if Device.DEFAULT != "WEBGPU":
626+
helper_test_op(None, lambda x,y: x**y, vals=[[0.0], [-0.3]])
627+
def test_pow_zero_const(self):
628+
helper_test_op(None, lambda x: x**0.3, vals=[[0.0]])
629+
helper_test_op(None, lambda x: x**0.0, vals=[[0.0]])
630+
helper_test_op(None, lambda x: x**-0.3, vals=[[0.0]])
631+
622632
@unittest.skip("not supported")
623633
def test_pow_int(self):
624634
def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True)

tinygrad/gradient.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def reduce_gradient(ctx:UOp, ret:UOp):
2222
(UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
2323
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
2424
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
25-
(UPat(Ops.POW, name="ret"), lambda ctx, ret: (ctx*ret*ret.src[1]/ret.src[0], ctx*ret*ret.src[0].log2()*math.log(2.0))),
25+
(UPat(Ops.POW, name="ret"), lambda ctx, ret:
26+
(ret.src[0].eq(0).where(ret.src[1].eq(0).where(ret.src[1], ret.src[1]*math.inf), ctx*ret*ret.src[1]/ret.src[0]),
27+
ret.src[0].eq(0).where((ret.src[1]<0).where(ret.const_like(-math.inf), ret.const_like(0)), ctx*ret*ret.src[0].log2()*math.log(2.0)))),
2628
(UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
2729
(ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
2830
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),

0 commit comments

Comments
 (0)