Skip to content

Commit 3a7d83b

Browse files
lapp0brandonwillard
authored andcommitted
make LlamaCppTokenizer an outlines Tokenizer
1 parent 6696cb5 commit 3a7d83b

File tree

5 files changed

+165
-64
lines changed

5 files changed

+165
-64
lines changed

outlines/integrations/llamacpp.py

+2-37
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"""
2727

2828
import math
29-
from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union
29+
from typing import TYPE_CHECKING, Optional, Type, Union
3030

3131
import numpy as np
3232
import torch
@@ -36,47 +36,12 @@
3636
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
3737
from outlines.fsm.json_schema import build_regex_from_schema
3838
from outlines.integrations.utils import convert_json_schema_to_str
39+
from outlines.models.llamacpp import LlamaCppTokenizer
3940

4041
if TYPE_CHECKING:
4142
from llama_cpp import Llama
4243

4344

44-
class LlamaCppTokenizer:
45-
def __init__(self, model: "Llama"):
46-
self.eos_token_id = model.token_eos()
47-
self.eos_token = model.tokenizer().decode([self.eos_token_id])
48-
self.pad_token_id = self.eos_token_id
49-
self.special_tokens: Set[int] = set()
50-
51-
self.vocabulary: Dict[str, int] = dict()
52-
53-
tokenizer = model.tokenizer()
54-
55-
self.decode = tokenizer.decode
56-
57-
# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
58-
try:
59-
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
60-
except AttributeError:
61-
# ###
62-
for t in range(model.n_vocab()):
63-
token_piece = model.tokenizer().decode([t])
64-
self.vocabulary[token_piece] = t
65-
66-
def convert_token_to_string(self, token: str) -> str:
67-
return token
68-
69-
def __getstate__(self):
70-
"""Allow tokenizer to be used as hash key by excluding self.decode"""
71-
return (
72-
self.vocabulary.items(),
73-
self.eos_token_id,
74-
self.eos_token,
75-
self.pad_token_id,
76-
sorted(self.special_tokens),
77-
)
78-
79-
8045
class LogitsProcessor:
8146
"""Bias LlamaCpp generation using a finite state machine.
8247

outlines/models/llamacpp.py

+88-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,102 @@
11
import dataclasses
2+
import pickle
23
import warnings
3-
from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union
4+
from typing import (
5+
TYPE_CHECKING,
6+
Dict,
7+
Iterator,
8+
List,
9+
Optional,
10+
Set,
11+
Tuple,
12+
TypedDict,
13+
Union,
14+
)
415

516
from typing_extensions import Unpack
617

718
from outlines.generate.api import GenerationParameters, SamplingParameters
19+
from outlines.models.tokenizer import Tokenizer
820

921
if TYPE_CHECKING:
1022
from llama_cpp import Llama, LogitsProcessorList
1123

1224

25+
class LlamaCppTokenizer(Tokenizer):
26+
def __init__(self, model: "Llama"):
27+
self.eos_token_id = model.token_eos()
28+
self.eos_token = model.tokenizer().decode([self.eos_token_id])
29+
self.pad_token_id = self.eos_token_id
30+
self.special_tokens: Set[int] = set()
31+
32+
self.vocabulary: Dict[str, int] = dict()
33+
34+
self.tokenizer = model.tokenizer()
35+
36+
# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
37+
try:
38+
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
39+
except AttributeError:
40+
# ###
41+
for t in range(model.n_vocab()):
42+
token_piece = model.tokenizer().decode([t])
43+
self.vocabulary[token_piece] = t
44+
45+
# ensure stable ordering of vocabulary
46+
self.vocabulary = {
47+
tok: tok_id
48+
for tok, tok_id in sorted(self.vocabulary.items(), key=lambda x: x[1])
49+
}
50+
51+
self._hash = None
52+
53+
def decode(self, token_ids: List[int]) -> List[str]:
54+
decoded_bytes = self.tokenizer.detokenize(token_ids)
55+
return [decoded_bytes.decode("utf-8", errors="ignore")]
56+
57+
def encode(
58+
self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True
59+
) -> Tuple[List[int], List[int]]:
60+
if isinstance(prompt, list):
61+
raise NotImplementedError(
62+
"llama-cpp-python tokenizer doesn't support batch tokenization"
63+
)
64+
token_ids = self.tokenizer.tokenize(
65+
prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
66+
)
67+
# generate attention mask, missing from llama-cpp-python
68+
attention_mask = [
69+
1 if token_id != self.pad_token_id else 0 for token_id in token_ids
70+
]
71+
return token_ids, attention_mask
72+
73+
def convert_token_to_string(self, token: str) -> str:
74+
return token
75+
76+
def __eq__(self, other):
77+
if not isinstance(other, LlamaCppTokenizer):
78+
return False
79+
return self.__getstate__() == other.__getstate__()
80+
81+
def __hash__(self):
82+
if self._hash is None:
83+
self._hash = hash(pickle.dumps(self))
84+
return self._hash
85+
86+
def __getstate__(self):
87+
"""Create a stable representation for outlines.caching"""
88+
return (
89+
self.vocabulary,
90+
self.eos_token_id,
91+
self.eos_token,
92+
self.pad_token_id,
93+
sorted(self.special_tokens),
94+
)
95+
96+
def __setstate__(self, state):
97+
raise NotImplementedError("Cannot load a pickled llamacpp tokenizer")
98+
99+
13100
class LlamaCppParams(TypedDict, total=False):
14101
suffix: Optional[str]
15102
temperature: float

tests/generate/conftest.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from importlib import reload
2+
3+
import pytest
4+
5+
6+
@pytest.fixture
7+
def temp_cache_dir():
8+
import os
9+
import tempfile
10+
11+
import outlines.caching
12+
import outlines.fsm.guide
13+
14+
with tempfile.TemporaryDirectory() as tempdir:
15+
os.environ["OUTLINES_CACHE_DIR"] = tempdir
16+
outlines.caching.get_cache.cache_clear()
17+
reload(outlines)
18+
reload(outlines.fsm.guide)
19+
cache_status = outlines.caching._caching_enabled
20+
try:
21+
outlines.caching._caching_enabled = True
22+
yield
23+
finally:
24+
outlines.caching._caching_enabled = cache_status

tests/generate/test_integration_llamacpp.py

+51-4
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,56 @@ def test_llama_cpp_pre_tokenizer_remains_broken():
281281
generate.choice(model, ["skirt", "dress", "pen", "jacket"])
282282

283283

284-
def test_create_states_mapping_llamacpp_tokenizer_regression(model):
285-
"""Minimal reproducer for #922, error passing llamacpp tokenizer to create_states_mapping"""
284+
def test_RegexGuide_caching(model, temp_cache_dir):
285+
import llama_cpp
286+
287+
import outlines.caching
286288
from outlines.fsm.guide import create_states_mapping
287-
from outlines.integrations.llamacpp import LlamaCppTokenizer
288289

289-
create_states_mapping("a", LlamaCppTokenizer(model.model))
290+
assert outlines.caching._caching_enabled
291+
292+
regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
293+
prompt = "What is the IP address of the Google DNS servers? "
294+
295+
cache = outlines.caching.get_cache()
296+
297+
# Returns (hits, misses)
298+
_ = cache.stats(enable=True)
299+
assert cache.statistics
300+
301+
assert create_states_mapping.__memory__ is cache
302+
303+
generator = generate.regex(model, regex, sampler=samplers.greedy())
304+
assert cache.stats() == (0, 1)
305+
306+
model_2 = models.llamacpp(
307+
"Qwen/Qwen1.5-0.5B-Chat-GGUF",
308+
"*q2*.gguf",
309+
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(
310+
"Qwen/Qwen1.5-0.5B-Chat"
311+
),
312+
)
313+
generator_2 = generate.regex(model_2, regex, sampler=samplers.greedy())
314+
assert cache.stats() == (0, 2)
315+
316+
# These two different models and tokenizers should not have the same state
317+
# mapping results
318+
assert (
319+
generator.logits_processor.fsm.states_to_token_maps
320+
!= generator_2.logits_processor.fsm.states_to_token_maps
321+
)
322+
323+
generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy())
324+
assert cache.stats() == (1, 2)
325+
assert (
326+
generator_2.logits_processor.fsm.states_to_token_maps
327+
== generator_3.logits_processor.fsm.states_to_token_maps
328+
)
329+
330+
# Just for fun...
331+
structured = generator(prompt, max_tokens=30)
332+
structured_2 = generator_2(prompt, max_tokens=30)
333+
334+
assert re.fullmatch(regex, structured)
335+
assert re.fullmatch(regex, structured_2)
336+
assert structured != structured_2

tests/generate/test_integration_transformers.py

-22
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import datetime
22
import re
33
from enum import Enum
4-
from importlib import reload
54
from typing import List, Union
65

76
import pytest
@@ -15,27 +14,6 @@
1514
from outlines.samplers import beam_search, greedy, multinomial
1615

1716

18-
@pytest.fixture
19-
def temp_cache_dir():
20-
import os
21-
import tempfile
22-
23-
import outlines.caching
24-
import outlines.fsm.guide
25-
26-
with tempfile.TemporaryDirectory() as tempdir:
27-
os.environ["OUTLINES_CACHE_DIR"] = tempdir
28-
outlines.caching.get_cache.cache_clear()
29-
reload(outlines)
30-
reload(outlines.fsm.guide)
31-
cache_status = outlines.caching._caching_enabled
32-
try:
33-
outlines.caching._caching_enabled = True
34-
yield
35-
finally:
36-
outlines.caching._caching_enabled = cache_status
37-
38-
3917
def test_transformers_integration_text():
4018
rng = torch.Generator()
4119
rng.manual_seed(10000) # Choosen so <EOS> is generated

0 commit comments

Comments
 (0)