Skip to content

Commit bbb2dd8

Browse files
authored
move VALID creation after merging the views (tinygrad#8757)
* do valid creation later * work for view_left * only view(const) makes valids in view_left * cleaner bind diff
1 parent a6e496b commit bbb2dd8

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

Diff for: tinygrad/engine/schedule.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
199199
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
200200
# once images are loaded they become the base dtype
201201
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
202+
# CONST(VIEW) becomes VALID too, TODO: doesn't have to
203+
(UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)),
202204
])
203205

204206
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
@@ -438,11 +440,11 @@ def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
438440
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
439441
])
440442

441-
# **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp
443+
# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp
442444

443445
def unbind_variable(ctx:ScheduleContext, bind:UOp, var:UOp, val:UOp):
444-
assert isinstance(val.src[1].const_arg, int), f"expected BIND value to be int {val}"
445-
ctx.var_vals[ret:=var.replace(src=())] = val.src[1].const_arg
446+
assert isinstance(val.const_arg, int), f"expected BIND value to be int {val}"
447+
ctx.var_vals[ret:=var.replace(src=())] = val.const_arg
446448
return ret.valid(unwrap(bind.st))
447449

448450
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
@@ -456,8 +458,6 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
456458
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
457459

458460
break_sched = PatternMatcher([
459-
# CONST is always fused and generated
460-
(UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)),
461461
(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.var("val"))), unbind_variable),
462462
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
463463
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
@@ -481,9 +481,6 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
481481
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
482482
(UPat(Ops.VIEW, name="view"),
483483
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
484-
# merge unmasked const views
485-
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
486-
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
487484
])
488485

489486
@track_rewrites(named=True)

Diff for: tinygrad/ops.py

+5
Original file line numberDiff line numberDiff line change
@@ -1322,10 +1322,15 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
13221322
# VIEW(VIEW) merges to a single VIEW
13231323
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
13241324
(UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None),
1325+
# merge unmasked const views
1326+
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
1327+
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
13251328
])
13261329

13271330
# push VIEW to parents
13281331
view_left = merge_views+PatternMatcher([
1332+
# VIEW(CONST) becomes VALID
1333+
(UPat(Ops.VIEW, name="vm", src=(UPat.cvar("x"),)), lambda vm,x: UOp.const(x.dtype, x.const_arg).valid(vm.st)),
13291334
# VIEW before elementwise/buffer ops
13301335
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
13311336
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),

0 commit comments

Comments
 (0)