@@ -203,7 +203,57 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> Sched
203
203
def save_process_replay () -> None :
204
204
for k ,v in PROCESS_REPLAY_CAPTURE .items (): diskcache_put ("schedule_process_replay" , k , v , prepickled = True )
205
205
206
- # **** Schedule grouping
206
+ # **** UOp realization
207
+
208
+ class UPatScheduled (UPat ):
209
+ def __init__ (self , * args , ** kwargs ):
210
+ super ().__init__ (Ops .VIEW , name = "base" , src = (UPat (Ops .BUFFER , name = "b" ), UPat (* args , ** {"name" :"to_store" ,** kwargs })))
211
+
212
+ def realize (ctx :ScheduleContext , b :UOp , to_store :UOp , ** kwargs ) -> None : ctx .realizes [b ] = to_store
213
+
214
+ def realize_before_view (ctx :ScheduleContext , view :UOp , src :UOp , b :UOp , ** kwargs ) -> None :
215
+ st = unwrap (view .st )
216
+ # fold simple pads
217
+ if len (st .views ) == 1 and (m := st .views [- 1 ].mask ) is not None and all_int (src .shape ) and resolve (prod (src .shape ) >= prod ([y - x for x ,y in m ])):
218
+ return None if can_pad (src , ctx .realizes , set ()) else realize (ctx , b , src )
219
+ # early realize before expand
220
+ if resolve (prod (src .shape ) < prod (st .shape )) and not getenv ("DONT_REALIZE_EXPAND" ): return realize (ctx , b , src )
221
+ # otherwise safety check pads
222
+ return None if (all (v .mask is None for v in st .views ) or can_pad (src , ctx .realizes , set ())) else realize (ctx , b , src )
223
+
224
+ def fold_img_cast (ctx :ScheduleContext , xb :UOp , view :UOp , b :UOp , x :UOp , ** kwargs ) -> UOp | None :
225
+ if not isinstance (xb .dtype , ImageDType ) or b not in ctx .realizes or xb not in ctx .realizes or uval (x .base ).op is Ops .COPY : return None
226
+ del ctx .realizes [b ]
227
+ return x .view (unwrap (view .st ))
228
+
229
+ def create_subbuffer (base :UOp , b :UOp , root :UOp , x :UOp ):
230
+ if isinstance (b .device , tuple ) or not b .device .startswith ("DISK" ): return None
231
+ buffers [b ] = x .buf_uop .buffer .view (b .size , b .dtype , unwrap (x .st ).views [0 ].offset * x .dtype .itemsize )
232
+ return base .replace (src = (b , root .replace (op = Ops .BUFFER_VIEW )))
233
+
234
+ do_realize = PatternMatcher ([
235
+ # always realize SINK parents
236
+ (UPat (Ops .SINK , name = "sink" ), lambda ctx ,sink : ctx .realizes .update ((x .buf_uop , x ) for x in sink .src )),
237
+ # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
238
+ (UPatScheduled ({Ops .ASSIGN , Ops .CONTIGUOUS , Ops .COPY , Ops .BUFFER_VIEW }), realize ),
239
+ # realize before expand or unsafe pad ops
240
+ (UPat (Ops .VIEW , name = "view" , src = (UPatScheduled (name = "src" ),)), realize_before_view ),
241
+ # don't realize image to image casts
242
+ (UPat (Ops .VIEW , name = "view" , src = (UPatScheduled (Ops .CAST , src = (UPat (Ops .VIEW , src = (UPat .var ("xb" ), UPat ()), name = "x" ),), dtype = dtypes .float ),)),
243
+ fold_img_cast ),
244
+ # realize before COPY or BUFFER_VIEW
245
+ (UPat (Ops .COPY , src = (UPat (), UPat .any (UPatScheduled (), UPatScheduled ().view ()),)), realize ),
246
+ (UPat (Ops .BUFFER_VIEW , src = (UPat .any (UPatScheduled (), UPatScheduled ().view ()),)), realize ),
247
+ # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
248
+ (UPatScheduled ((Ops .BITCAST , Ops .CONTIGUOUS ), name = "root" , src = (UPat .var ("x" ),)), create_subbuffer ),
249
+ ])
250
+
251
+ def append_uop (ctx :ScheduleContext , view :UOp , buf_uop :UOp ) -> None :
252
+ ctx .allbufs [buf_uop ] = view
253
+ if (op := uval (view )).op is Ops .ASSIGN : ctx .assigns .add (buf_uop )
254
+ for x in op .base .src :
255
+ if is_scheduled (x .base ): ctx .children .setdefault (x .base .buf_uop , {})[buf_uop ] = None
256
+ create_ctx = PatternMatcher ([(UPat (Ops .VIEW , name = "view" , src = (UPat (Ops .BUFFER , name = "buf_uop" ), UPat ())), append_uop )])
207
257
208
258
def is_scheduled (u :UOp ) -> bool : return u .op is Ops .VIEW and len (u .src ) == 2 and u .src [0 ].op is Ops .BUFFER
209
259
def uval (u :UOp ) -> UOp :
@@ -228,8 +278,9 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di
228
278
if len (st_childs := dedup (unwrap (x .st ) for x in tr_next_uop .src if is_scheduled (x .base ) and x .base .buf_uop == tr )) > 1 : return group .setdefault (r )
229
279
recursive_group (tr_next , st + st_childs [0 ], r , children , allbufs , realizes , reduce_for_op , group , cache )
230
280
231
- def group_realizes (ctx :ScheduleContext ) -> None :
232
- """search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
281
+ def group_realizes (sink :UOp , ctx :ScheduleContext ) -> dict [UOp , UOp ]:
282
+ # start by adding uops that always realize
283
+ sink = graph_rewrite (sink , do_realize + create_ctx , ctx )
233
284
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
234
285
reduce_for_op : dict [UOp , UOp ] = {}
235
286
double_reduces : list [UOp ] = []
@@ -280,10 +331,28 @@ def group_realizes(ctx:ScheduleContext) -> None:
280
331
for reduceop in double_reduces :
281
332
top_reduce = uval (ctx .allbufs [reduceop ]).src [0 ].base .buf_uop
282
333
if len (ctx .children [top_reduce ]) == 1 : del ctx .realizes [top_reduce ]
334
+ graph_rewrite (sink , break_sched , ctx )
335
+ return ctx .realizes
283
336
284
- # **** Schedule creation and BFS toposort
337
+ # break the SINK into stores
285
338
286
- # ** this is schedule level const folding
339
+ def load_realized (ctx :ScheduleContext , b :UOp , st :UOp ):
340
+ # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
341
+ return UOp (Ops .PRELOAD if b in ctx .assigns else Ops .LOAD , b .dtype .base , (b , unwrap (st .st ).to_uop ()))
342
+
343
+ def store_or_fuse (ctx :ScheduleContext , b :UOp , x :UOp , st :UOp ):
344
+ if (m := ctx .tensor_uops [b ][- 1 ].metadata ) is not None : ctx .ops_metadata [x ] = m
345
+ if b not in ctx .realizes : return x # collapse BUFFER
346
+ ctx .realizes [b ] = UOp .store (b , ShapeTracker .from_shape (st .shape ).to_uop (), x )
347
+ return UOp (Ops .LOAD , x .dtype , (b , unwrap (st .st ).to_uop ()))
348
+
349
+ break_sched = PatternMatcher ([
350
+ # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
351
+ (UPat (Ops .VIEW , name = "st" , src = (UPat (Ops .BUFFER , name = "b" ),)), load_realized ),
352
+ (UPat (Ops .VIEW , name = "st" , src = (UPat (Ops .BUFFER , name = "b" ), UPat .var ("x" ))), store_or_fuse ),
353
+ ])
354
+
355
+ # **** schedule simplifier
287
356
288
357
def simplify_reduceop (reduce :UOp , x :UOp ) -> UOp | None :
289
358
if not all_int (x .shape ): return None
@@ -338,80 +407,6 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
338
407
if (new_src := tuple (x for x in root .src if not x .is_realized and x .base .op not in {Ops .CONST , Ops .BIND })) != root .src else None ),
339
408
])
340
409
341
- # ** this decides which ops get realized
342
-
343
- class UPatScheduled (UPat ):
344
- def __init__ (self , * args , ** kwargs ):
345
- super ().__init__ (Ops .VIEW , name = "base" , src = (UPat (Ops .BUFFER , name = "b" ), UPat (* args , ** {"name" :"to_store" ,** kwargs })))
346
-
347
- def realize (ctx :ScheduleContext , b :UOp , to_store :UOp , ** kwargs ) -> None : ctx .realizes [b ] = to_store
348
-
349
- def realize_before_view (ctx :ScheduleContext , view :UOp , src :UOp , b :UOp , ** kwargs ) -> None :
350
- st = unwrap (view .st )
351
- # fold simple pads
352
- if len (st .views ) == 1 and (m := st .views [- 1 ].mask ) is not None and all_int (src .shape ) and resolve (prod (src .shape ) >= prod ([y - x for x ,y in m ])):
353
- return None if can_pad (src , ctx .realizes , set ()) else realize (ctx , b , src )
354
- # early realize before expand
355
- if resolve (prod (src .shape ) < prod (st .shape )) and not getenv ("DONT_REALIZE_EXPAND" ): return realize (ctx , b , src )
356
- # otherwise safety check pads
357
- return None if (all (v .mask is None for v in st .views ) or can_pad (src , ctx .realizes , set ())) else realize (ctx , b , src )
358
-
359
- def fold_img_cast (ctx :ScheduleContext , xb :UOp , view :UOp , b :UOp , x :UOp , ** kwargs ) -> UOp | None :
360
- if not isinstance (xb .dtype , ImageDType ) or b not in ctx .realizes or xb not in ctx .realizes or uval (x .base ).op is Ops .COPY : return None
361
- del ctx .realizes [b ]
362
- return x .view (unwrap (view .st ))
363
-
364
- def create_subbuffer (base :UOp , b :UOp , root :UOp , x :UOp ):
365
- if isinstance (b .device , tuple ) or not b .device .startswith ("DISK" ): return None
366
- buffers [b ] = x .buf_uop .buffer .view (b .size , b .dtype , unwrap (x .st ).views [0 ].offset * x .dtype .itemsize )
367
- return base .replace (src = (b , root .replace (op = Ops .BUFFER_VIEW )))
368
-
369
- do_realize = PatternMatcher ([
370
- # always realize SINK parents
371
- (UPat (Ops .SINK , name = "sink" ), lambda ctx ,sink : ctx .realizes .update ((x .buf_uop , x ) for x in sink .src )),
372
- # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
373
- (UPatScheduled ({Ops .ASSIGN , Ops .CONTIGUOUS , Ops .COPY , Ops .BUFFER_VIEW }), realize ),
374
- # realize before expand or unsafe pad ops
375
- (UPat (Ops .VIEW , name = "view" , src = (UPatScheduled (name = "src" ),)), realize_before_view ),
376
- # don't realize image to image casts
377
- (UPat (Ops .VIEW , name = "view" , src = (UPatScheduled (Ops .CAST , src = (UPat (Ops .VIEW , src = (UPat .var ("xb" ), UPat ()), name = "x" ),), dtype = dtypes .float ),)),
378
- fold_img_cast ),
379
- # realize before COPY or BUFFER_VIEW
380
- (UPat (Ops .COPY , src = (UPat (), UPat .any (UPatScheduled (), UPatScheduled ().view ()),)), realize ),
381
- (UPat (Ops .BUFFER_VIEW , src = (UPat .any (UPatScheduled (), UPatScheduled ().view ()),)), realize ),
382
- # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
383
- (UPatScheduled ((Ops .BITCAST , Ops .CONTIGUOUS ), name = "root" , src = (UPat .var ("x" ),)), create_subbuffer ),
384
- ])
385
-
386
- # **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp
387
-
388
- def load_realized (ctx :ScheduleContext , b :UOp , st :UOp ):
389
- # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
390
- return UOp (Ops .PRELOAD if b in ctx .assigns else Ops .LOAD , b .dtype .base , (b , unwrap (st .st ).to_uop ()))
391
-
392
- def store_or_fuse (ctx :ScheduleContext , b :UOp , x :UOp , st :UOp ):
393
- if (m := ctx .tensor_uops [b ][- 1 ].metadata ) is not None : ctx .ops_metadata [x ] = m
394
- if b not in ctx .realizes : return x # collapse BUFFER
395
- ctx .realizes [b ] = UOp .store (b , ShapeTracker .from_shape (st .shape ).to_uop (), x )
396
- return UOp (Ops .LOAD , x .dtype , (b , unwrap (st .st ).to_uop ()))
397
-
398
- break_sched = PatternMatcher ([
399
- # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
400
- (UPat (Ops .VIEW , name = "st" , src = (UPat (Ops .BUFFER , name = "b" ),)), load_realized ),
401
- (UPat (Ops .VIEW , name = "st" , src = (UPat (Ops .BUFFER , name = "b" ), UPat .var ("x" ))), store_or_fuse ),
402
- ])
403
-
404
- # **** Schedule context builder
405
-
406
- def append_uop (ctx :ScheduleContext , view :UOp , buf_uop :UOp ) -> None :
407
- ctx .allbufs [buf_uop ] = view
408
- if (op := uval (view )).op is Ops .ASSIGN : ctx .assigns .add (buf_uop )
409
- for x in op .base .src :
410
- if is_scheduled (x .base ): ctx .children .setdefault (x .base .buf_uop , {})[buf_uop ] = None
411
- create_ctx = PatternMatcher ([(UPat (Ops .VIEW , name = "view" , src = (UPat (Ops .BUFFER , name = "buf_uop" ), UPat ())), append_uop )])
412
-
413
- # **** movement ops
414
-
415
410
remove_movement_ops = merge_views + PatternMatcher ([
416
411
# NOTE: movement ops are always applied to base
417
412
(UPat (GroupOp .Movement , name = "mov" , src = (UPat .var ("x" ),)), lambda x ,mov : x .view (unwrap (mov .st ))),
@@ -420,6 +415,8 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
420
415
lambda view : view .const_like (0 ) if (vm := view .st .views [- 1 ].mask ) is not None and any ((x [1 ]- x [0 ]) == 0 for x in vm ) else None ),
421
416
])
422
417
418
+ # **** schedule creation and toposort
419
+
423
420
@track_rewrites (named = True )
424
421
def create_schedule_with_vars (big_sink :UOp ) -> tuple [list [ScheduleItem ], dict [Variable , int ], dict [UOp , UOp ]]:
425
422
tensor_map = graph_rewrite_map (big_sink , remove_movement_ops + sym , ctx = {})
@@ -438,11 +435,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
438
435
for k ,v in tensor_map .items (): rev_tensor_map .setdefault (v , []).append (k )
439
436
# add BUFFER uops
440
437
sink = add_buffers (tensor_map [big_sink ], rev_tensor_map , ctx := ScheduleContext (), cache = {})
441
- # add realizes
442
- sink = graph_rewrite (sink , do_realize + create_ctx , ctx )
443
- # group realizes into kernels
444
- group_realizes (ctx )
445
- graph_rewrite (sink , break_sched , ctx )
438
+ # get realizes
439
+ realize_map = group_realizes (sink , ctx )
446
440
447
441
# TODO: this should be the break between the "grouper" and the "linearizer"
448
442
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
@@ -451,7 +445,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
451
445
# create schedule items + map buffers to realized tensors
452
446
prescheduled : list [ScheduleItem ] = []
453
447
var_vals : dict [Variable , int ] = {}
454
- for buf_uop ,store in ctx . realizes .items ():
448
+ for buf_uop ,store in realize_map .items ():
455
449
assert store .op is Ops .STORE , f"expected a realized BUFFER to get a STORE { sink } "
456
450
prescheduled .append (schedule_uop (store .sink (), ctx , var_vals ))
457
451
# can only schedule once
0 commit comments