Skip to content

Commit d32f5e9

Browse files
authored
improve rendering of shapes in viz + investigate symbolic [pr] (tinygrad#10091)
1 parent dbb7aee commit d32f5e9

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

test/test_symbolic_ops.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
from tinygrad import Variable
3-
from tinygrad.helpers import Context
3+
from tinygrad.helpers import Context, GlobalCounters
44
from tinygrad.tensor import Tensor
55
from examples.gpt2 import Attention
66
import numpy as np
@@ -43,17 +43,24 @@ def f(a, b): return (a@b).realize()
4343
expected = f(a, b).numpy()
4444
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
4545

46-
def test_attention(self, dropout_p=0.0):
46+
def test_attention(self, dropout_p=0.0, imin=1, imax=5, use_symbolic=True):
4747
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p).realize()
48-
for i in range(1, 5):
49-
vi = Variable("i", 1, 10).bind(i)
48+
for i in range(imin, imax):
49+
vi = Variable("i", 1, 10).bind(i) if use_symbolic else i
5050
q = Tensor.rand(2, 1, 4, 8)
5151
k = Tensor.rand(2, i, 4, 8)
5252
v = Tensor.rand(2, i, 4, 8)
53+
Tensor.realize(q, k, v)
54+
GlobalCounters.reset()
5355
symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
5456
expected = f(q, k, v).numpy()
5557
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
5658

59+
def test_attention_cmp_symbolic(self):
60+
# symbolic isn't seeing if i == i, so it's not putting them on the same axis
61+
self.test_attention(imin=4, imax=5, use_symbolic=False)
62+
self.test_attention(imin=4, imax=5, use_symbolic=True)
63+
5764
def test_attention_training(self):
5865
with Tensor.train():
5966
self.test_attention(dropout_p=0.0)

tinygrad/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def _suop(lst, uop_fxn, python_fxn):
207207
return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
208208
def smax(*lst): return _suop(argfix(*lst), UOp.maximum, max)
209209
def smin(*lst): return _suop(argfix(*lst), UOp.minimum, min)
210+
def srender(x) -> str: return x.render() if isinstance(x, UOp) else str(x)
210211

211212
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop
212213
def sym_infer(uop: Union[UOp, int], var_vals: dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
@@ -1001,6 +1002,7 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
10011002
(UPat(Ops.CAST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"({str(x.dtype)[7:]})({x.src[0].arg})")),
10021003
(UPat(Ops.LOAD), lambda: UOp(Ops.NOOP, arg="load")),
10031004
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
1005+
#(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}[={x.src[1].arg}]")),
10041006
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
10051007
(UPat(Ops.MAX, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
10061008
(UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),

tinygrad/runtime/support/am/amdev.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def fmt_ver(hwip): return '_'.join(map(str, adev.ip_ver[hwip]))
7171
self.descs += [self.desc(blob, imu_i_off, imu_i_sz, am.GFX_FW_TYPE_IMU_I), self.desc(blob, imu_i_off + imu_i_sz, imu_d_sz, am.GFX_FW_TYPE_IMU_D)]
7272

7373
# RLC firmware
74-
blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0,
74+
blob, hdr0, _hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0,
7575
am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3)
7676

7777
for mem,fmem in [('IRAM', 'iram'), ('DRAM_BOOT', 'dram')]:

tinygrad/viz/serve.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from urllib.parse import parse_qs, urlparse
55
from typing import Any, Callable, TypedDict, Generator
66
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap
7-
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp
7+
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp, srender, sint
88
from tinygrad.codegen.kernel import Kernel
99
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent
1010
from tinygrad.dtype import dtypes
@@ -51,6 +51,8 @@ class GraphRewriteDetails(TypedDict):
5151
changed_nodes: list[int]|None # the changed UOp id + all its parents ids
5252
upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat
5353

54+
def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")"
55+
5456
def uop_to_json(x:UOp) -> dict[int, dict]:
5557
assert isinstance(x, UOp)
5658
graph: dict[int, dict] = {}
@@ -65,8 +67,8 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
6567
if u in excluded: continue
6668
argst = str(u.arg)
6769
if u.op is Ops.VIEW:
68-
argst = ("\n".join([f"{v.shape} / {v.strides}"+(f"\nMASK {v.mask}" if v.mask is not None else "")+
69-
("" if v.offset == 0 else f" / {v.offset}") for v in unwrap(u.st).views]))
70+
argst = ("\n".join([f"{shape_to_str(v.shape)} / {shape_to_str(v.strides)}"+(f"\nMASK {v.mask}" if v.mask is not None else "")+
71+
("" if v.offset == 0 else f" / {srender(v.offset)}") for v in unwrap(u.st).views]))
7072
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
7173
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
7274
for idx,x in enumerate(u.src):
@@ -75,7 +77,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
7577
else: label += f"\n{x.op.name}{idx} {x.arg}"
7678
try:
7779
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
78-
label += f"\n{repr(u.shape)}"
80+
label += f"\n{shape_to_str(u.shape)}"
7981
except Exception:
8082
label += "\n<ISSUE GETTING SHAPE>"
8183
graph[id(u)] = {"label":label, "src":[id(x) for x in u.src if x not in excluded], "color":uops_colors.get(u.op, "#ffffff")}

0 commit comments

Comments
 (0)