Skip to content

Commit 0cac941

Browse files
authored
move xpow to sym instead of late_rewrite (tinygrad#8968)
does not need to be in late_rewrite and can be simplified further
1 parent e7182bb commit 0cac941

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tinygrad/codegen/rewriter.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
123123
def get_late_rewrite_patterns(ops, force_transcendental=False):
124124
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
125125
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
126-
pat.append((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src)))
127126
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
128127
if Ops.AND in ops:
129128
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
@@ -297,6 +296,8 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
297296
# ** where **
298297
# push cast to branches
299298
(UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))),
299+
# ** pow **
300+
((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))),
300301
# ** load/store folding **
301302
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
302303
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),

0 commit comments

Comments
 (0)