2
2
import multiprocessing , pickle , functools , difflib , os , threading , json , time , sys , webbrowser , socket , argparse , decimal
3
3
from http .server import HTTPServer , BaseHTTPRequestHandler
4
4
from urllib .parse import parse_qs , urlparse
5
- from typing import Any , Callable , TypedDict
5
+ from typing import Any , Callable , TypedDict , cast
6
6
from tinygrad .helpers import colored , getenv , to_function_name , tqdm , unwrap , word_wrap
7
7
from tinygrad .ops import TrackedGraphRewrite , UOp , Ops , lines , GroupOp
8
8
from tinygrad .codegen .kernel import Kernel
@@ -29,7 +29,7 @@ class GraphRewriteDetails(GraphRewriteMetadata):
29
29
changed_nodes : list [list [int ]] # the changed UOp id + all its parents ids
30
30
code_line : str # source code calling graph_rewrite
31
31
kernel_code : str | None # optionally render the final kernel code
32
- upats : list [tuple [tuple [str , int ], str ]]
32
+ upats : list [tuple [tuple [str , int ], str ]| None ]
33
33
34
34
# NOTE: if any extra rendering in VIZ fails, we don't crash
35
35
def pcall (fxn :Callable [..., str ], * args , ** kwargs ) -> str :
@@ -69,9 +69,10 @@ def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> li
69
69
def _prg (k :Kernel ): return k .to_program ().src
70
70
def get_details (k :Any , ctx :TrackedGraphRewrite , metadata :GraphRewriteMetadata , offset = 0 , limit = 200 ) -> GraphRewriteDetails :
71
71
ret :GraphRewriteDetails = {"uops" :[pcall (str , sink := ctx .sink )], "graphs" :[uop_to_json (sink )], "code_line" :lines (ctx .loc [0 ])[ctx .loc [1 ]- 1 ].strip (),
72
- "kernel_code" :pcall (_prg , k ) if isinstance (k , Kernel ) else None , "diffs" :[], "upats" :[], "changed_nodes" :[], ** metadata }
72
+ "kernel_code" :pcall (_prg , k ) if isinstance (k , Kernel ) else None , ** metadata ,
73
+ "diffs" :[[]], "upats" :[None ], "changed_nodes" :[[]]} # NOTE: the first graph just renders the input UOp
73
74
replaces : dict [UOp , UOp ] = {}
74
- for i ,(u0 ,u1 ,upat ) in enumerate (tqdm (ctx .matches [ offset : offset + limit ] )):
75
+ for i ,(u0 ,u1 ,upat ) in enumerate (tqdm (ctx .matches )):
75
76
replaces [u0 ] = u1
76
77
new_sink = sink .substitute (replaces )
77
78
ret ["graphs" ].append (new_sink_js := uop_to_json (new_sink ))
@@ -80,7 +81,9 @@ def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata, o
80
81
ret ["upats" ].append ((upat .location , upat .printable ()))
81
82
# TODO: this is O(n^2)!
82
83
ret ["uops" ].append (str (sink := new_sink ))
83
- return ret
84
+ # if the client requested a chunk we only send that chunk
85
+ # TODO: is there a way to cache the replaces dict here?
86
+ return cast (GraphRewriteDetails , {k :v [offset :offset + limit ] if isinstance (v ,list ) else v for k ,v in ret .items ()})
84
87
85
88
# Profiler API
86
89
devices :dict [str , tuple [decimal .Decimal , decimal .Decimal , int ]] = {}
0 commit comments