|
45 | 45 |
|
46 | 46 | # ***** uop type spec *****
|
47 | 47 |
|
| 48 | +def validate_index(idx:UOp, mask:UOp|None=None): |
| 49 | + # this checks for out of bounds access. it is not complete but should catch some issues |
| 50 | + if mask is None and not isinstance(idx.dtype, ImageDType): |
| 51 | + # WEBGPU has a BITCAST in the index. TODO: fix |
| 52 | + if any(x.op in {Ops.DEFINE_VAR, Ops.BITCAST} or (x.op is Ops.SPECIAL and any(not isinstance(y, int) for y in x.arg[1:])) for x in idx.toposort): |
| 53 | + return True |
| 54 | + vmin, vmax, sz = idx.src[1].vmin, idx.src[1].vmax, cast(PtrDType, idx.src[0].dtype).size |
| 55 | + if sz != -1 and (vmin < 0 or vmax >= sz): |
| 56 | + print(f"OUT OF BOUNDS ACCESS in INDEX {vmin} - {vmax} not in 0 - {sz}") |
| 57 | + return False |
| 58 | + return True |
| 59 | + |
48 | 60 | # this is the matcher for the final rendered UOps
|
49 | 61 | # matcher functions returns True or False (or None to not match)
|
50 | 62 | spec = PatternMatcher([
|
|
75 | 87 |
|
76 | 88 | # INDEX is used in new style load/store
|
77 | 89 | # INDEX takes a <buf, alu, gate?>
|
78 |
| - (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True), |
79 |
| - (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool))), lambda: True), |
| 90 | + (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat()), name="idx"), validate_index), |
| 91 | + (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool, name="mask")), name="idx"), validate_index), |
80 | 92 |
|
81 | 93 | # LOAD takes a <bufidx, alt?, barrier?>
|
82 | 94 | (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
|
0 commit comments