18
18
Iterator ,
19
19
Deque ,
20
20
Callable ,
21
+ Dict ,
21
22
)
22
23
from collections import deque
23
24
from pathlib import Path
@@ -262,9 +263,7 @@ def __init__(
262
263
263
264
self .n_batch = min (n_ctx , n_batch ) # ???
264
265
self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
265
- self .n_threads_batch = n_threads_batch or max (
266
- multiprocessing .cpu_count () // 2 , 1
267
- )
266
+ self .n_threads_batch = n_threads_batch or multiprocessing .cpu_count ()
268
267
269
268
# Context Params
270
269
self .context_params = llama_cpp .llama_context_default_params ()
@@ -1793,7 +1792,7 @@ def save_state(self) -> LlamaState:
1793
1792
file = sys .stderr ,
1794
1793
)
1795
1794
return LlamaState (
1796
- scores = self .scores .copy (),
1795
+ scores = self ._scores .copy (),
1797
1796
input_ids = self .input_ids .copy (),
1798
1797
n_tokens = self .n_tokens ,
1799
1798
llama_state = bytes (llama_state_compact ),
@@ -1802,7 +1801,9 @@ def save_state(self) -> LlamaState:
1802
1801
1803
1802
def load_state (self , state : LlamaState ) -> None :
1804
1803
assert self ._ctx .ctx is not None
1805
- self .scores = state .scores .copy ()
1804
+ # Only filling in up to `n_tokens` and then zero-ing out the rest
1805
+ self .scores [: state .n_tokens , :] = state .scores .copy ()
1806
+ self .scores [state .n_tokens :, :] = 0.0
1806
1807
self .input_ids = state .input_ids .copy ()
1807
1808
self .n_tokens = state .n_tokens
1808
1809
state_size = state .llama_state_size
@@ -1953,7 +1954,6 @@ def from_pretrained(
1953
1954
local_dir_use_symlinks = local_dir_use_symlinks ,
1954
1955
cache_dir = cache_dir ,
1955
1956
local_files_only = True ,
1956
-
1957
1957
)
1958
1958
else :
1959
1959
model_path = os .path .join (local_dir , filename )
0 commit comments