Skip to content

Commit d64af3c

Browse files
authored
reorder simplifier and grouper logic in scheduler [pr] (tinygrad#8861)
1 parent 83a904a commit d64af3c

File tree

1 file changed

+79
-85
lines changed

1 file changed

+79
-85
lines changed

tinygrad/engine/schedule.py

Lines changed: 79 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,57 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> Sched
203203
def save_process_replay() -> None:
204204
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
205205

206-
# **** Schedule grouping
206+
# **** UOp realization
207+
208+
class UPatScheduled(UPat):
209+
def __init__(self, *args, **kwargs):
210+
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
211+
212+
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
213+
214+
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
215+
st = unwrap(view.st)
216+
# fold simple pads
217+
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
218+
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
219+
# early realize before expand
220+
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
221+
# otherwise safety check pads
222+
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
223+
224+
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None:
225+
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None
226+
del ctx.realizes[b]
227+
return x.view(unwrap(view.st))
228+
229+
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
230+
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
231+
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
232+
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
233+
234+
do_realize = PatternMatcher([
235+
# always realize SINK parents
236+
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
237+
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
238+
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
239+
# realize before expand or unsafe pad ops
240+
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
241+
# don't realize image to image casts
242+
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)),
243+
fold_img_cast),
244+
# realize before COPY or BUFFER_VIEW
245+
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
246+
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
247+
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
248+
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
249+
])
250+
251+
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
252+
ctx.allbufs[buf_uop] = view
253+
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
254+
for x in op.base.src:
255+
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
256+
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
207257

208258
def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2 and u.src[0].op is Ops.BUFFER
209259
def uval(u:UOp) -> UOp:
@@ -228,8 +278,9 @@ def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:defaultdict[UOp, di
228278
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r)
229279
recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache)
230280

231-
def group_realizes(ctx:ScheduleContext) -> None:
232-
"""search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop"""
281+
def group_realizes(sink:UOp, ctx:ScheduleContext) -> dict[UOp, UOp]:
282+
# start by adding uops that always realize
283+
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
233284
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
234285
reduce_for_op: dict[UOp, UOp] = {}
235286
double_reduces: list[UOp] = []
@@ -280,10 +331,28 @@ def group_realizes(ctx:ScheduleContext) -> None:
280331
for reduceop in double_reduces:
281332
top_reduce = uval(ctx.allbufs[reduceop]).src[0].base.buf_uop
282333
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
334+
graph_rewrite(sink, break_sched, ctx)
335+
return ctx.realizes
283336

284-
# **** Schedule creation and BFS toposort
337+
# break the SINK into stores
285338

286-
# ** this is schedule level const folding
339+
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
340+
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
341+
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
342+
343+
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
344+
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
345+
if b not in ctx.realizes: return x # collapse BUFFER
346+
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
347+
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
348+
349+
break_sched = PatternMatcher([
350+
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
351+
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
352+
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
353+
])
354+
355+
# **** schedule simplifier
287356

288357
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
289358
if not all_int(x.shape): return None
@@ -338,80 +407,6 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
338407
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
339408
])
340409

341-
# ** this decides which ops get realized
342-
343-
class UPatScheduled(UPat):
344-
def __init__(self, *args, **kwargs):
345-
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
346-
347-
def realize(ctx:ScheduleContext, b:UOp, to_store:UOp, **kwargs) -> None: ctx.realizes[b] = to_store
348-
349-
def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> None:
350-
st = unwrap(view.st)
351-
# fold simple pads
352-
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(src.shape) and resolve(prod(src.shape) >= prod([y-x for x,y in m])):
353-
return None if can_pad(src, ctx.realizes, set()) else realize(ctx, b, src)
354-
# early realize before expand
355-
if resolve(prod(src.shape) < prod(st.shape)) and not getenv("DONT_REALIZE_EXPAND"): return realize(ctx, b, src)
356-
# otherwise safety check pads
357-
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
358-
359-
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) -> UOp|None:
360-
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(x.base).op is Ops.COPY: return None
361-
del ctx.realizes[b]
362-
return x.view(unwrap(view.st))
363-
364-
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
365-
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
366-
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
367-
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
368-
369-
do_realize = PatternMatcher([
370-
# always realize SINK parents
371-
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.realizes.update((x.buf_uop, x) for x in sink.src)),
372-
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
373-
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}), realize),
374-
# realize before expand or unsafe pad ops
375-
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(name="src"),)), realize_before_view),
376-
# don't realize image to image casts
377-
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="x"),), dtype=dtypes.float),)),
378-
fold_img_cast),
379-
# realize before COPY or BUFFER_VIEW
380-
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
381-
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
382-
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
383-
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
384-
])
385-
386-
# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp
387-
388-
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
389-
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
390-
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
391-
392-
def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
393-
if (m:=ctx.tensor_uops[b][-1].metadata) is not None: ctx.ops_metadata[x] = m
394-
if b not in ctx.realizes: return x # collapse BUFFER
395-
ctx.realizes[b] = UOp.store(b, ShapeTracker.from_shape(st.shape).to_uop(), x)
396-
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
397-
398-
break_sched = PatternMatcher([
399-
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
400-
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
401-
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
402-
])
403-
404-
# **** Schedule context builder
405-
406-
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
407-
ctx.allbufs[buf_uop] = view
408-
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
409-
for x in op.base.src:
410-
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
411-
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
412-
413-
# **** movement ops
414-
415410
remove_movement_ops = merge_views+PatternMatcher([
416411
# NOTE: movement ops are always applied to base
417412
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
@@ -420,6 +415,8 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
420415
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
421416
])
422417

418+
# **** schedule creation and toposort
419+
423420
@track_rewrites(named=True)
424421
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
425422
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
@@ -438,11 +435,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
438435
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
439436
# add BUFFER uops
440437
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={})
441-
# add realizes
442-
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
443-
# group realizes into kernels
444-
group_realizes(ctx)
445-
graph_rewrite(sink, break_sched, ctx)
438+
# get realizes
439+
realize_map = group_realizes(sink, ctx)
446440

447441
# TODO: this should be the break between the "grouper" and the "linearizer"
448442
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
@@ -451,7 +445,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
451445
# create schedule items + map buffers to realized tensors
452446
prescheduled: list[ScheduleItem] = []
453447
var_vals: dict[Variable, int] = {}
454-
for buf_uop,store in ctx.realizes.items():
448+
for buf_uop,store in realize_map.items():
455449
assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}"
456450
prescheduled.append(schedule_uop(store.sink(), ctx, var_vals))
457451
# can only schedule once

0 commit comments

Comments
 (0)