Skip to content

Commit 3c5161b

Browse files
geohotQazalin
andauthored
add validation of the bounds of Ops.INDEX (tinygrad#9503)
* add validation of the bounds of Ops.INDEX * do mask properly * more validation * correct * fix gated * add CAST support to vmin/vmax * fix ptx and image * ptx no diff * upat.index also stays --------- Co-authored-by: qazal <qazal.software@gmail.com>
1 parent 0b20f91 commit 3c5161b

File tree

6 files changed

+50
-4
lines changed

6 files changed

+50
-4
lines changed

test/test_uop_graph.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,11 @@ def test_bitcast_to_same_dtype_fold(self):
442442
uops = to_uops_list([v.bitcast(dt)])
443443
self.assertEqual(len([x for x in uops if x.op is Ops.BITCAST]), 0, f"dtype = {dt}")
444444

445+
def test_out_of_bounds_access(self):
446+
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
447+
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42)),))
448+
with self.assertRaises(RuntimeError): to_uops_list([ld0])
449+
445450
def test_fold_gated_load(self):
446451
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
447452
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
@@ -456,7 +461,7 @@ def test_fold_gated_load(self):
456461

457462
def test_fold_gated_load_local(self):
458463
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
459-
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=1, local=True), (), "temp")
464+
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, local=True), (), "temp")
460465
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
461466
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
462467
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))

test/test_uops.py

+1
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def test_local_packed(self):
322322
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
323323

324324
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
325+
@unittest.skip("tinygrad doesn't support this behavior")
325326
def test_local_indirect(self):
326327
uops = []
327328
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(size=16, local=True), (), 'smem')

test/unit/test_uop_vmin_vmax.py

+25
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,31 @@ def test_vmin_vmax_addition_with_variable(self):
2525
self.assertEqual(uop.vmin, 15)
2626
self.assertEqual(uop.vmax, 25)
2727

28+
def test_vmin_vmax_subtraction_with_variable(self):
29+
x = UOp.variable('x', 10, 20)
30+
uop = x - 5
31+
self.assertEqual(uop.vmin, 5)
32+
self.assertEqual(uop.vmax, 15)
33+
uop = 5 - x
34+
self.assertEqual(uop.vmin, -15)
35+
self.assertEqual(uop.vmax, -5)
36+
37+
def test_vmin_vmax_and_with_variable(self):
38+
x = UOp.variable('x', 10, 20)
39+
uop = x & 5
40+
self.assertEqual(uop.vmin, 0)
41+
self.assertEqual(uop.vmax, 5)
42+
43+
# this can be improved
44+
uop = x & 15
45+
self.assertEqual(uop.vmin, 0)
46+
self.assertEqual(uop.vmax, 15)
47+
48+
# this can be improved
49+
uop = x & 32
50+
self.assertEqual(uop.vmin, 0)
51+
self.assertEqual(uop.vmax, 20)
52+
2853
def test_vmin_vmax_multiplication_with_variable(self):
2954
# vmin and vmax for multiplication with a variable
3055
x = UOp.variable('x', -3, 4)

tinygrad/engine/schedule.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def unbind_variable(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], var:UOp, va
334334
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
335335
# partial assign can store to a non-contiguous ShapeTracker
336336
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
337-
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
337+
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.src[0].base.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
338338
# otherwise the store is contiguous
339339
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
340340
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),

tinygrad/ops.py

+3
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,8 @@ def _min_max(self) -> tuple[ConstType, ConstType]:
603603
if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
604604
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
605605
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
606+
if self.op is Ops.SUB: return s0_vmin-s1_vmax, s0_vmax-s1_vmin
607+
if self.op is Ops.AND and s1_vmin == s1_vmax and s0_vmin >= 0 and s1_vmin >= 0: return min(0, s0_vmin), min(s0_vmax, s1_vmax)
606608
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
607609
# SHL/SHR on consts only
608610
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
@@ -633,6 +635,7 @@ def _min_max(self) -> tuple[ConstType, ConstType]:
633635
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
634636
if self.op is Ops.CONST: return self.arg, self.arg
635637
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
638+
if self.op is Ops.CAST: return max(dtypes.min(self.dtype), self.src[0].vmin), min(self.src[0].vmax, dtypes.max(self.dtype))
636639
return dtypes.min(self.dtype), dtypes.max(self.dtype)
637640

638641
@functools.cached_property

tinygrad/spec.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@
4545

4646
# ***** uop type spec *****
4747

48+
def validate_index(idx:UOp, mask:UOp|None=None):
49+
# this checks for out of bounds access. it is not complete but should catch some issues
50+
if mask is None and not isinstance(idx.dtype, ImageDType):
51+
# WEBGPU has a BITCAST in the index. TODO: fix
52+
if any(x.op in {Ops.DEFINE_VAR, Ops.BITCAST} or (x.op is Ops.SPECIAL and any(not isinstance(y, int) for y in x.arg[1:])) for x in idx.toposort):
53+
return True
54+
vmin, vmax, sz = idx.src[1].vmin, idx.src[1].vmax, cast(PtrDType, idx.src[0].dtype).size
55+
if sz != -1 and (vmin < 0 or vmax >= sz):
56+
print(f"OUT OF BOUNDS ACCESS in INDEX {vmin} - {vmax} not in 0 - {sz}")
57+
return False
58+
return True
59+
4860
# this is the matcher for the final rendered UOps
4961
# matcher functions returns True or False (or None to not match)
5062
spec = PatternMatcher([
@@ -75,8 +87,8 @@
7587

7688
# INDEX is used in new style load/store
7789
# INDEX takes a <buf, alu, gate?>
78-
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
79-
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
90+
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat()), name="idx"), validate_index),
91+
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool, name="mask")), name="idx"), validate_index),
8092

8193
# LOAD takes a <bufidx, alt?, barrier?>
8294
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),

0 commit comments

Comments
 (0)