1
1
# https://arxiv.org/pdf/2409.02060
2
+ import time
2
3
import numpy as np
3
4
np .set_printoptions (suppress = True , linewidth = 1000 )
4
5
import functools
@@ -53,13 +54,17 @@ def fetch_weights() -> dict[str, Tensor]:
53
54
model_state_dict = nn .state .get_state_dict (model )
54
55
del model_state_dict ['freqs_cis' ]
55
56
56
- with Timing ("fetch and load weights: " ):
57
- state = fetch_weights ()
58
- nhf_state = convert_from_huggingface (state , model , 16 , 16 )
57
+ with Timing ("load weights to GPU: " ):
58
+ nhf_state = convert_from_huggingface (fetch_weights (), model , 16 , 16 )
59
59
# NOTE: i'm not sure this actually needs float32, it may just change the type of things downstream from it. but doesn't match torch w/o this
60
60
for needs_float32 in ['tok_embeddings.weight' ]: nhf_state [needs_float32 ] = nhf_state [needs_float32 ].float ()
61
+ print (f"ram used: { GlobalCounters .mem_used / 1e9 :.2f} GB" )
62
+
63
+ with Timing ("unpack weights: " ):
61
64
nn .state .load_state_dict (model , nhf_state , verbose = False , strict = False , consume = True , realize = False )
62
65
assert len (nhf_state ) == 0
66
+ Tensor .realize (* list (nn .state .get_state_dict (model ).values ()))
67
+ print (f"ram used: { GlobalCounters .mem_used / 1e9 :.2f} GB" )
63
68
64
69
count = 30
65
70
temperature = 0
@@ -70,13 +75,17 @@ def fetch_weights() -> dict[str, Tensor]:
70
75
71
76
toks = [12092 ]
72
77
start_pos = 0
78
+ timings = []
73
79
for i in range (count ):
74
80
GlobalCounters .reset ()
81
+ st = time .perf_counter ()
75
82
tok = model (Tensor ([toks [start_pos :]]), start_pos , temperature ).item ()
83
+ timings .append (time .perf_counter ()- st )
76
84
toks .append (tok )
77
85
start_pos += 1
78
86
print (toks )
79
87
print (tokenizer .decode (toks ))
88
+ print (f"fastest token { min (timings )* 1e3 :.2f} ms, { 1 / min (timings ):.1f} tok/s" )
80
89
81
90
if temperature == 0 :
82
91
# Hello, I am a newbie to this forum and I am trying to get a better understanding of the different types of data that can be stored in a
0 commit comments