@@ -156,10 +156,6 @@ def threefry2x32(x: UOp, key: UOp):
156
156
157
157
return xr [1 ].cast (dtypes .uint64 ) * 2 ** 32 | xr [0 ].cast (dtypes .uint64 )
158
158
159
- # ***** other math rewrite ****
160
-
161
- def sigmoid_like (x :UOp , y :UOp ): return (t := (1 / (x + 1 ))) * (1 - t ) * y
162
-
163
159
# ***** main rewriter *****
164
160
165
161
def loop_collapse (compval , multconst , rng :UOp , acc :UOp , idx2 = None ,idx3 = None ,extra = None ,vec = None ,ne = None ,
@@ -315,10 +311,11 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
315
311
(UPat (Ops .SINK , name = "root" ),
316
312
lambda root : UOp (Ops .SINK , root .dtype , tuple (flatten (x .src if x .op in {Ops .SINK , Ops .UNROLL } else (x ,) for x in root .src )), root .arg )
317
313
if any (x .op in {Ops .SINK , Ops .UNROLL } for x in root .src ) else None ),
318
- # stable sigmoid
319
- (UPat .var ("x" )* (((UPat .var ("x" )+ 1 )* (UPat .var ("x" )+ 1 )).reciprocal ()), lambda x : sigmoid_like (x , x .const_like (1 ))),
320
- (UPat .var ("x" )* (((UPat .var ("x" )+ 1 )* (UPat .var ("x" )+ 1 )).reciprocal ()* UPat .var ("y" )), sigmoid_like ),
321
- (UPat .var ("x" )* (((UPat .var ("x" )+ 1 )* (UPat .var ("x" )+ 1 )* (UPat .var ("x" )+ 1 )).reciprocal ()), lambda x : sigmoid_like (x , (x + 1 ).reciprocal ())),
314
+ ((UPat .var ("x" ) * UPat .var ("x" )).reciprocal (), lambda x : x .reciprocal ()* x .reciprocal ()), # 1/(x^c) -> (1/x)^c
315
+ ((UPat .var ("x" ) * UPat .var ("x" ) * UPat .var ("x" )).reciprocal (), lambda x : x .reciprocal ()* x .reciprocal ()* x .reciprocal ()),
316
+ (UPat .var ("x" ) * ((1 + UPat .var ("x" )).reciprocal ().named ("d" )), lambda x ,d : 1 - d ), # x*/(1+x) -> 1-1/(1+x)
317
+ (UPat .var ("x" ) * ((1 + UPat .var ("x" )).reciprocal ().named ("d" )* UPat .var ("y" )), lambda x ,y ,d : y * (1 - d )),
318
+ (UPat .var ("x" ) * ((1 + UPat .var ("x" )).reciprocal ().named ("d" )+ UPat .var ("y" )), lambda x ,y ,d : (1 - d )+ x * y ),
322
319
])
323
320
324
321
# *** uop expander ***
0 commit comments