Skip to content

Commit d488bbb

Browse files
authored
share merge_views/valid creation for CONST/DEFINE_VAR (tinygrad#8758)
* share valid creation behavior for CONST/DEFINE_VAR * work
1 parent bbb2dd8 commit d488bbb

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

Diff for: tinygrad/engine/schedule.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
200200
# once images are loaded they become the base dtype
201201
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
202202
# CONST(VIEW) becomes VALID too, TODO: doesn't have to
203-
(UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)),
203+
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: x.replace(src=()).valid(st.st)),
204204
])
205205

206206
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
@@ -444,8 +444,8 @@ def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
444444

445445
def unbind_variable(ctx:ScheduleContext, bind:UOp, var:UOp, val:UOp):
446446
assert isinstance(val.const_arg, int), f"expected BIND value to be int {val}"
447-
ctx.var_vals[ret:=var.replace(src=())] = val.const_arg
448-
return ret.valid(unwrap(bind.st))
447+
ctx.var_vals[var.replace(src=())] = val.const_arg
448+
return var
449449

450450
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
451451
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN

Diff for: tinygrad/ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1323,14 +1323,14 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
13231323
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
13241324
(UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None),
13251325
# merge unmasked const views
1326-
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
1326+
(UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
13271327
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
13281328
])
13291329

13301330
# push VIEW to parents
13311331
view_left = merge_views+PatternMatcher([
13321332
# VIEW(CONST) becomes VALID
1333-
(UPat(Ops.VIEW, name="vm", src=(UPat.cvar("x"),)), lambda vm,x: UOp.const(x.dtype, x.const_arg).valid(vm.st)),
1333+
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.replace(src=()).valid(vm.st)),
13341334
# VIEW before elementwise/buffer ops
13351335
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
13361336
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),

0 commit comments

Comments
 (0)