Skip to content

Commit ba17786

Browse files
Qazalingeohot
andauthored
do not construct unmasked VALID (tinygrad#8759)
* new lines that exist in codegen/ops * update tests * update sops.gz (13071 -> 13070 asts) * fix viz too * remove that TODO * diff pruning * mask assert + device * work * diff pruning * re: fix viz too --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
1 parent 3417bc1 commit ba17786

11 files changed

+58
-65
lines changed

extra/datasets/sops.gz

-66 KB
Binary file not shown.

test/helpers.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tinygrad.engine.realize import Runner
99
from tinygrad.dtype import ConstType, DType
1010
from tinygrad.nn.state import get_parameters
11-
from tinygrad.helpers import T
11+
from tinygrad.helpers import T, unwrap
1212
from tinygrad.codegen.linearize import linearize_uop
1313
from tinygrad.codegen.rewriter import full_graph_rewrite
1414
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator
@@ -43,7 +43,9 @@ def rand_for_dtype(dt:DType, size:int):
4343
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
4444
if st_src is None:
4545
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
46-
return UOp(Ops.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
46+
st = unwrap(st_src[0].st)
47+
if all(v.mask is None for v in st.views): return UOp.const(dtype, val).replace(src=(st.to_uop(),))
48+
return UOp.const(dtype, val).valid(st)
4749

4850
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
4951
st = time.perf_counter_ns()

test/test_linearizer.py

+2
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def test_const_alu_indexing(self):
121121
x = Tensor.randn(4,).realize()
122122
helper_linearizer_ast(store.sink(), [x], wanna_output=[x.numpy()+1*-1], opts=[])
123123

124+
# shapeless CONST in AST is not supported
125+
@unittest.expectedFailure
124126
def test_const_alu_indexing_one_const_fine(self):
125127
st = ShapeTracker.from_shape((4,)).to_uop()
126128
load = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), st, dtype=dtypes.float)

test/test_linearizer_failures.py

+30-50
Original file line numberDiff line numberDiff line change
@@ -1206,18 +1206,13 @@ def test_failure_51(self):
12061206
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(1024, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
12071207
UOp(Ops.RECIP, dtypes.half, arg=None, src=(
12081208
UOp(Ops.ADD, dtypes.half, arg=None, src=(
1209-
UOp(Ops.WHERE, dtypes.half, arg=None, src=(
1210-
x6:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
1211-
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
1212-
UOp(Ops.CONST, dtypes.half, arg=1.0, src=()),
1213-
x9:=UOp(Ops.CONST, dtypes.half, arg=0.0, src=()),)),
1209+
UOp(Ops.CONST, dtypes.half, arg=1.0, src=(
1210+
x6:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
12141211
UOp(Ops.EXP2, dtypes.half, arg=None, src=(
12151212
UOp(Ops.MUL, dtypes.half, arg=None, src=(
12161213
UOp(Ops.MUL, dtypes.half, arg=None, src=(
1217-
UOp(Ops.WHERE, dtypes.half, arg=None, src=(
1218-
x6,
1219-
UOp(Ops.CONST, dtypes.half, arg=2.0, src=()),
1220-
x9,)),
1214+
UOp(Ops.CONST, dtypes.half, arg=2.0, src=(
1215+
x6,)),
12211216
UOp(Ops.ADD, dtypes.half, arg=None, src=(
12221217
UOp(Ops.CAST, dtypes.half, arg=None, src=(
12231218
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
@@ -1232,10 +1227,8 @@ def test_failure_51(self):
12321227
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
12331228
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()),
12341229
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
1235-
UOp(Ops.WHERE, dtypes.half, arg=None, src=(
1236-
x6,
1237-
UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=()),
1238-
x9,)),)),)),)),)),)),))
1230+
UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=(
1231+
x6,)),)),)),)),)),)),))
12391232
opts = [Opt(op=OptOps.TC, axis=0, arg=2)]
12401233
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
12411234

@@ -1283,17 +1276,14 @@ def test_failure_53(self):
12831276
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
12841277
UOp(Ops.VALID, dtypes.bool, arg=None, src=(
12851278
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50001, 99999), strides=(0, 0), offset=0, mask=((0, 50001), (49999, 99999)), contiguous=False), View(shape=(1024, 50000, 50000), strides=(0, 1, 100000), offset=0, mask=None, contiguous=False))), src=()),)),
1286-
UOp(Ops.CONST, dtypes.int, arg=1, src=()),
1287-
x20:=UOp(Ops.CONST, dtypes.int, arg=0, src=()),)),)),
1288-
UOp(Ops.WHERE, dtypes.int, arg=None, src=(
1289-
x22:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
1290-
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
1291-
UOp(Ops.CONST, dtypes.int, arg=-1, src=()),
1292-
x20,)),)),)),
1293-
UOp(Ops.WHERE, dtypes.bool, arg=None, src=(
1294-
x22,
1295-
UOp(Ops.CONST, dtypes.bool, arg=True, src=()),
1296-
UOp(Ops.CONST, dtypes.bool, arg=False, src=()),)),)),)),)),)),)),))
1279+
UOp(Ops.CONST, dtypes.int, arg=1, src=(
1280+
x20:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 50000), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
1281+
UOp(Ops.CONST, dtypes.int, arg=0, src=(
1282+
x20,)),)),)),
1283+
UOp(Ops.CONST, dtypes.int, arg=-1, src=(
1284+
x23:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
1285+
UOp(Ops.CONST, dtypes.bool, arg=True, src=(
1286+
x23,)),)),)),)),)),)),))
12971287
opts = [Opt(op=OptOps.GROUPTOP, axis=1, arg=16)]
12981288
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["AMD", "GPU", "METAL", "NV", "CUDA"])
12991289

@@ -1348,11 +1338,8 @@ def test_failure_56(self):
13481338
UOp(Ops.MUL, dtypes.float, arg=None, src=(
13491339
UOp(Ops.CAST, dtypes.float, arg=None, src=(
13501340
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
1351-
x7:=UOp(Ops.WHERE, dtypes.float, arg=None, src=(
1352-
x8:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
1353-
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
1354-
x10:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),
1355-
x10,)),
1341+
x7:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=(
1342+
x8:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
13561343
UOp(Ops.MAX, dtypes.float, arg=None, src=(
13571344
UOp(Ops.ADD, dtypes.float, arg=None, src=(
13581345
UOp(Ops.MUL, dtypes.float, arg=None, src=(
@@ -1364,20 +1351,18 @@ def test_failure_56(self):
13641351
UOp(Ops.MUL, dtypes.float, arg=None, src=(
13651352
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
13661353
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
1367-
x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
1368-
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
1369-
x8,
1370-
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
1371-
x10,)),)),)),
1354+
x20:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
1355+
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=(
1356+
x8,)),)),)),
13721357
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
13731358
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
1374-
x22,)),)),
1359+
x20,)),)),
13751360
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
13761361
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
1377-
x22,)),)),
1362+
x20,)),)),
13781363
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
13791364
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
1380-
x22,)),)),
1365+
x20,)),)),
13811366
x7,)),)),)),
13821367
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
13831368
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
@@ -1394,11 +1379,8 @@ def test_failure_57(self):
13941379
UOp(Ops.MUL, dtypes.float, arg=None, src=(
13951380
UOp(Ops.CAST, dtypes.float, arg=None, src=(
13961381
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
1397-
x7:=UOp(Ops.WHERE, dtypes.float, arg=None, src=(
1398-
x8:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
1399-
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
1400-
x10:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),
1401-
x10,)),
1382+
x7:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=(
1383+
x8:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
14021384
UOp(Ops.MAX, dtypes.float, arg=None, src=(
14031385
UOp(Ops.ADD, dtypes.float, arg=None, src=(
14041386
UOp(Ops.MUL, dtypes.float, arg=None, src=(
@@ -1410,20 +1392,18 @@ def test_failure_57(self):
14101392
UOp(Ops.MUL, dtypes.float, arg=None, src=(
14111393
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
14121394
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
1413-
x22:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
1414-
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
1415-
x8,
1416-
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
1417-
x10,)),)),)),
1395+
x20:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 11, 11), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
1396+
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=(
1397+
x8,)),)),)),
14181398
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
14191399
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
1420-
x22,)),)),
1400+
x20,)),)),
14211401
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
14221402
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
1423-
x22,)),)),
1403+
x20,)),)),
14241404
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
14251405
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
1426-
x22,)),)),
1406+
x20,)),)),
14271407
x7,)),)),)),
14281408
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
14291409
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),

test/test_schedule.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -2026,8 +2026,7 @@ def test_late_fusion_post_permute_simpler(self):
20262026
self.assertEqual(swizzle_cnt(ret), 1)
20272027

20282028
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
2029-
# TODO: we only need valid on ast consts if it's masked, can fold this early to UOp.const
2030-
zero_pm = UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat(Ops.CONST, arg=0), UPat.cvar()))
2029+
zero_pm = UPat(Ops.CONST, arg=0)
20312030
class TestView(unittest.TestCase):
20322031
def test_all_masked_out(self):
20332032
# start with non CONST Ops
@@ -2193,7 +2192,7 @@ def test_unmasked_const_ast(self):
21932192
a = Tensor.ones((4,)).contiguous()
21942193
sched = a.schedule()
21952194
print(sched[0].ast)
2196-
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0)))),))
2195+
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(Ops.CONST)),))
21972196
self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1)
21982197
run_schedule(sched)
21992198
self.assertListEqual(a.tolist(), [1, 1, 1, 1])

test/test_uops.py

-1
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,6 @@ def test_expanded_const(self):
524524
a = Tensor.ones((4, 4)).lazydata
525525
self.assertEqual(a.st, ShapeTracker.from_shape(()).reshape((1,1)).expand((4,4)))
526526

527-
@unittest.expectedFailure
528527
def test_padded_const(self):
529528
a = Tensor.ones((1, 1)).pad(((1, 1), (1, 1)))
530529
ast = a.contiguous().schedule()[0].ast

test/unit/test_verify_ast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_buffer_uops_st(self):
7878
verify_ast(ast:=a.schedule()[-1].ast)
7979
store_st = [u.st for u in ast.toposort if u.op is Ops.STORE][0]
8080
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
81-
const_st = [u.st for u in ast.toposort if u.op is Ops.VALID][0]
81+
const_st = [u.st for u in ast.toposort if u.op is Ops.CONST][0]
8282
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
8383

8484
def test_assert_swizzle(self):

tinygrad/codegen/kernel.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,10 @@ def fixup_ast(op:UOp) -> UOp:
582582
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
583583
if op.op in GroupOp.Buffer and op in self.bufs:
584584
st_uop = self.sts[self.bufs.index(op)].to_uop()
585-
return ret.replace(src=(st_uop,)) if op.op is Ops.VALID else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
585+
# NOTE: if CONST got masked after applying opts, we create a new VALID
586+
if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.valid(unwrap(st_uop.st))
587+
# otherwise we just replace the VIEW source
588+
return ret.replace(src=(st_uop,)) if len(op.src) == 1 else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
586589
if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals))
587590
if op.op is Ops.REDUCE_AXIS:
588591
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2

tinygrad/engine/schedule.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,12 @@ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
194194
# don't need contiguous or assign anymore
195195
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
196196
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
197+
# don't need DEVICE anymore
198+
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
197199
# PRELOAD becomes LOAD
198200
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
199201
# once images are loaded they become the base dtype
200202
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
201-
# CONST(VIEW) becomes VALID too, TODO: doesn't have to
202-
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: x.replace(src=()).valid(st.st)),
203203
])
204204

205205
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel

tinygrad/ops.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class GroupOp:
162162
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
163163
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.STRIDE}
164164

165-
Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
165+
Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
166166
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
167167

168168
# BinaryOps that can be flipped
@@ -305,7 +305,9 @@ def st(self) -> ShapeTracker|None:
305305
def full_shape(self) -> tuple[sint, ...]:
306306
if self.op is Ops.VIEW: return self.shape
307307
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
308-
return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL,Ops.DEFINE_VAR,Ops.CONST}]))
308+
return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL} \
309+
# TODO: this exists because wmma creates consts without ShapeTracker in the AST, there's probably a way to fix this
310+
and not (x.op is Ops.CONST and x.st is None)]))
309311
@property
310312
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
311313
@property
@@ -385,7 +387,10 @@ def const(dtype:DType, b:ConstLike):
385387
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
386388
def valid(self, st:ShapeTracker):
387389
assert self.op in {Ops.CONST, Ops.DEFINE_VAR}, f"can only create VALID from a constant, got {self.op}"
388-
return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self, 0)
390+
from tinygrad.shape.shapetracker import ShapeTracker
391+
# NOTE: only VALID has a masked ShapeTracker, the CONST operands are unmasked
392+
unmasked_st = ShapeTracker.from_shape(()).reshape((1,)*len(st.shape)).expand(st.shape).to_uop()
393+
return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self.replace(src=(unmasked_st,)), UOp.const(self.dtype, 0).replace(src=(unmasked_st,)))
389394
@staticmethod
390395
def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx)
391396
def _reduce_op(self, op:Ops, axis:tuple[int, ...]):
@@ -1330,7 +1335,7 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
13301335
# push VIEW to parents
13311336
view_left = merge_views+PatternMatcher([
13321337
# VIEW(CONST) becomes VALID
1333-
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.replace(src=()).valid(vm.st)),
1338+
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)),
13341339
# VIEW before elementwise/buffer ops
13351340
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
13361341
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),

tinygrad/viz/serve.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, list[int], str]]:
4242
graph: dict[int, tuple[str, list[int], str]] = {}
4343
excluded: set[UOp] = set()
4444
for u in (toposort:=x.toposort):
45-
if u.op in {Ops.CONST, Ops.DEVICE}: excluded.update((u,) + u.src)
45+
# always exclude DEVICE/CONST
46+
if u.op in {Ops.DEVICE, Ops.CONST}: excluded.add(u)
47+
# only exclude CONST VIEW source if it has no other children
48+
if u.op is Ops.CONST and len(u.src) != 0 and all((cr:=c()) is None or cr.op is Ops.CONST for c in u.src[0].children): excluded.update(u.src)
4649
for u in toposort:
4750
if u in excluded: continue
4851
argst = str(u.arg)

0 commit comments

Comments
 (0)