diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 9b1a3d263..45cbd7bca 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -578,6 +578,8 @@ def tokenize( Args: text: The utf-8 encoded string to tokenize. + add_bos: Whether to add a beginning of sequence token. + special: Whether to tokenize special tokens. Raises: RuntimeError: If the tokenization failed. @@ -588,18 +590,19 @@ def tokenize( return self.tokenizer_.tokenize(text, add_bos, special) def detokenize( - self, tokens: List[int], prev_tokens: Optional[List[int]] = None + self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False ) -> bytes: """Detokenize a list of tokens. Args: tokens: The list of tokens to detokenize. - prev_tokens: The list of previous tokens. Offset mapping will be performed if provided + prev_tokens: The list of previous tokens. Offset mapping will be performed if provided. + special: Whether to detokenize special tokens. Returns: The detokenized string. """ - return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens) + return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens, special=special) def set_cache(self, cache: Optional[BaseLlamaCache]): """Set the cache. diff --git a/llama_cpp/llama_tokenizer.py b/llama_cpp/llama_tokenizer.py index 029bf2acc..2e7590d14 100644 --- a/llama_cpp/llama_tokenizer.py +++ b/llama_cpp/llama_tokenizer.py @@ -19,20 +19,22 @@ def tokenize( """Tokenize the text into tokens. Args: - text: The text to tokenize. + text: The utf-8 encoded string to tokenize. add_bos: Whether to add a beginning of sequence token. - special: Whether to tokenize text literally or as special tokens.""" + special: Whether to tokenize special tokens. + """ raise NotImplementedError @abc.abstractmethod def detokenize( - self, tokens: List[int], prev_tokens: Optional[List[int]] = None + self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False ) -> bytes: """Detokenize the tokens into text. Args: - tokens: The tokens to detokenize. - prev_tokens: If tokens is a continuation of a previous sequence, the previous tokens. + tokens: The list of tokens to detokenize. + prev_tokens: The list of previous tokens. Offset mapping will be performed if provided. + special: Whether to detokenize special tokens. """ raise NotImplementedError @@ -47,9 +49,9 @@ def tokenize( return self._model.tokenize(text, add_bos=add_bos, special=special) def detokenize( - self, tokens: List[int], prev_tokens: Optional[List[int]] = None + self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False ) -> bytes: - return self._model.detokenize(tokens) + return self._model.detokenize(tokens, special=special) def encode( self, text: str, add_bos: bool = True, special: bool = True @@ -78,18 +80,19 @@ def tokenize( ) def detokenize( - self, tokens: List[int], prev_tokens: Optional[List[int]] = None + self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False ) -> bytes: + skip_special_tokens = not special if prev_tokens is not None: - text = self.hf_tokenizer.decode(prev_tokens + tokens).encode( + text = self.hf_tokenizer.decode(prev_tokens + tokens, skip_special_tokens=skip_special_tokens).encode( "utf-8", errors="ignore" ) - prev_text = self.hf_tokenizer.decode(prev_tokens).encode( + prev_text = self.hf_tokenizer.decode(prev_tokens, skip_special_tokens=skip_special_tokens).encode( "utf-8", errors="ignore" ) return text[len(prev_text) :] else: - return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore") + return self.hf_tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens).encode("utf-8", errors="ignore") @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":