Skip to content

Commit 8a60766

Browse files
authored
Merge branch 'main' into feat/chat_function_template
2 parents 71c372e + 610a592 commit 8a60766

File tree

6 files changed

+539
-108
lines changed

6 files changed

+539
-108
lines changed

llama_cpp/llama.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Iterator,
1919
Deque,
2020
Callable,
21+
Dict,
2122
)
2223
from collections import deque
2324
from pathlib import Path
@@ -262,9 +263,7 @@ def __init__(
262263

263264
self.n_batch = min(n_ctx, n_batch) # ???
264265
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
265-
self.n_threads_batch = n_threads_batch or max(
266-
multiprocessing.cpu_count() // 2, 1
267-
)
266+
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
268267

269268
# Context Params
270269
self.context_params = llama_cpp.llama_context_default_params()
@@ -1793,7 +1792,7 @@ def save_state(self) -> LlamaState:
17931792
file=sys.stderr,
17941793
)
17951794
return LlamaState(
1796-
scores=self.scores.copy(),
1795+
scores=self._scores.copy(),
17971796
input_ids=self.input_ids.copy(),
17981797
n_tokens=self.n_tokens,
17991798
llama_state=bytes(llama_state_compact),
@@ -1802,7 +1801,9 @@ def save_state(self) -> LlamaState:
18021801

18031802
def load_state(self, state: LlamaState) -> None:
18041803
assert self._ctx.ctx is not None
1805-
self.scores = state.scores.copy()
1804+
# Only filling in up to `n_tokens` and then zero-ing out the rest
1805+
self.scores[: state.n_tokens, :] = state.scores.copy()
1806+
self.scores[state.n_tokens :, :] = 0.0
18061807
self.input_ids = state.input_ids.copy()
18071808
self.n_tokens = state.n_tokens
18081809
state_size = state.llama_state_size
@@ -1953,7 +1954,6 @@ def from_pretrained(
19531954
local_dir_use_symlinks=local_dir_use_symlinks,
19541955
cache_dir=cache_dir,
19551956
local_files_only=True,
1956-
19571957
)
19581958
else:
19591959
model_path = os.path.join(local_dir, filename)

llama_cpp/llama_chat_format.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2852,4 +2852,5 @@ def vicuna_function_calling(
28522852
"{% if add_generation_prompt %}</s>ASSISTANT\n{% endif %}" # Vicuna adds the role for prompt continuation
28532853
)
28542854
return base_function_calling(end_token="</s>",
2855-
**locals())
2855+
**locals())
2856+

0 commit comments

Comments
 (0)