@@ -123,7 +123,6 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
123
123
def get_late_rewrite_patterns (ops , force_transcendental = False ):
124
124
pat : list [tuple [UPat , Callable ]] = [(UPat (op , dtype = TRANSCENDENTAL_SUPPORTED_DTYPES , src = (UPat .var ("d" ),)), f ) for op ,f in \
125
125
((Ops .EXP2 , xexp2 ), (Ops .LOG2 , xlog2 ), (Ops .SIN , xsin )) if op not in ops or force_transcendental ]
126
- pat .append ((UPat (Ops .POW , name = "p" ), lambda p : xpow (* p .src )))
127
126
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
128
127
if Ops .AND in ops :
129
128
pat += [(UPat .var ("x" , dtypes .ints )% UPat .cvar ("c" ), lambda x ,c : x & (c .arg - 1 ) if c .arg in powers_of_two else None )]
@@ -297,6 +296,8 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
297
296
# ** where **
298
297
# push cast to branches
299
298
(UPat .var ("s" ).where (UPat .var ("a" ), UPat .var ("b" )).cast ().named ("cast" ), lambda s ,a ,b ,cast : s .where (a .cast (cast .dtype ), b .cast (cast .dtype ))),
299
+ # ** pow **
300
+ ((UPat (Ops .POW , name = "p" ), lambda p : xpow (* p .src ))),
300
301
# ** load/store folding **
301
302
(UPat .store (UPat (Ops .INDEX , name = "index" ), UPat .load (UPat (Ops .INDEX , name = "index" ))), lambda index : UOp (Ops .NOOP )),
302
303
(UPat .store (UPat (Ops .INDEX , name = "index" ), UPat .var ("gate" ).where (UPat .var ("alt" ), UPat .load (UPat (Ops .INDEX , name = "index" )))),
0 commit comments