Skip to content

Commit 4924455

Browse files
authored
feat: Make saved state more compact on-disk (#1296)
* State load/save changes - Only store up to `n_tokens` logits instead of full `(n_ctx, n_vocab)` sized array. - Difference between ~350MB and ~1500MB for example prompt with ~300 tokens (makes sense lol) - Auto-formatting changes * Back out formatting changes
1 parent 9842cbf commit 4924455

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

llama_cpp/llama.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Iterator,
1919
Deque,
2020
Callable,
21+
Dict,
2122
)
2223
from collections import deque
2324
from pathlib import Path
@@ -1791,7 +1792,7 @@ def save_state(self) -> LlamaState:
17911792
file=sys.stderr,
17921793
)
17931794
return LlamaState(
1794-
scores=self.scores.copy(),
1795+
scores=self._scores.copy(),
17951796
input_ids=self.input_ids.copy(),
17961797
n_tokens=self.n_tokens,
17971798
llama_state=bytes(llama_state_compact),
@@ -1800,7 +1801,9 @@ def save_state(self) -> LlamaState:
18001801

18011802
def load_state(self, state: LlamaState) -> None:
18021803
assert self._ctx.ctx is not None
1803-
self.scores = state.scores.copy()
1804+
# Only filling in up to `n_tokens` and then zero-ing out the rest
1805+
self.scores[: state.n_tokens, :] = state.scores.copy()
1806+
self.scores[state.n_tokens :, :] = 0.0
18041807
self.input_ids = state.input_ids.copy()
18051808
self.n_tokens = state.n_tokens
18061809
state_size = state.llama_state_size
@@ -1951,7 +1954,6 @@ def from_pretrained(
19511954
local_dir_use_symlinks=local_dir_use_symlinks,
19521955
cache_dir=cache_dir,
19531956
local_files_only=True,
1954-
19551957
)
19561958
else:
19571959
model_path = os.path.join(local_dir, filename)

0 commit comments

Comments
 (0)