@@ -223,7 +223,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
223
223
if len (ctx .children [top_reduce ]) == 1 : del ctx .realizes [top_reduce ]
224
224
return ctx .realizes
225
225
226
- # break the SINK into kernels
226
+ # **** create kernels
227
227
228
228
@dataclass (frozen = True )
229
229
class Kernel :
@@ -243,6 +243,7 @@ def create_kernel(ctx:KernelContext, x:UOp, b:UOp):
243
243
return UOp (Ops .ASSIGN , x .dtype , (buffer , kernel )).reshape (x .shape )
244
244
245
245
DONT_PLACE_IN_KERNEL = {Ops .KERNEL , Ops .ASSIGN , Ops .BUFFER }
246
+
246
247
def append_to_kernel (ctx :KernelContext , x :UOp ):
247
248
new_srcs : list [UOp ] = []
248
249
metadata = dict .fromkeys (x .arg .metadata )
@@ -268,30 +269,7 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
268
269
(UPat (Ops .SINK , name = "x" ), lambda x :x .replace (src = tuple (s .base for s in x .src )) if any (s .op is Ops .VIEW for s in x .src ) else None ),
269
270
])
270
271
271
- # **** fix kernel AST
272
-
273
- # ** create buffer ops + enumerate buffers
274
-
275
- add_buffer_ops = PatternMatcher ([
276
- # LOAD
277
- (UPat (Ops .BUFFER , name = "x" ), lambda ctx ,x :UOp .load (UOp (Ops .DEFINE_GLOBAL , x .dtype .ptr (x .size ), (), ctx [1 ].index (x )), x .st .to_uop (), dtype = x .dtype )),
278
- # STORE (except for COPY/BUFFER_VIEW)
279
- (UPat (Ops .SINK , src = (UPat ((Ops .COPY , Ops .BUFFER_VIEW ), name = "x" ),)), lambda x :x ),
280
- # partial assign can store to a non-contiguous ShapeTracker
281
- (UPat (Ops .SINK , src = (UPat (Ops .ASSIGN , name = "x" ),)),
282
- lambda x : UOp .store (UOp (Ops .DEFINE_GLOBAL , x .dtype .ptr (x .size ), (), 0 ), x .src [0 ].st .to_uop (), x .src [1 ]).sink ()),
283
- # otherwise the store is contiguous
284
- (UPat (Ops .SINK , src = (UPat (GroupOp .All - {Ops .STORE }, name = "x" ),)),
285
- lambda x : UOp .store (UOp (Ops .DEFINE_GLOBAL , x .dtype .ptr (x .size ), (), 0 ), ShapeTracker .from_shape (x .shape ).to_uop (), x ).sink ()),
286
- # if the last child is a VIEW we merge the ShapeTrackers and store the base
287
- (UPat (Ops .STORE , src = (UPat .var ("b" ), UPat .var ("st" ), UPat (Ops .VIEW , src = (UPat (GroupOp .All - DONT_PUSH_VIEWS , name = "x" ),)))),
288
- lambda x ,b ,st : UOp .store (b , (st .arg + x .st ).to_uop (), x )),
289
- # remove CONTIGUOUS/DEVICE from kernel AST
290
- (UPat (Ops .CONTIGUOUS , src = (UPat .var ("x" ),)), lambda x : x ),
291
- (UPat (Ops .VIEW , src = (UPat (Ops .DEVICE ),), name = "view" ), lambda view : view .replace (src = ())),
292
- ])
293
-
294
- # ** push views to buffer ops
272
+ # **** swizzler
295
273
296
274
def apply_swizzle (u :UOp ) -> UOp :
297
275
with Context (TRACK_MATCH_STATS = 0 ): return graph_rewrite (u , view_left )
@@ -314,7 +292,7 @@ def reduceop_view_right(src:UOp, v:UOp, r:UOp):
314
292
assert unwrap (v .st ).contiguous and v .size == src .size , f"can't compute new axis for { src .shape } -> { r .shape } "
315
293
return src .r (r .arg [0 ], tuple (i for i ,(s ,u ) in enumerate (zip (src .shape , r .shape )) if s != u )).view (ShapeTracker .from_shape (r .shape ))
316
294
317
- def elementwise_view_right (root :UOp ) -> UOp | None :
295
+ def elementwise_view_right (root :UOp ):
318
296
if not (swizzles := [x for x in root .src if x .op is Ops .VIEW and x .base .op not in DONT_PUSH_VIEWS ]): return None
319
297
assert all_same ([x .base .size for x in swizzles ]), f"swizzle inputs must have the same size { swizzles } "
320
298
# place view after applying the elementwise op
@@ -323,7 +301,7 @@ def elementwise_view_right(root:UOp) -> UOp|None:
323
301
# reshape to match downstream shapes
324
302
return root .replace (src = tuple (new_src )).reshape (root .shape )
325
303
326
- def merge_double_reduce (root :UOp , first_reduce :UOp ) -> UOp :
304
+ def merge_double_reduce (root :UOp , first_reduce :UOp ):
327
305
assert root .arg [0 ] == first_reduce .arg [0 ], "can't merge reduceops with different alu"
328
306
assert not any (x .op is Ops .REDUCE_AXIS for x in first_reduce .src [0 ].toposort ), "can't merge more than two reduceops at a time"
329
307
return first_reduce .replace (arg = (first_reduce .arg [0 ], root .axis_arg + first_reduce .axis_arg ))
@@ -340,9 +318,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
340
318
(UPat (Ops .REDUCE_AXIS , src = (UPat (Ops .REDUCE_AXIS , name = "first_reduce" ),), name = "root" ), merge_double_reduce ),
341
319
])
342
320
343
- # ** unbind variables
321
+ # **** unbind variables
344
322
345
- def unbind_shapetracker (ctx :tuple [dict [Variable , int ], tuple [UOp , ...]], x :UOp ) -> UOp | None :
323
+ def unbind_shapetracker (ctx :tuple [dict [Variable , int ], tuple [UOp , ...]], x :UOp ):
346
324
st = unwrap (x .st ).simplify ()
347
325
if any (x .op is Ops .BIND for x in st .vars ()):
348
326
st , var_vals = st .unbind ()
@@ -354,7 +332,26 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
354
332
return var
355
333
unbind_vars = PatternMatcher ([(UPat (Ops .BIND , name = "bind" , src = (UPat .var ("var" ), UPat .cvar ("val" ))), unbind_variable ),])
356
334
357
- # ** fix_kernel_ops
335
+ # **** fix kernel AST
336
+
337
+ add_buffer_ops = PatternMatcher ([
338
+ # LOAD
339
+ (UPat (Ops .BUFFER , name = "x" ), lambda ctx ,x :UOp .load (UOp (Ops .DEFINE_GLOBAL , x .dtype .ptr (x .size ), (), ctx [1 ].index (x )), x .st .to_uop (), dtype = x .dtype )),
340
+ # STORE (except for COPY/BUFFER_VIEW)
341
+ (UPat (Ops .SINK , src = (UPat ((Ops .COPY , Ops .BUFFER_VIEW ), name = "x" ),)), lambda x :x ),
342
+ # partial assign can store to a non-contiguous ShapeTracker
343
+ (UPat (Ops .SINK , src = (UPat (Ops .ASSIGN , name = "x" ),)),
344
+ lambda x : UOp .store (UOp (Ops .DEFINE_GLOBAL , x .dtype .ptr (x .size ), (), 0 ), x .src [0 ].st .to_uop (), x .src [1 ]).sink ()),
345
+ # otherwise the store is contiguous
346
+ (UPat (Ops .SINK , src = (UPat (GroupOp .All - {Ops .STORE }, name = "x" ),)),
347
+ lambda x : UOp .store (UOp (Ops .DEFINE_GLOBAL , x .dtype .ptr (x .size ), (), 0 ), ShapeTracker .from_shape (x .shape ).to_uop (), x ).sink ()),
348
+ # if the last child is a VIEW we merge the ShapeTrackers and store the base
349
+ (UPat (Ops .STORE , src = (UPat .var ("b" ), UPat .var ("st" ), UPat (Ops .VIEW , src = (UPat (GroupOp .All - DONT_PUSH_VIEWS , name = "x" ),)))),
350
+ lambda x ,b ,st : UOp .store (b , (st .arg + x .st ).to_uop (), x )),
351
+ # remove CONTIGUOUS/DEVICE from kernel AST
352
+ (UPat (Ops .CONTIGUOUS , src = (UPat .var ("x" ),)), lambda x : x ),
353
+ (UPat (Ops .VIEW , src = (UPat (Ops .DEVICE ),), name = "view" ), lambda view : view .replace (src = ())),
354
+ ])
358
355
359
356
def check_load_st (glbl :UOp , view :UOp ):
360
357
if glbl .arg != 0 or (st := unwrap (view .st )).contiguous : return
0 commit comments