Skip to content

Commit 865f23d

Browse files
committed
olmoe memory usage cleanups
1 parent 2c87a22 commit 865f23d

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

examples/olmoe.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# https://arxiv.org/pdf/2409.02060
2+
import time
23
import numpy as np
34
np.set_printoptions(suppress=True, linewidth=1000)
45
import functools
@@ -53,13 +54,17 @@ def fetch_weights() -> dict[str, Tensor]:
5354
model_state_dict = nn.state.get_state_dict(model)
5455
del model_state_dict['freqs_cis']
5556

56-
with Timing("fetch and load weights: "):
57-
state = fetch_weights()
58-
nhf_state = convert_from_huggingface(state, model, 16, 16)
57+
with Timing("load weights to GPU: "):
58+
nhf_state = convert_from_huggingface(fetch_weights(), model, 16, 16)
5959
# NOTE: i'm not sure this actually needs float32, it may just change the type of things downstream from it. but doesn't match torch w/o this
6060
for needs_float32 in ['tok_embeddings.weight']: nhf_state[needs_float32] = nhf_state[needs_float32].float()
61+
print(f"ram used: {GlobalCounters.mem_used/1e9:.2f} GB")
62+
63+
with Timing("unpack weights: "):
6164
nn.state.load_state_dict(model, nhf_state, verbose=False, strict=False, consume=True, realize=False)
6265
assert len(nhf_state) == 0
66+
Tensor.realize(*list(nn.state.get_state_dict(model).values()))
67+
print(f"ram used: {GlobalCounters.mem_used/1e9:.2f} GB")
6368

6469
count = 30
6570
temperature = 0
@@ -70,13 +75,17 @@ def fetch_weights() -> dict[str, Tensor]:
7075

7176
toks = [12092]
7277
start_pos = 0
78+
timings = []
7379
for i in range(count):
7480
GlobalCounters.reset()
81+
st = time.perf_counter()
7582
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
83+
timings.append(time.perf_counter()-st)
7684
toks.append(tok)
7785
start_pos += 1
7886
print(toks)
7987
print(tokenizer.decode(toks))
88+
print(f"fastest token {min(timings)*1e3:.2f} ms, {1/min(timings):.1f} tok/s")
8089

8190
if temperature == 0:
8291
# Hello, I am a newbie to this forum and I am trying to get a better understanding of the different types of data that can be stored in a

0 commit comments

Comments
 (0)