6
6
from enum import Enum , auto
7
7
8
8
from tinygrad .ops import GroupOp , KernelInfo , UOp , Ops , can_pad , resolve , Variable , sint , graph_rewrite , track_rewrites , view_left , print_uops
9
- from tinygrad .spec import type_verify
9
+ from tinygrad .spec import type_verify , shape_spec
10
10
from tinygrad .device import Device
11
11
from tinygrad .renderer import Renderer , TensorCore , ProgramSpec
12
12
from tinygrad .dtype import ImageDType
@@ -57,11 +57,8 @@ def __init__(self, ast:UOp, opts:Optional[Renderer]=None):
57
57
if ast .op is Ops .SINK : self .ast = ast
58
58
59
59
self .opts = opts if opts is not None else Device [Device .DEFAULT ].renderer
60
- try : verify_ast (self .ast )
61
- except AssertionError as e :
62
- print ("INVALID AST" )
63
- print (self .ast )
64
- raise e
60
+ # verify AST matches the spec
61
+ if __debug__ : type_verify (list (self .ast .toposort ), shape_spec )
65
62
66
63
self .reduceops = [x for x in self .ast .toposort if x .op is Ops .REDUCE_AXIS ]
67
64
@@ -673,7 +670,10 @@ def linearize(self) -> Kernel:
673
670
if getenv ("RAWAST" ): print (self .ast )
674
671
print (modified_ast )
675
672
print (self .applied_opts )
676
- verify_ast (modified_ast )
673
+ # verify AST matches the spec after applying opts
674
+ if __debug__ : type_verify (list (modified_ast .toposort ))
675
+ # TODO: sadly modified_ast doesn't pass the shape spec because of how group_for_reduces constructs UOps, there's probably a way to fix this
676
+ #if __debug__: type_verify(list(modified_ast.toposort), shape_spec)
677
677
678
678
self .uops :list [UOp ] = linearize_uop (full_graph_rewrite (rewrite_shapetracker_with_index (modified_ast , self .opts ), self .opts ))
679
679
if DEBUG >= 5 : print_uops (self .uops )
@@ -693,39 +693,3 @@ def to_program(self, name_override:Optional[str]=None) -> ProgramSpec:
693
693
key = lambda x : (x .op , x .src [0 ].arg )))
694
694
return ProgramSpec (ansiname , src , self .opts .device , self .uops , mem_estimate = mem_bytes ,
695
695
global_size = [1 ,1 ,1 ] if self .opts .has_local else None , local_size = [1 ,1 ,1 ] if self .opts .has_local else None )
696
-
697
- # the living definition of intermediate UOps
698
-
699
- def _assert_valid_uop (uop :UOp , st :ShapeTracker , sts :dict [UOp , ShapeTracker ]) -> None :
700
- if uop in sts : return
701
- # restore globals from the two stage reduce
702
- # this is because this LOAD has an implicit movement op
703
- if uop .op is Ops .LOAD and uop .src [0 ].op is Ops .DEFINE_LOCAL :
704
- _assert_valid_uop (local_reduce := uop .src [2 ].src [2 ], uop .st_arg , sts )
705
- sts [uop ] = sts [local_reduce ]
706
- return
707
- for x in uop .src : _assert_valid_uop (x , st , sts )
708
- # only reduceuop is allowed to change shape, limited to turning n to 1
709
- if uop .op in {Ops .REDUCE_AXIS , Ops .WMMA }: st = ShapeTracker .from_shape (sts [uop .src [0 ]].reduce (uop .axis_arg ))
710
- # movementops are pushed to VIEW
711
- elif uop .op is Ops .VIEW :
712
- # NOTE: we disallow VIEW in the middle of the AST, if it has a DEVICE source it's fine
713
- assert len (uop .src ) == 0 or uop .src [0 ].op is Ops .DEVICE , f"can't swizzle in kernel yet { uop } "
714
- st = uop .arg
715
- # everything else inherits shape
716
- else :
717
- if len (src_sts := [sts [x ] for x in uop .src if x in sts ]) == 0 : return None
718
- st = src_sts [0 ]
719
- if not all_same (shapes := [x .shape for x in src_sts ]):
720
- if all_same (sizes := [prod (x ) for x in shapes ]): raise AssertionError (f"found implicit reshape { shapes } " )
721
- raise AssertionError (f"found implicit expand { sizes } { shapes } " )
722
- sts [uop ] = st
723
-
724
- def verify_ast (ast :UOp ) -> None :
725
- assert ast .op is Ops .SINK and all (x .op is Ops .STORE for x in ast .src ), "must be SINK"
726
- assert all_same ([x .st_arg .size for x in ast .src ]), "outputs must be exactly the same size"
727
- sts : dict [UOp , ShapeTracker ] = {}
728
- for out in ast .src : _assert_valid_uop (out , out .st_arg , sts )
729
- shape_dims = [sorted (dedup (dims )) for dims in zip (* [x .shape for x in sts .values ()])]
730
- assert all (len (x ) == 1 or (len (x ) == 2 and x [0 ] == 1 ) for x in shape_dims ), f"shapes must have either 1 or n in each dimension, { shape_dims } "
731
- type_verify (list (sts ))
0 commit comments