Skip to content

Commit ea5caef

Browse files
authored
gep should look at count, not vcount (tinygrad#9698)
* gep should look at count, not vcount * gep in order is a rule * min change * gep on void
1 parent f72a87f commit ea5caef

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

tinygrad/codegen/symbolic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
180180
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
181181
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
182182
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
183+
# GEP on void is skipped
184+
(UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x),
185+
# GEP in order is removed
186+
(UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None),
183187
# push all GEPs through ALUs (fix arange stuff)
184188
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
185189
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \

tinygrad/ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ def gep(self, i:Union[tuple[int, ...], int]):
387387
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
388388
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
389389
i = (i,)
390-
if (self.dtype.vcount == len(i) and i == tuple(range(len(i)))) or self.dtype == dtypes.void: return self
391390
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
392391
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
393392
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)

0 commit comments

Comments
 (0)