Skip to content

Commit 8206c72

Browse files
authored
move const multiply after REDUCE (tinygrad#9730)
1 parent 6b3480e commit 8206c72

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

tinygrad/codegen/devectorizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def no_vectorized_wmma(wmma:UOp):
229229
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
230230
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
231231

232-
def no_vectorized_alu(alu):
232+
def no_vectorized_alu(alu:UOp):
233233
if alu.dtype.vcount == 1: return None
234234
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
235235
return UOp(Ops.VECTORIZE, alu.dtype, alus)

tinygrad/codegen/symbolic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
486486
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
487487
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
488488
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
489-
# move const multiply after REDUCE. TODO: enable later
490-
#(UPat(Ops.REDUCE, src=(UPat.var("x")*UPat.cvar("c", vec=False),), arg=Ops.ADD, name="r", allow_any_len=True),
491-
# lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
489+
# move const multiply after REDUCE
490+
(UPat(Ops.REDUCE, src=(UPat.var("x")*UPat.cvar("c", vec=False),), arg=Ops.ADD, name="r", allow_any_len=True),
491+
lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
492492
])

0 commit comments

Comments
 (0)