Skip to content

Commit 308516e

Browse files
authored
fix viz paginate + cleanups [pr] (tinygrad#8973)
* fix viz paginate [pr] * cleanups * remove the extra font definition * more work * none for the first graph
1 parent 107e616 commit 308516e

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

tinygrad/viz/index.html

+9-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
padding: 0;
2727
width: 100%;
2828
height: 100%;
29-
font-family: "Noto Sans", sans-serif;
29+
font-family: sans-serif;
3030
font-optical-sizing: auto;
3131
font-weight: 400;
3232
font-style: normal;
@@ -356,25 +356,28 @@
356356
}
357357
else {
358358
const p = document.querySelector(".progress");
359+
p.style.display = "none";
359360
const limit = 100;
360361
const pageCount = Math.ceil((kernels[currentKernel][1][currentUOp].match_count+1)/limit);
361-
p.style.display = pageCount == 1 ? "none" : "flex";
362+
if (pageCount > 1) {
363+
p.style.display = "flex";
364+
d3.select("#graph-svg g").selectAll("*").remove();
365+
}
362366
for (let i=0; i < pageCount; i++) {
363367
p.innerText = `fetching data ${i+1}/${pageCount}`;
364368
const chunk = await (await fetch(`/kernels?kernel=${currentKernel}&idx=${currentUOp}&offset=${i*limit}&limit=${limit}`)).json();
365369
if (i === 0) ret = chunk
366370
else {
367371
// TODO: this shouldn't exist after the viz api refactor
368372
for (const [k,v] of Object.entries(chunk)) {
369-
if (["uops", "graphs"].includes(k)) v.splice(0, 1);
370373
if (Array.isArray(v) && k !== "loc") ret[k].push(...v);
371374
}
372375
}
373376
}
374377
p.style.display = "none";
375378
cache[cacheKey] = ret;
376379
}
377-
renderGraph(ret.graphs[currentRewrite], currentRewrite == 0 ? [] : ret.changed_nodes[currentRewrite-1]);
380+
renderGraph(ret.graphs[currentRewrite], ret.changed_nodes[currentRewrite]);
378381
}
379382
// ***** RHS metadata
380383
const metadata = document.querySelector(".container.metadata");
@@ -402,8 +405,8 @@
402405
if (i === currentRewrite) {
403406
gUl.classList.add("active");
404407
if (i !== 0) {
405-
const diff = ret.diffs[i-1];
406-
const [loc, pattern] = ret.upats[i-1];
408+
const diff = ret.diffs[i];
409+
const [loc, pattern] = ret.upats[i];
407410
const parts = loc.join(":").split("/");
408411
const div = Object.assign(document.createElement("div"), { className: "rewrite-container" });
409412
const link = vsCodeOpener(parts);

tinygrad/viz/serve.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal
33
from http.server import HTTPServer, BaseHTTPRequestHandler
44
from urllib.parse import parse_qs, urlparse
5-
from typing import Any, Callable, TypedDict
5+
from typing import Any, Callable, TypedDict, cast
66
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
77
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp
88
from tinygrad.codegen.kernel import Kernel
@@ -29,7 +29,7 @@ class GraphRewriteDetails(GraphRewriteMetadata):
2929
changed_nodes: list[list[int]] # the changed UOp id + all its parents ids
3030
code_line: str # source code calling graph_rewrite
3131
kernel_code: str|None # optionally render the final kernel code
32-
upats: list[tuple[tuple[str, int], str]]
32+
upats: list[tuple[tuple[str, int], str]|None]
3333

3434
# NOTE: if any extra rendering in VIZ fails, we don't crash
3535
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
@@ -69,9 +69,10 @@ def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> li
6969
def _prg(k:Kernel): return k.to_program().src
7070
def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata, offset=0, limit=200) -> GraphRewriteDetails:
7171
ret:GraphRewriteDetails = {"uops":[pcall(str, sink:=ctx.sink)], "graphs":[uop_to_json(sink)], "code_line":lines(ctx.loc[0])[ctx.loc[1]-1].strip(),
72-
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None, "diffs":[], "upats":[], "changed_nodes":[], **metadata}
72+
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None, **metadata,
73+
"diffs":[[]], "upats":[None], "changed_nodes":[[]]} # NOTE: the first graph just renders the input UOp
7374
replaces: dict[UOp, UOp] = {}
74-
for i,(u0,u1,upat) in enumerate(tqdm(ctx.matches[offset:offset+limit])):
75+
for i,(u0,u1,upat) in enumerate(tqdm(ctx.matches)):
7576
replaces[u0] = u1
7677
new_sink = sink.substitute(replaces)
7778
ret["graphs"].append(new_sink_js:=uop_to_json(new_sink))
@@ -80,7 +81,9 @@ def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata, o
8081
ret["upats"].append((upat.location, upat.printable()))
8182
# TODO: this is O(n^2)!
8283
ret["uops"].append(str(sink:=new_sink))
83-
return ret
84+
# if the client requested a chunk we only send that chunk
85+
# TODO: is there a way to cache the replaces dict here?
86+
return cast(GraphRewriteDetails, {k:v[offset:offset+limit] if isinstance(v,list) else v for k,v in ret.items()})
8487

8588
# Profiler API
8689
devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}

0 commit comments

Comments
 (0)