Skip to content

Commit 78c0455

Browse files
S-Lykleschenyuxyz
andauthored
Better stable sigmoid (tinygrad#8806)
Uses `1/(x*x) -> 1/x * 1/x` together with `x/(1+x) -> 1-1/(1+x)` to rewrite sigmoid instead of `x/((x+1)(x+1)) -> 1/(x+1)*(1-1/(x+1))` Co-authored-by: chenyu <chenyu@fastmail.com>
1 parent cac2b4e commit 78c0455

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

test/test_ops.py

-1
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,6 @@ def test_sigmoid_extreme(self):
835835
self.assertAlmostEqual(x.sigmoid()[0].gradient(x)[0].item(), 0.0)
836836
x = Tensor([-300.0])
837837
self.assertAlmostEqual(x.sigmoid()[0].gradient(x)[0].item(), 0.0)
838-
@unittest.skip("fix sigmoid stability")
839838
def test_sigmoid_alt_extreme(self):
840839
def sigmoid(x:Tensor): return x.exp() / (1 + x.exp())
841840
x = Tensor([300.0])

tinygrad/codegen/rewriter.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,6 @@ def threefry2x32(x: UOp, key: UOp):
156156

157157
return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
158158

159-
# ***** other math rewrite ****
160-
161-
def sigmoid_like(x:UOp, y:UOp): return (t:=(1/(x+1))) * (1-t) * y
162-
163159
# ***** main rewriter *****
164160

165161
def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extra=None,vec=None,ne=None,
@@ -315,10 +311,11 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
315311
(UPat(Ops.SINK, name="root"),
316312
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.UNROLL} else (x,) for x in root.src)), root.arg)
317313
if any(x.op in {Ops.SINK, Ops.UNROLL} for x in root.src) else None),
318-
# stable sigmoid
319-
(UPat.var("x")*(((UPat.var("x")+1)*(UPat.var("x")+1)).reciprocal()), lambda x: sigmoid_like(x, x.const_like(1))),
320-
(UPat.var("x")*(((UPat.var("x")+1)*(UPat.var("x")+1)).reciprocal()*UPat.var("y")), sigmoid_like),
321-
(UPat.var("x")*(((UPat.var("x")+1)*(UPat.var("x")+1)*(UPat.var("x")+1)).reciprocal()), lambda x: sigmoid_like(x, (x+1).reciprocal())),
314+
((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c
315+
((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()),
316+
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
317+
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
318+
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
322319
])
323320

324321
# *** uop expander ***

tinygrad/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,7 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
12391239
((UPat.var("y") + UPat.var("x")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: y+x*(c+1)),
12401240
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
12411241
((UPat.var("y") + UPat.var("x")) + UPat.var("x"), lambda y,x: y+x*2),
1242-
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
1242+
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3) if x2 is not x3 else None), # (x/x2)/x3 -> x/(x2*x3)
12431243
(-1 * (UPat.var("x") + UPat.cvar("c")), lambda x,c: (-x)+(-c)), # -(x+c) -> -x + -c
12441244
# a conditional with the same results either way is a noop, also fold const conditionals
12451245
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),

0 commit comments

Comments
 (0)