@@ -235,10 +235,6 @@ def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
235
235
236
236
# break the SINK into stores
237
237
238
- def load_realized (ctx :ScheduleContext , b :UOp , st :UOp ):
239
- # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
240
- return UOp (Ops .PRELOAD if b in ctx .assigns else Ops .LOAD , b .dtype .base , (b , unwrap (st .st ).to_uop ()))
241
-
242
238
def store_or_fuse (ctx :ScheduleContext , b :UOp , x :UOp , st :UOp ):
243
239
if (m := ctx .ops_metadata .get (b )) is not None : ctx .ops_metadata [x ] = m
244
240
if b not in ctx .realizes : return x # collapse BUFFER
@@ -247,7 +243,8 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
247
243
248
244
break_sched = PatternMatcher ([
249
245
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
250
- (UPat (Ops .VIEW , name = "st" , src = (UPat (Ops .BUFFER , name = "b" ),)), load_realized ),
246
+ (UPat (Ops .VIEW , name = "st" , src = (UPat (Ops .BUFFER , name = "b" ),)),
247
+ lambda ctx ,st ,b : UOp (Ops .PRELOAD if b in ctx .assigns else Ops .LOAD , b .dtype .base , (b , st .st .to_uop ()))),
251
248
(UPat (Ops .VIEW , name = "st" , src = (UPat (Ops .BUFFER , name = "b" ), UPat .var ("x" ))), store_or_fuse ),
252
249
])
253
250
0 commit comments