Skip to content

Commit 0b20f91

Browse files
authored
remove move_mask from the devectorizer (tinygrad#9511)
* remove move_mask from the devectorizer * add (wrong) ptx * reason * enable index addition in PTX, we won't have the INDEX anyways * space
1 parent 9302738 commit 0b20f91

File tree

10 files changed

+30
-25
lines changed

10 files changed

+30
-25
lines changed

test/test_renderer_failures.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def test_gated_store_with_alu(self):
5353
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
5454
np.testing.assert_equal(ret, [0, 1, 1, 1])
5555

56+
@unittest.skip("INDEX can only have a gate ALU parent, not an IF")
5657
def test_gated_store_with_if(self):
5758
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
5859
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)

test/unit/test_simplify_valid_idx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_is_increasing(self):
4646
class TestValidIdxSimplification(unittest.TestCase):
4747
def check(self, load, sidx, svalid):
4848
load = full_graph_rewrite(load.sink()).src[0]
49-
idx, valid = load.src[0].src[1], load.src[2]
49+
idx, valid = load.src[0].src[1], load.src[0].src[2]
5050
self.assertEqual(idx.render(simplify=False), sidx)
5151
self.assertEqual(valid.render(simplify=False), svalid)
5252

@@ -133,7 +133,7 @@ def check(self, load, svalid, sidx0, sidx1):
133133
idx0, idx1 = idx.src[0], idx.src[1]
134134
self.assertEqual(idx0.render(simplify=False), sidx0)
135135
self.assertEqual(idx1.render(simplify=False), sidx1)
136-
if svalid is not None: self.assertEqual(load.src[2].render(simplify=False), svalid)
136+
if svalid is not None: self.assertEqual(load.src[0].src[2].render(simplify=False), svalid)
137137

138138
def test_idx_gt_c(self):
139139
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid

tinygrad/codegen/devectorizer.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -249,24 +249,18 @@ def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|N
249249
UPat.var("val"))), delete_redundant_gates),
250250
])
251251

252-
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:UOp|None=None) -> UOp:
253-
# this moves the mask from the indexing to the load/store op for rendering
254-
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
255-
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
256-
257252
pm_render = PatternMatcher([
258253
# for rendering, we use explicit VECTORIZE
259254
(UPat(Ops.CONST, name='c'),
260255
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
261256
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
262257
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
263258
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
264-
# move masks of loads/stores
265-
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("mask")))
266-
.or_casted("cast"),), allow_any_len=True, name="x"), move_mask),
259+
# give any loads that are masked an alt value
260+
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), name="x"), lambda x: x.replace(src=x.src+(x.const_like(0),))),
267261
# gate any stores that aren't gated with ifs
268-
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
269-
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
262+
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store"),
263+
lambda store,idx: UOp(Ops.STORE, src=store.src+(UOp(Ops.IF, src=(idx.src[2],)),))),
270264
])
271265

272266
# *** uop graph ***

tinygrad/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.a
735735

736736
@staticmethod
737737
def any(*src): return UPatAny(src=src)
738-
def or_casted(self, name:str|None=None): return UPat.any(self, UPat(Ops.CAST, name=name, src=(self,)))
738+
def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,)))
739739

740740
@staticmethod
741741
@functools.lru_cache(None)

tinygrad/renderer/cstyle.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@
4343
# default const render
4444
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
4545
# new load/store
46-
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
46+
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx')), allow_any_len=True),
4747
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"),
48-
(UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
48+
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted('bidx'), UPat.var("var"))),
49+
lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
4950
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
5051
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
5152
# alu/gep
@@ -235,7 +236,7 @@ class OpenCLRenderer(CStyleLanguage):
235236
string_rewrite = PatternMatcher([
236237
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
237238
# load/store image (OpenCL)
238-
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
239+
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2)), UPat.var("gate")), UPat.var("var"))),
239240
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
240241
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
241242
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),

tinygrad/renderer/llvmir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1
5656
# memory load/store
5757
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
5858
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
59-
(UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask:
59+
(UPat(Ops.LOAD, src=(UPat.or_casted(name='idx', self=UPat(src=(UPat(), UPat(), UPat.var('mask')))), UPat.var('alt')), name="x"),
60+
lambda ctx,x,idx,alt,mask:
6061
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
6162
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
6263
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"

tinygrad/renderer/ptx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ def render_val(x, dtype):
4949
# load/store use pointer arithmetic, and the cast does nothing
5050
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
5151
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
52+
# move mask from INDEX to the load/store to enable pointer arithmetic
53+
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat.var("gate"))), UPat.var("alt"))),
54+
lambda buf,idx,gate,alt: UOp(Ops.LOAD, alt.dtype, (buf.index(idx), alt, gate))),
55+
(UPat(Ops.STORE, src=(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"), UPat())), UPat.var("val"), UPat.var("gate"))),
56+
lambda buf,idx,val,gate: UOp.store(buf.index(idx), val, gate)),
5257
# ptx shr and shl instructions require y to be uint
5358
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.SHL, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
5459
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),

tinygrad/renderer/wgsl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def packed_store(bidx:UOp, var:UOp):
2020
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
2121
div_idx = bidx.src[1]//(4//dtype.itemsize)
2222
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize)
23-
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), var, root.src[2], dtype=dtypes.uint32, arg=root.arg)
23+
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx, bidx.src[2])), var, dtype=dtypes.uint32, arg=root.arg)
2424
else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=dtypes.uint32, arg=root.arg)
2525
val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
2626
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
@@ -31,7 +31,7 @@ def is_packed(dt:DType) -> bool: return dt.itemsize < 4 and dt.base != dtypes.ha
3131
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
3232
lambda a,b,c: a.cast(dtypes.int).alu(c.op, b.cast(dtypes.int)).cast(dtypes.bool)),
3333
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"),)), lambda l,b: packed_load(l, b, l.dtype) if is_packed(l.dtype) else None),
34-
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"), UPat.cvar("c"), UPat())),
34+
(UPat(Ops.LOAD, name="l", src=(UPat.var("b"), UPat.cvar("c"))),
3535
lambda l,b,c: packed_load(l,b,l.dtype,c.cast(dtypes.uint32)) if is_packed(l.dtype) else None),
3636
(UPat.store(UPat.var("bidx"), UPat.var("var"), allow_any_len=True), lambda bidx,var: packed_store(bidx,var) if is_packed(var.dtype) else None),
3737
# TODO: why is this needed, and only for this MUL order
@@ -64,13 +64,13 @@ class WGSLRenderer(CStyleLanguage):
6464
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"),lambda ctx,x:f"bitcast<{ctx.type_map[x.dtype]}>(vec2<f16>({ctx[x.src[0]]},0))" \
6565
if x.src[0].dtype == dtypes.half else f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
6666
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]})"),
67-
(UPat.load(UPat.var("b"), UPat.cvar("v"), UPat.var("g")),lambda ctx,b,v,g: f"select({ctx[v]}, {ctx.render_load(ctx[b], b.dtype)}, {ctx[g]})"),
67+
(UPat.load(UPat.var("b"), UPat.cvar("v")),lambda ctx,b,v: f"select({ctx[v]}, {ctx.render_load(ctx[b],b.src[0].dtype)}, {ctx[b.src[2]]})"),
6868
(UPat.load(UPat.var("b"), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.dtype)),
6969
(UPat.store(UPat.var("b"), UPat.var("v"), allow_any_len=True),lambda ctx,b,v:\
7070
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
7171
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if is_packed(b.src[0].dtype) \
7272
else f"{ctx[b]} = {ctx[v]};"),
73-
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"))),
73+
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx")), allow_any_len=True),
7474
lambda ctx,b,idx: f"{ctx[b]}[{strip_parens(ctx[idx]) if idx.arg is Ops.ADD else ctx[idx]}]"),
7575
# fix nan check: 'a != a -> is_nan()'
7676
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"(min({ctx[a]}, 1.0) == 1.0 && max({ctx[a]}, -1.0) == -1.0)"),

tinygrad/runtime/ops_python.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def _load(m, i):
1717
return m[i]
1818

1919
def load(inp, j=0):
20-
if len(inp) == 3: return [_load(m, x+j if x is not None else None) if gate else default for (m,x),default,gate in zip(*inp)]
20+
if len(inp) == 2: return [_load(m, x+j if x is not None else None) if gate else default for (m,x,gate),default in zip(*inp)]
2121
return [_load(m, x+j if x is not None else None) for m,x in inp[0]]
2222

2323
def _store(m, i, v):
@@ -80,13 +80,14 @@ def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tup
8080
elif uop is Ops.DEFINE_ACC:
8181
ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
8282
elif uop is Ops.INDEX:
83-
ret = []
83+
ret:list = []
8484
if isinstance(dtp[0], ImageDType):
8585
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
8686
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append((m, None))
8787
else: ret.append((m, ox*4 + oy*dtp[0].shape[1]*4))
8888
else:
8989
for m,o in zip(inp[0], inp[1]): ret.append((m,o))
90+
if len(inp) == 3: ret = [(m,o,g) for (m,o),g in zip(ret, inp[2])] # set the gate last
9091
ul[i] = ret
9192
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
9293
ul[i] = inp[0]

tinygrad/spec.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,14 @@
7474
# **** new style load/store ****
7575

7676
# INDEX is used in new style load/store
77+
# INDEX takes a <buf, alu, gate?>
7778
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
79+
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
7880

79-
# LOAD takes a <bufidx, alt?, gate?, barrier?>
81+
# LOAD takes a <bufidx, alt?, barrier?>
8082
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
8183
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
82-
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat.var("alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
84+
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat.var("alt")), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
8385

8486
# STORE takes a <bufidx, val, gate?>
8587
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),

0 commit comments

Comments
 (0)