Skip to content

Commit af4f9d1

Browse files
authored
use matchers to verify AST shape [pr] (tinygrad#8828)
* use matchers to verify kernel AST [pr] * work * use swizzle_cnt * add comment * imports * modified_ast comment * brief
1 parent 643c09a commit af4f9d1

File tree

4 files changed

+32
-53
lines changed

4 files changed

+32
-53
lines changed

test/test_schedule.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from tinygrad.shape.shapetracker import ShapeTracker
1515
from tinygrad.shape.view import View
1616
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
17+
from tinygrad.spec import type_verify, shape_spec
1718
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
18-
from tinygrad.codegen.kernel import verify_ast
1919
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
2020
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
2121
from extra.models.llama import precompute_freqs_cis
2222

23+
def verify_ast(sink:UOp): return type_verify(list(sink.toposort), shape_spec)
2324
class KernelCountException(Exception): pass
2425
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
2526
if to_prerealize:
@@ -1824,7 +1825,7 @@ def test_simple_store_reshape(self):
18241825
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
18251826
rsink = graph_rewrite(sink, view_right)
18261827
# this AST first needs to swizzle, but it doesn't have implicit movementops
1827-
with self.assertRaisesRegex(AssertionError, "swizzle"): verify_ast(sink)
1828+
self.assertEqual(swizzle_cnt(sink), 1)
18281829
verify_ast(rsink)
18291830

18301831
def test_no_reshape_reduceop(self):

test/unit/test_verify_ast.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tinygrad.codegen.kernel import Kernel
66
from tinygrad.helpers import DEBUG
77
from tinygrad.ops import UOp, Ops, print_uops
8-
from tinygrad.codegen.kernel import verify_ast
8+
from tinygrad.spec import type_verify, shape_spec
99
from tinygrad.shape.shapetracker import ShapeTracker
1010
from tinygrad import dtypes
1111
from tinygrad.shape.view import View
@@ -15,8 +15,8 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel:
1515
sink = UOp(Ops.SINK, dtypes.void, stores)
1616
if DEBUG >= 3:
1717
for op in stores: print(op)
18-
try: verify_ast(sink)
19-
except AssertionError as e: raise InvalidASTException(e.args)
18+
try: type_verify(list(sink.toposort), shape_spec)
19+
except RuntimeError as e: raise InvalidASTException(e.args)
2020
k = Kernel(sink)
2121
k.linearize()
2222
if DEBUG >= 6: print_uops(k.uops)
@@ -64,23 +64,24 @@ def test_reduce_store(self):
6464
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
6565
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
6666
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
67-
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
67+
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
6868

6969
def test_reduce_add_store(self):
7070
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
7171
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
7272
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
7373
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
74-
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
74+
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
7575

7676
def test_buffer_uops_st(self):
7777
a = Tensor.randn(4, 4)+2
78-
verify_ast(ast:=a.schedule()[-1].ast)
78+
helper_test_verify_ast(ast:=a.schedule()[-1].ast)
7979
store_st = [u.st for u in ast.toposort if u.op is Ops.STORE][0]
8080
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
8181
const_st = [u.st for u in ast.toposort if u.op is Ops.CONST][0]
8282
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
8383

84+
@unittest.skip("questionable if we want this")
8485
def test_assert_swizzle(self):
8586
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
8687
a = UOp(Ops.LOAD, dtypes.float, (buf, ShapeTracker.from_shape((32, 1)).to_uop()))

tinygrad/codegen/kernel.py

+7-43
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from enum import Enum, auto
77

88
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, view_left, print_uops
9-
from tinygrad.spec import type_verify
9+
from tinygrad.spec import type_verify, shape_spec
1010
from tinygrad.device import Device
1111
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
1212
from tinygrad.dtype import ImageDType
@@ -57,11 +57,8 @@ def __init__(self, ast:UOp, opts:Optional[Renderer]=None):
5757
if ast.op is Ops.SINK: self.ast = ast
5858

5959
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
60-
try: verify_ast(self.ast)
61-
except AssertionError as e:
62-
print("INVALID AST")
63-
print(self.ast)
64-
raise e
60+
# verify AST matches the spec
61+
if __debug__: type_verify(list(self.ast.toposort), shape_spec)
6562

6663
self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS]
6764

@@ -673,7 +670,10 @@ def linearize(self) -> Kernel:
673670
if getenv("RAWAST"): print(self.ast)
674671
print(modified_ast)
675672
print(self.applied_opts)
676-
verify_ast(modified_ast)
673+
# verify AST matches the spec after applying opts
674+
if __debug__: type_verify(list(modified_ast.toposort))
675+
# TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
676+
#if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
677677

678678
self.uops:list[UOp] = linearize_uop(full_graph_rewrite(rewrite_shapetracker_with_index(modified_ast, self.opts), self.opts))
679679
if DEBUG >= 5: print_uops(self.uops)
@@ -693,39 +693,3 @@ def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
693693
key=lambda x: (x.op, x.src[0].arg)))
694694
return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
695695
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
696-
697-
# the living definition of intermediate UOps
698-
699-
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) -> None:
700-
if uop in sts: return
701-
# restore globals from the two stage reduce
702-
# this is because this LOAD has an implicit movement op
703-
if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
704-
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
705-
sts[uop] = sts[local_reduce]
706-
return
707-
for x in uop.src: _assert_valid_uop(x, st, sts)
708-
# only reduceuop is allowed to change shape, limited to turning n to 1
709-
if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
710-
# movementops are pushed to VIEW
711-
elif uop.op is Ops.VIEW:
712-
# NOTE: we disallow VIEW in the middle of the AST, if it has a DEVICE source it's fine
713-
assert len(uop.src) == 0 or uop.src[0].op is Ops.DEVICE, f"can't swizzle in kernel yet {uop}"
714-
st = uop.arg
715-
# everything else inherits shape
716-
else:
717-
if len(src_sts:=[sts[x] for x in uop.src if x in sts]) == 0: return None
718-
st = src_sts[0]
719-
if not all_same(shapes:=[x.shape for x in src_sts]):
720-
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
721-
raise AssertionError(f"found implicit expand {sizes} {shapes}")
722-
sts[uop] = st
723-
724-
def verify_ast(ast:UOp) -> None:
725-
assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
726-
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
727-
sts: dict[UOp, ShapeTracker] = {}
728-
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
729-
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
730-
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
731-
type_verify(list(sts))

tinygrad/spec.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import cast
22
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
33
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
4-
from tinygrad.helpers import all_int, prod
4+
from tinygrad.helpers import all_int, all_same, dedup, prod
55

66
# *** this is the spec of a Tensor in UOp ***
77

@@ -61,7 +61,7 @@
6161
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype),
6262

6363
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
64-
(UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
64+
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
6565

6666
# early LOAD has a <buf, shapetracker, store?>
6767
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
@@ -121,6 +121,19 @@
121121
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
122122
])
123123

124+
# *** this is the UOp shape spec ***
125+
126+
def verify_sink_dims(sink:UOp):
127+
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sink.toposort if x.op is not Ops.SINK and x.st is not None])]
128+
return all_same([x.st_arg.size for x in sink.src]) and all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims)
129+
130+
shape_spec = PatternMatcher([
131+
# shapes must have either 1 or n in each dimension
132+
(UPat(Ops.SINK, src=UPat(Ops.STORE), allow_any_len=True, name="sink"), verify_sink_dims),
133+
# all parent UOps must have the same shape
134+
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
135+
])
136+
124137
# ***** uop helpers *****
125138

126139
def type_verify(uops:list[UOp], *extra_specs:PatternMatcher):

0 commit comments

Comments
 (0)