Skip to content

Commit 05e3202

Browse files
authored
remove unused memsize_to_str and minor cleanups [pr] (tinygrad#9211)
* fix edge cases in memsize_to_str() Inputs <= 1 now return "0.00 B" for 0 and "1.00 B" for 1, avoiding an IndexError. Also, memsize_to_str(1000) now returns "1.00 KB" instead of "1000.00 B". Replaced the list comprehension with a next(...) generator for conciseness and efficiency. * simplify code using idiomatic python - Remove the unused `memsize_to_str()` function in helpers. - Use a tuple for checking multiple string prefixes/suffixes. - Avoid unnecessary list construction by using iterables directly. - Check None in @diskcache to ensure proper caching of falsy values. * revert generators back to list comprehension Sometimes building list first could be faster. Keep it as is.
1 parent 81a71ae commit 05e3202

File tree

3 files changed

+5
-7
lines changed

3 files changed

+5
-7
lines changed

examples/llama3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def convert(name) -> Tensor:
4747
disk_tensors: List[Tensor] = [model[name] for model in models]
4848
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
4949
return disk_tensors[0].to(device=device)
50-
axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
50+
axis = 1 if name.endswith((".attention.wo.weight", ".feed_forward.w2.weight")) else 0
5151
lazy_tensors = [data.to(device=device) for data in disk_tensors]
5252
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
5353
return {name: convert(name) for name in {name: None for model in models for name in model}}

tinygrad/helpers.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def all_same(items:Union[tuple[T, ...], list[T]]): return all(x == items[0] for
2727
def all_int(t: Sequence[Any]) -> TypeGuard[tuple[int, ...]]: return all(isinstance(s, int) for s in t)
2828
def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
2929
def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
30-
def memsize_to_str(_bytes: int) -> str: return [f"{(_bytes / d):.2f} {pr}" for d,pr in [(1e9,"GB"),(1e6,"MB"),(1e3,"KB"),(1,"B")] if _bytes > d][0]
3130
def time_to_str(t:float, w=8) -> str: return next((f"{t * d:{w}.2f}{pr}" for d,pr in [(1, "s "),(1e3, "ms")] if t > 10/d), f"{t * 1e6:{w}.2f}us")
3231
def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
3332
def ansilen(s:str): return len(ansistrip(s))
@@ -191,8 +190,7 @@ def diskcache_clear():
191190
def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
192191
if CACHELEVEL < 1: return None
193192
if isinstance(key, (str,int)): key = {"key": key}
194-
conn = db_connection()
195-
cur = conn.cursor()
193+
cur = db_connection().cursor()
196194
try:
197195
res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
198196
except sqlite3.OperationalError:
@@ -211,15 +209,15 @@ def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=Fals
211209
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
212210
cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
213211
_db_tables.add(table)
214-
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
212+
cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key))}, ?)", tuple(key.values()) + (val if prepickled else pickle.dumps(val), )) # noqa: E501
215213
conn.commit()
216214
cur.close()
217215
return val
218216

219217
def diskcache(func):
220218
def wrapper(*args, **kwargs) -> bytes:
221219
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
222-
if (ret:=diskcache_get(table, key)): return ret
220+
if (ret:=diskcache_get(table, key)) is not None: return ret
223221
return diskcache_put(table, key, func(*args, **kwargs))
224222
return wrapper
225223

tinygrad/viz/serve.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def do_GET(self):
124124
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
125125
elif (url:=urlparse(self.path)).path == "/profiler":
126126
with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
127-
elif (self.path.startswith("/assets/") or self.path.startswith("/lib/")) and '/..' not in self.path:
127+
elif self.path.startswith(("/assets/", "/lib/")) and '/..' not in self.path:
128128
try:
129129
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
130130
if url.path.endswith(".js"): content_type = "application/javascript"

0 commit comments

Comments
 (0)