Skip to content

Commit bfc68d1

Browse files
authored
add gep rules to simplify (tinygrad#9419)
* add gep rules to simplify * ws * flipped direction
1 parent 0bed9b6 commit bfc68d1

File tree

2 files changed

+15
-25
lines changed

2 files changed

+15
-25
lines changed

tinygrad/codegen/devectorizer.py

+4-14
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,16 @@
1111

1212
# ***** load/store grouping *****
1313

14-
def fancy_gep(vec:UOp, i:int):
15-
# if there's a vectorized ADD here, expand through it
16-
if vec.op is Ops.ADD:
17-
if vec.src[0].op is Ops.VECTORIZE and vec.src[1].op is Ops.VCONST: return vec.src[0].gep(i) + vec.src[1].gep(i)
18-
if vec.src[1].op is Ops.VECTORIZE and vec.src[0].op is Ops.VCONST: return vec.src[1].gep(i) + vec.src[0].gep(i)
19-
# if there's a vectorized AND here, expand through it
20-
if vec.op is Ops.AND:
21-
if vec.src[0].op is Ops.VECTORIZE and vec.src[1].op is Ops.VCONST: return vec.src[0].gep(i) & vec.src[1].gep(i)
22-
if vec.src[1].op is Ops.VECTORIZE and vec.src[0].op is Ops.VCONST: return vec.src[1].gep(i) & vec.src[0].gep(i)
23-
return vec.gep(i)
24-
2514
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
2615
# first, extract all the relevant offsets
2716
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
2817
for i in range(vec.dtype.count):
29-
idx = fancy_gep(vec, i)
18+
idx = vec.gep(i).simplify()
3019
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
20+
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
3121
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
3222
else: root_src, arg = idx, 0
33-
if mask is not None: root_src = (fancy_gep(mask, i), root_src)
23+
if mask is not None: root_src = (mask.gep(i).simplify(), root_src)
3424
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
3525

3626
# the buf.dtype is always a pointer
@@ -44,7 +34,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
4434
grouped_offsets = [[x for _,x in group] for _,group in itertools.groupby(enumerate(sorted(offsets.keys())), lambda x: x[1]-x[0])]
4535
for grp in grouped_offsets:
4636
# get the index offset for this element. using [0] is okay, because they are the same
47-
oidx = fancy_gep(vec, offsets[grp[0]][0])
37+
oidx = vec.gep(offsets[grp[0]][0])
4838
lidx = UOp(Ops.INDEX, buf.dtype, (buf, oidx, rootsrc[0]) if mask is not None else (buf, oidx))
4939
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
5040
# set the idxs of the output

tinygrad/codegen/symbolic.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,17 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
230230
# ** mod **
231231
# mod folding
232232
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
233+
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
234+
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
235+
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
236+
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
237+
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
238+
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
239+
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
240+
# push all GEPs through ALUs (fix arange stuff)
241+
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
242+
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
243+
if not isinstance(gep.dtype, PtrDType) else None),
233244
])
234245

235246
symbolic_flat = symbolic+PatternMatcher([
@@ -399,17 +410,6 @@ def gep_through_wmma(gep:UOp, wmma:UOp):
399410
# VECTORIZE void is SINK
400411
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
401412
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
402-
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
403-
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
404-
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
405-
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
406-
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
407-
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
408-
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
409-
# push all GEPs through ALUs (fix arange stuff)
410-
(UPat(Ops.GEP, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
411-
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
412-
if not isinstance(gep.dtype, PtrDType) else None),
413413
# push some GEPs through WMMAs
414414
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
415415
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)

0 commit comments

Comments
 (0)