Skip to content

Commit 80be68a

Browse files
committed
feat: Update llama.cpp
1 parent 0580cf2 commit 80be68a

File tree

4 files changed

+626
-280
lines changed

4 files changed

+626
-280
lines changed

llama_cpp/_internals.py

Lines changed: 101 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@ def __init__(
5555
if model is None:
5656
raise ValueError(f"Failed to load model from file: {path_model}")
5757

58+
vocab = llama_cpp.llama_model_get_vocab(model)
59+
60+
if vocab is None:
61+
raise ValueError(f"Failed to get vocab from model: {path_model}")
62+
5863
self.model = model
64+
self.vocab = vocab
5965

6066
def free_model():
6167
if self.model is None:
@@ -75,7 +81,7 @@ def vocab_type(self) -> int:
7581
return llama_cpp.llama_vocab_type(self.model)
7682

7783
def n_vocab(self) -> int:
78-
return llama_cpp.llama_n_vocab(self.model)
84+
return llama_cpp.llama_n_vocab(self.vocab)
7985

8086
def n_ctx_train(self) -> int:
8187
return llama_cpp.llama_n_ctx_train(self.model)
@@ -84,7 +90,7 @@ def n_embd(self) -> int:
8490
return llama_cpp.llama_n_embd(self.model)
8591

8692
def rope_freq_scale_train(self) -> float:
87-
return llama_cpp.llama_rope_freq_scale_train(self.model)
93+
return llama_cpp.llama_model_rope_freq_scale_train(self.model)
8894

8995
def desc(self) -> str:
9096
buf = ctypes.create_string_buffer(1024)
@@ -98,67 +104,67 @@ def n_params(self) -> int:
98104
return llama_cpp.llama_model_n_params(self.model)
99105

100106
def get_tensor(self, name: str) -> ctypes.c_void_p:
101-
return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8"))
107+
raise NotImplementedError("get_tensor is not implemented in llama.cpp")
102108

103109
# Vocab
104110

105111
def token_get_text(self, token: int) -> str:
106-
return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8")
112+
return llama_cpp.llama_token_get_text(self.vocab, token).decode("utf-8")
107113

108114
def token_get_score(self, token: int) -> float:
109-
return llama_cpp.llama_token_get_score(self.model, token)
115+
return llama_cpp.llama_token_get_score(self.vocab, token)
110116

111117
def token_get_attr(self, token: int) -> int:
112-
return llama_cpp.llama_token_get_attr(self.model, token)
118+
return llama_cpp.llama_token_get_attr(self.vocab, token)
113119

114120
# Special tokens
115121

116122
def token_bos(self) -> int:
117-
return llama_cpp.llama_token_bos(self.model)
123+
return llama_cpp.llama_token_bos(self.vocab)
118124

119125
def token_eos(self) -> int:
120-
return llama_cpp.llama_token_eos(self.model)
126+
return llama_cpp.llama_token_eos(self.vocab)
121127

122128
def token_cls(self) -> int:
123-
return llama_cpp.llama_token_cls(self.model)
129+
return llama_cpp.llama_token_cls(self.vocab)
124130

125131
def token_sep(self) -> int:
126-
return llama_cpp.llama_token_sep(self.model)
132+
return llama_cpp.llama_token_sep(self.vocab)
127133

128134
def token_nl(self) -> int:
129-
return llama_cpp.llama_token_nl(self.model)
135+
return llama_cpp.llama_token_nl(self.vocab)
130136

131137
def token_prefix(self) -> int:
132-
return llama_cpp.llama_token_prefix(self.model)
138+
raise NotImplementedError("token_prefix is not implemented in llama.cpp")
133139

134140
def token_middle(self) -> int:
135-
return llama_cpp.llama_token_middle(self.model)
141+
raise NotImplementedError("token_middle is not implemented in llama.cpp")
136142

137143
def token_suffix(self) -> int:
138-
return llama_cpp.llama_token_suffix(self.model)
144+
raise NotImplementedError("token_suffix is not implemented in llama.cpp")
139145

140146
def token_eot(self) -> int:
141-
return llama_cpp.llama_token_eot(self.model)
147+
return llama_cpp.llama_token_eot(self.vocab)
142148

143149
def add_bos_token(self) -> bool:
144-
return llama_cpp.llama_add_bos_token(self.model)
150+
return llama_cpp.llama_add_bos_token(self.vocab)
145151

146152
def add_eos_token(self) -> bool:
147-
return llama_cpp.llama_add_eos_token(self.model)
153+
return llama_cpp.llama_add_eos_token(self.vocab)
148154

149155
# Tokenization
150156

151157
def tokenize(self, text: bytes, add_bos: bool, special: bool):
152158
n_ctx = self.n_ctx_train()
153159
tokens = (llama_cpp.llama_token * n_ctx)()
154160
n_tokens = llama_cpp.llama_tokenize(
155-
self.model, text, len(text), tokens, n_ctx, add_bos, special
161+
self.vocab, text, len(text), tokens, n_ctx, add_bos, special
156162
)
157163
if n_tokens < 0:
158164
n_tokens = abs(n_tokens)
159165
tokens = (llama_cpp.llama_token * n_tokens)()
160166
n_tokens = llama_cpp.llama_tokenize(
161-
self.model, text, len(text), tokens, n_tokens, add_bos, special
167+
self.vocab, text, len(text), tokens, n_tokens, add_bos, special
162168
)
163169
if n_tokens < 0:
164170
raise RuntimeError(
@@ -168,7 +174,7 @@ def tokenize(self, text: bytes, add_bos: bool, special: bool):
168174

169175
def token_to_piece(self, token: int, special: bool = False) -> bytes:
170176
buf = ctypes.create_string_buffer(32)
171-
llama_cpp.llama_token_to_piece(self.model, token, buf, 32, 0, special)
177+
llama_cpp.llama_token_to_piece(self.vocab, token, buf, 32, 0, special)
172178
return bytes(buf)
173179

174180
def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
@@ -177,7 +183,7 @@ def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
177183
buffer = (ctypes.c_char * size)()
178184
for token in tokens:
179185
n = llama_cpp.llama_token_to_piece(
180-
self.model, llama_cpp.llama_token(token), buffer, size, 0, special
186+
self.vocab, llama_cpp.llama_token(token), buffer, size, 0, special
181187
)
182188
assert n <= size
183189
output += bytes(buffer[:n])
@@ -320,7 +326,8 @@ def get_embeddings(self):
320326

321327
def set_rng_seed(self, seed: int):
322328
# TODO: Fix
323-
llama_cpp.llama_set_rng_seed(self.ctx, seed)
329+
# llama_cpp.llama_set_rng_seed(self.ctx, seed)
330+
raise NotImplementedError("set_rng_seed is not implemented in llama.cpp")
324331

325332
def sample_repetition_penalties(
326333
self,
@@ -331,55 +338,63 @@ def sample_repetition_penalties(
331338
penalty_freq: float,
332339
penalty_present: float,
333340
):
334-
llama_cpp.llama_sample_repetition_penalties(
335-
self.ctx,
336-
llama_cpp.byref(candidates.candidates),
337-
last_tokens_data,
338-
penalty_last_n,
339-
penalty_repeat,
340-
penalty_freq,
341-
penalty_present,
342-
)
341+
# llama_cpp.llama_sample_repetition_penalties(
342+
# self.ctx,
343+
# llama_cpp.byref(candidates.candidates),
344+
# last_tokens_data,
345+
# penalty_last_n,
346+
# penalty_repeat,
347+
# penalty_freq,
348+
# penalty_present,
349+
# )
350+
raise NotImplementedError("sample_repetition_penalties is not implemented in llama.cpp")
343351

344352
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
345-
llama_cpp.llama_sample_softmax(
346-
self.ctx,
347-
llama_cpp.byref(candidates.candidates),
348-
)
353+
# llama_cpp.llama_sample_softmax(
354+
# self.ctx,
355+
# llama_cpp.byref(candidates.candidates),
356+
# )
357+
raise NotImplementedError("sample_softmax is not implemented in llama.cpp")
349358

350359
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
351-
llama_cpp.llama_sample_top_k(
352-
self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
353-
)
360+
# llama_cpp.llama_sample_top_k(
361+
# self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
362+
# )
363+
raise NotImplementedError("sample_top_k is not implemented in llama.cpp")
354364

355365
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
356-
llama_cpp.llama_sample_top_p(
357-
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
358-
)
366+
# llama_cpp.llama_sample_top_p(
367+
# self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
368+
# )
369+
raise NotImplementedError("sample_top_p is not implemented in llama.cpp")
359370

360371
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
361-
llama_cpp.llama_sample_min_p(
362-
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
363-
)
372+
# llama_cpp.llama_sample_min_p(
373+
# self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
374+
# )
375+
raise NotImplementedError("sample_min_p is not implemented in llama.cpp")
364376

365377
def sample_typical(
366378
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
367379
):
368-
llama_cpp.llama_sample_typical(
369-
self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
370-
)
380+
# llama_cpp.llama_sample_typical(
381+
# self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep
382+
# )
383+
raise NotImplementedError("sample_typical is not implemented in llama.cpp")
371384

372385
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
373-
llama_cpp.llama_sample_temp(
374-
self.ctx, llama_cpp.byref(candidates.candidates), temp
375-
)
386+
# llama_cpp.llama_sample_temp(
387+
# self.ctx, llama_cpp.byref(candidates.candidates), temp
388+
# )
389+
raise NotImplementedError("sample_temp is not implemented in llama.cpp")
376390

377391
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
378-
llama_cpp.llama_sample_grammar(
379-
self.ctx,
380-
llama_cpp.byref(candidates.candidates),
381-
grammar.grammar,
382-
)
392+
# llama_cpp.llama_sample_grammar(
393+
# self.ctx,
394+
# llama_cpp.byref(candidates.candidates),
395+
# grammar.grammar,
396+
# )
397+
raise NotImplementedError("sample_grammar is not implemented in llama.cpp")
383398

384399
def sample_token_mirostat(
385400
self,
@@ -389,14 +404,15 @@ def sample_token_mirostat(
389404
m: int,
390405
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
391406
) -> int:
392-
return llama_cpp.llama_sample_token_mirostat(
393-
self.ctx,
394-
llama_cpp.byref(candidates.candidates),
395-
tau,
396-
eta,
397-
m,
398-
mu,
399-
)
407+
raise NotImplementedError("sample_token_mirostat is not implemented in llama.cpp")
408+
# return llama_cpp.llama_sample_token_mirostat(
409+
# self.ctx,
410+
# llama_cpp.byref(candidates.candidates),
411+
# tau,
412+
# eta,
413+
# m,
414+
# mu,
415+
# )
400416

401417
def sample_token_mirostat_v2(
402418
self,
@@ -405,29 +421,33 @@ def sample_token_mirostat_v2(
405421
eta: float,
406422
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
407423
) -> int:
408-
return llama_cpp.llama_sample_token_mirostat_v2(
409-
self.ctx,
410-
llama_cpp.byref(candidates.candidates),
411-
tau,
412-
eta,
413-
mu,
414-
)
424+
raise NotImplementedError("sample_token_mirostat_v2 is not implemented in llama.cpp")
425+
# return llama_cpp.llama_sample_token_mirostat_v2(
426+
# self.ctx,
427+
# llama_cpp.byref(candidates.candidates),
428+
# tau,
429+
# eta,
430+
# mu,
431+
# )
415432

416433
def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
417-
return llama_cpp.llama_sample_token_greedy(
418-
self.ctx,
419-
llama_cpp.byref(candidates.candidates),
420-
)
434+
raise NotImplementedError("sample_token_greedy is not implemented in llama.cpp")
435+
# return llama_cpp.llama_sample_token_greedy(
436+
# self.ctx,
437+
# llama_cpp.byref(candidates.candidates),
438+
# )
421439

422440
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
423-
return llama_cpp.llama_sample_token(
424-
self.ctx,
425-
llama_cpp.byref(candidates.candidates),
426-
)
441+
raise NotImplementedError("sample_token is not implemented in llama.cpp")
442+
# return llama_cpp.llama_sample_token(
443+
# self.ctx,
444+
# llama_cpp.byref(candidates.candidates),
445+
# )
427446

428447
# Grammar
429448
def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
430-
llama_cpp.llama_grammar_accept_token(grammar.grammar, self.ctx, token)
449+
raise NotImplementedError("grammar_accept_token is not implemented in llama.cpp")
450+
# llama_cpp.llama_grammar_accept_token(grammar.grammar, self.ctx, token)
431451

432452
def reset_timings(self):
433453
llama_cpp.llama_perf_context_reset(self.ctx)
@@ -788,7 +808,7 @@ def add_mirostat_v2(self, seed: int, tau: float, eta: float):
788808

789809
def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
790810
sampler = llama_cpp.llama_sampler_init_grammar(
791-
model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
811+
model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
792812
)
793813
self._add_sampler(sampler)
794814

@@ -842,6 +862,7 @@ def get_seed(self) -> int:
842862

843863
def sample(self, ctx: LlamaContext, idx: int) -> int:
844864
assert self.sampler is not None
865+
assert ctx.ctx is not None
845866
return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx)
846867

847868
def close(self):

llama_cpp/llama.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -406,10 +406,10 @@ def __init__(
406406
)
407407
)
408408

409-
self._lora_adapter: Optional[llama_cpp.llama_lora_adapter_p] = None
409+
self._lora_adapter: Optional[llama_cpp.llama_adapter_lora_p] = None
410410

411411
if self.lora_path:
412-
self._lora_adapter = llama_cpp.llama_lora_adapter_init(
412+
self._lora_adapter = llama_cpp.llama_adapter_lora_init(
413413
self._model.model,
414414
self.lora_path.encode("utf-8"),
415415
)
@@ -421,12 +421,12 @@ def __init__(
421421
def free_lora_adapter():
422422
if self._lora_adapter is None:
423423
return
424-
llama_cpp.llama_lora_adapter_free(self._lora_adapter)
424+
llama_cpp.llama_adapter_lora_free(self._lora_adapter)
425425
self._lora_adapter = None
426426

427427
self._stack.callback(free_lora_adapter)
428428

429-
if llama_cpp.llama_lora_adapter_set(
429+
if llama_cpp.llama_set_adapter_lora(
430430
self._ctx.ctx, self._lora_adapter, self.lora_scale
431431
):
432432
raise RuntimeError(
@@ -1152,9 +1152,9 @@ def _create_completion(
11521152
bos_token_id: int = self.token_bos()
11531153
cls_token_id: int = self._model.token_cls()
11541154
sep_token_id: int = self._model.token_sep()
1155-
prefix_token_id: int = self._model.token_prefix()
1156-
middle_token_id: int = self._model.token_middle()
1157-
suffix_token_id: int = self._model.token_suffix()
1155+
prefix_token_id: int = 0 # self._model.token_prefix() # TODO: Fix
1156+
middle_token_id: int = 0 # self._model.token_middle() # TODO: Fix
1157+
suffix_token_id: int = 0 # self._model.token_suffix() # TODO: Fix
11581158
add_space_prefix: bool = (
11591159
self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
11601160
)
@@ -1332,7 +1332,7 @@ def logit_bias_processor(
13321332
logits_processor=logits_processor,
13331333
grammar=grammar,
13341334
):
1335-
if llama_cpp.llama_token_is_eog(self._model.model, token):
1335+
if llama_cpp.llama_token_is_eog(self._model.vocab, token):
13361336
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
13371337
finish_reason = "stop"
13381338
break

0 commit comments

Comments
 (0)