@@ -260,13 +260,9 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
260
260
261
261
# ** create buffer ops + enumerate buffers
262
262
263
- def load_buf (ctx :list [UOp ], x :UOp ):
264
- if x not in ctx : ctx .append (x )
265
- return UOp (Ops .LOAD , x .dtype , (UOp (Ops .DEFINE_GLOBAL , x .dtype .ptr (x .size ), (), ctx .index (x )), unwrap (x .st ).to_uop ()))
266
-
267
263
add_buffer_ops = PatternMatcher ([
268
264
# LOAD
269
- (UPat (Ops .BUFFER , name = "x" ), load_buf ),
265
+ (UPat (Ops .BUFFER , name = "x" ), lambda ctx , x : UOp ( Ops . LOAD , x . dtype , ( UOp ( Ops . DEFINE_GLOBAL , x . dtype . ptr ( x . size ), (), ctx . index ( x )), x . st . to_uop ())) ),
270
266
# STORE (except for COPY/BUFFER_VIEW)
271
267
(UPat (Ops .SINK , src = (UPat ((Ops .COPY , Ops .BUFFER_VIEW ), name = "x" ),)), lambda x :x ),
272
268
(UPat (Ops .SINK , src = (UPat (GroupOp .All - {Ops .STORE }, name = "x" ),)),
@@ -278,8 +274,9 @@ def load_buf(ctx:list[UOp], x:UOp):
278
274
def apply_swizzle (u :UOp ) -> UOp :
279
275
with Context (TRACK_MATCH_STATS = 0 ): return graph_rewrite (u , view_left )
280
276
281
- def swizzle_r (r :UOp , src :UOp , st :ShapeTracker ) -> UOp :
282
- input_st = ShapeTracker .from_shape (unwrap (src .st ).shape )
277
+ def swizzle_reduceop (r :UOp , src :UOp , view :UOp ):
278
+ if (st := unwrap (view .st )).contiguous : return None
279
+ input_st = ShapeTracker .from_shape (src .shape )
283
280
tmp = input_st .permute (tuple (i for i in range (len (input_st .shape )) if i not in r .axis_arg )+ r .axis_arg )
284
281
prshape = prod (rshape := tmp .shape [- len (r .axis_arg ):])
285
282
strides = strides_for_shape (rshape )
@@ -290,20 +287,18 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
290
287
new_axis = tuple (range (len (st .shape ), len (st .shape ) + len (r .axis_arg )))
291
288
return apply_swizzle (src .view (new_input_st )).r (r .arg [0 ], new_axis ).view (ShapeTracker .from_shape (st .shape ))
292
289
293
- def reduceop_view_right (r :UOp , v :UOp , src :UOp ) -> UOp :
294
- if not (swizzle_st := unwrap (v .st )).contiguous or v .size != src .size : raise AssertionError (f"can't push { v } down through { src } " )
295
- output_shape = swizzle_st .reduce (r .axis_arg )
296
- return src .r (r .arg [0 ], tuple (i for i ,(s ,u ) in enumerate (zip (src .shape , output_shape )) if s != u )).view (ShapeTracker .from_shape (output_shape ))
290
+ def reduceop_view_right (src :UOp , v :UOp , r :UOp ):
291
+ assert unwrap (v .st ).contiguous and v .size == src .size , f"can't compute new axis for { src .shape } -> { r .shape } "
292
+ return src .r (r .arg [0 ], tuple (i for i ,(s ,u ) in enumerate (zip (src .shape , r .shape )) if s != u )).view (ShapeTracker .from_shape (r .shape ))
297
293
298
294
def elementwise_view_right (root :UOp ) -> UOp | None :
299
- if len (swizzles := [x for x in root .src if x .base is not x ]) == 0 : return None
300
- assert all (x .base .st is not None for x in swizzles ), f"found shapeless VIEW src in { root } "
295
+ if not (swizzles := [x for x in root .src if x .op is Ops .VIEW ]): return None
301
296
assert all_same ([x .base .size for x in swizzles ]), f"swizzle inputs must have the same size { swizzles } "
302
- # push the swizzle from src to root
303
- output_swizzle = swizzles [0 ]
304
- new_input_st = ShapeTracker . from_shape ( output_swizzle .base . shape )
305
- ret = root . replace ( src = tuple ( x if x . st is None else x . base if x in swizzles else apply_swizzle ( x . view ( new_input_st )) for x in root . src ))
306
- return ret .view ( ShapeTracker . from_shape ( output_swizzle . shape ) )
297
+ # place view after applying the elementwise op
298
+ new_shape = swizzles [0 ]. base . shape
299
+ ret = root . replace ( src = tuple ( x .base if x . base . shape == new_shape else apply_swizzle ( x . view ( ShapeTracker . from_shape ( new_shape ))) for x in root . src ) )
300
+ # reshape to match downstream shapes
301
+ return ret .reshape ( root . shape )
307
302
308
303
def merge_double_reduce (root :UOp , first_reduce :UOp ) -> UOp :
309
304
assert root .arg [0 ] == first_reduce .arg [0 ], "can't merge reduceops with different alu"
@@ -317,12 +312,12 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
317
312
lambda b ,target ,st ,val : apply_swizzle (UOp .store (b , st , val ).view (target .st ))),
318
313
# STORE is the last child, so we just merge the ShapeTrackers and store the base
319
314
(UPat (Ops .STORE , src = (UPat .var ("b" ), UPat .var ("st" ), UPat (Ops .VIEW , src = (UPat .var ("val" ),)))), lambda b ,st ,val : UOp .store (b , st .view (val .st ), val )),
320
- # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view( contiguous=True)).view()
321
- (UPat (Ops .REDUCE_AXIS , src = (UPat .var ("src" ),), name = "r" ). view ( name = "v " ), lambda v , r , src : None if v . st . contiguous else swizzle_r ( r , src , v . st ) ),
322
- # REDUCE(src. view()) -> REDUCE(src).view()
323
- (UPat (Ops .REDUCE_AXIS , src = (UPat . var ("src" ). view ( name = "v" ),), name = "r" ), reduceop_view_right ),
324
- # ALU(src. view()) -> ALU(src).view()
325
- (UPat (( * GroupOp .ALU , Ops . CAST , Ops . BITCAST , Ops . ASSIGN , Ops . CONTIGUOUS , Ops . STORE ) , name = "root" ), elementwise_view_right ),
315
+ # push a non contiguous ShapeTracker through reduceop
316
+ (UPat (Ops .VIEW , src = ( UPat ( Ops . REDUCE_AXIS , src = (UPat .var ("src" ),), name = "r" ),), name = "view " ), swizzle_reduceop ),
317
+ # apply view after reduceops
318
+ (UPat (Ops .REDUCE_AXIS , src = (UPat ( Ops . VIEW , src = ( UPat . var ("src" ),), name = "v" ),), name = "r" ), reduceop_view_right ),
319
+ # apply view after elementwise ops
320
+ (UPat (GroupOp .All - GroupOp . Buffer , name = "root" ), elementwise_view_right ),
326
321
# double reduce op collapses to a single reduce op
327
322
(UPat (Ops .REDUCE_AXIS , src = (UPat (Ops .REDUCE_AXIS , name = "first_reduce" ),), name = "root" ), merge_double_reduce ),
328
323
])
@@ -372,7 +367,7 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
372
367
# substitute kernel sources for the target buffer
373
368
ast = k .arg .ast .substitute ({s .src [1 ].arg .ast :s .src [0 ] for s in k .src if s .op is Ops .ASSIGN }).sink ()
374
369
# add buffer ops
375
- ast = graph_rewrite (ast , add_buffer_ops , bufs := [ s .buf_uop for s in k .src ] , bottom_up = True )
370
+ ast = graph_rewrite (ast , add_buffer_ops , bufs := tuple ( s .buf_uop for s in k .src ) , bottom_up = True )
376
371
if ast .op is Ops .SINK and not all_same (dev := [x .device for x in bufs ]): raise RuntimeError (f"all buffers must be on the same device: { dev } " )
377
372
# unbind_vars + push views to edges
378
373
ast = graph_rewrite (graph_rewrite (ast , unbind_vars + view_left , ctx = var_vals ), view_right )
0 commit comments