|
1 | 1 | import dataclasses
|
| 2 | +import pickle |
2 | 3 | import warnings
|
3 |
| -from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union |
| 4 | +from typing import ( |
| 5 | + TYPE_CHECKING, |
| 6 | + Dict, |
| 7 | + Iterator, |
| 8 | + List, |
| 9 | + Optional, |
| 10 | + Set, |
| 11 | + Tuple, |
| 12 | + TypedDict, |
| 13 | + Union, |
| 14 | +) |
4 | 15 |
|
5 | 16 | from typing_extensions import Unpack
|
6 | 17 |
|
7 | 18 | from outlines.generate.api import GenerationParameters, SamplingParameters
|
| 19 | +from outlines.models.tokenizer import Tokenizer |
8 | 20 |
|
9 | 21 | if TYPE_CHECKING:
|
10 | 22 | from llama_cpp import Llama, LogitsProcessorList
|
11 | 23 |
|
12 | 24 |
|
| 25 | +class LlamaCppTokenizer(Tokenizer): |
| 26 | + def __init__(self, model: "Llama"): |
| 27 | + self.eos_token_id = model.token_eos() |
| 28 | + self.eos_token = model.tokenizer().decode([self.eos_token_id]) |
| 29 | + self.pad_token_id = self.eos_token_id |
| 30 | + self.special_tokens: Set[int] = set() |
| 31 | + |
| 32 | + self.vocabulary: Dict[str, int] = dict() |
| 33 | + |
| 34 | + self.tokenizer = model.tokenizer() |
| 35 | + |
| 36 | + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved |
| 37 | + try: |
| 38 | + self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() |
| 39 | + except AttributeError: |
| 40 | + # ### |
| 41 | + for t in range(model.n_vocab()): |
| 42 | + token_piece = model.tokenizer().decode([t]) |
| 43 | + self.vocabulary[token_piece] = t |
| 44 | + |
| 45 | + # ensure stable ordering of vocabulary |
| 46 | + self.vocabulary = { |
| 47 | + tok: tok_id |
| 48 | + for tok, tok_id in sorted(self.vocabulary.items(), key=lambda x: x[1]) |
| 49 | + } |
| 50 | + |
| 51 | + self._hash = None |
| 52 | + |
| 53 | + def decode(self, token_ids: List[int]) -> List[str]: |
| 54 | + decoded_bytes = self.tokenizer.detokenize(token_ids) |
| 55 | + return [decoded_bytes.decode("utf-8", errors="ignore")] |
| 56 | + |
| 57 | + def encode( |
| 58 | + self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True |
| 59 | + ) -> Tuple[List[int], List[int]]: |
| 60 | + if isinstance(prompt, list): |
| 61 | + raise NotImplementedError( |
| 62 | + "llama-cpp-python tokenizer doesn't support batch tokenization" |
| 63 | + ) |
| 64 | + token_ids = self.tokenizer.tokenize( |
| 65 | + prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special |
| 66 | + ) |
| 67 | + # generate attention mask, missing from llama-cpp-python |
| 68 | + attention_mask = [ |
| 69 | + 1 if token_id != self.pad_token_id else 0 for token_id in token_ids |
| 70 | + ] |
| 71 | + return token_ids, attention_mask |
| 72 | + |
| 73 | + def convert_token_to_string(self, token: str) -> str: |
| 74 | + return token |
| 75 | + |
| 76 | + def __eq__(self, other): |
| 77 | + if not isinstance(other, LlamaCppTokenizer): |
| 78 | + return False |
| 79 | + return self.__getstate__() == other.__getstate__() |
| 80 | + |
| 81 | + def __hash__(self): |
| 82 | + if self._hash is None: |
| 83 | + self._hash = hash(pickle.dumps(self)) |
| 84 | + return self._hash |
| 85 | + |
| 86 | + def __getstate__(self): |
| 87 | + """Create a stable representation for outlines.caching""" |
| 88 | + return ( |
| 89 | + self.vocabulary, |
| 90 | + self.eos_token_id, |
| 91 | + self.eos_token, |
| 92 | + self.pad_token_id, |
| 93 | + sorted(self.special_tokens), |
| 94 | + ) |
| 95 | + |
| 96 | + def __setstate__(self, state): |
| 97 | + raise NotImplementedError("Cannot load a pickled llamacpp tokenizer") |
| 98 | + |
| 99 | + |
13 | 100 | class LlamaCppParams(TypedDict, total=False):
|
14 | 101 | suffix: Optional[str]
|
15 | 102 | temperature: float
|
|
0 commit comments