@@ -142,8 +142,7 @@ def elementwise_view_right(root:UOp) -> UOp|None:
142
142
output_swizzle = swizzles [0 ]
143
143
new_input_st = ShapeTracker .from_shape (output_swizzle .base .shape )
144
144
ret = root .replace (src = tuple (x if x .st is None else x .base if x in swizzles else apply_swizzle (x .view (new_input_st )) for x in root .src ))
145
- # NOTE: swizzle resolves once we hit STORE
146
- return ret if ret .op is Ops .STORE else ret .view (ShapeTracker .from_shape (output_swizzle .shape ))
145
+ return ret .view (ShapeTracker .from_shape (output_swizzle .shape ))
147
146
148
147
def merge_double_reduce (root :UOp , first_reduce :UOp ) -> UOp :
149
148
assert root .arg [0 ] == first_reduce .arg [0 ], "can't merge reduceops with different alu"
@@ -155,6 +154,8 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
155
154
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
156
155
(UPat (Ops .STORE , src = (UPat .var ("b" ), UPat .var ("st" ), UPat .assign (UPat .var ("target" ), UPat .var ("val" )))),
157
156
lambda b ,target ,st ,val : apply_swizzle (UOp .store (b , st , val ).view (target .st ))),
157
+ # STORE is the last child, so we just merge the ShapeTrackers and store the base
158
+ (UPat (Ops .STORE , src = (UPat .var ("b" ), UPat .var ("st" ), UPat (Ops .VIEW , src = (UPat .var ("val" ),)))), lambda b ,st ,val : UOp .store (b , st .view (val .st ), val )),
158
159
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
159
160
(UPat (Ops .REDUCE_AXIS , src = (UPat .var ("src" ),), name = "r" ).view (name = "v" ), lambda v ,r ,src : None if v .st .contiguous else swizzle_r (r , src , v .st )),
160
161
# REDUCE(src.view()) -> REDUCE(src).view()
@@ -303,7 +304,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
303
304
if (p_uop := ctx .allbufs .get (p := parents .pop ())) is None : continue
304
305
if (p_uop := uval (p_uop )).op is Ops .ASSIGN and p not in group : forced_realize , can_chase = True , False
305
306
if p in ctx .realizes : continue
306
- parents .extend ([x .base .src [ 0 ] for x in p_uop .src if x .base .op is Ops .VIEW and len (x .base .src ) != 0 ])
307
+ parents .extend ([x .base .buf_uop for x in p_uop .src if x .base .is_realized or ( x . base . op is Ops .VIEW and len (x .base .src ) != 0 ) ])
307
308
if forced_realize or not group :
308
309
tr = r
309
310
if can_chase :
0 commit comments