Skip to content

Commit bdda7c8

Browse files
authored
Accept custom pattern string and special tokens
Differential Revision: D74264910 Pull Request resolved: #69
1 parent 9ceef56 commit bdda7c8

File tree

7 files changed

+128
-44
lines changed

7 files changed

+128
-44
lines changed

include/pytorch/tokenizers/tiktoken.h

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,33 @@ static constexpr size_t kEOSTokenIndex = 1;
3232
class Tiktoken : public detail::BPETokenizerBase {
3333
public:
3434
explicit Tiktoken(
35+
std::string pattern,
3536
std::unique_ptr<std::vector<std::string>> special_tokens,
3637
size_t bos_token_index,
3738
size_t eos_token_index)
38-
: _special_tokens(std::move(special_tokens)),
39+
: _pattern(std::move(pattern)),
40+
_special_tokens(std::move(special_tokens)),
3941
_bos_token_index(bos_token_index),
4042
_eos_token_index(eos_token_index) {
4143
if (_bos_token_index >= _special_tokens->size() ||
4244
_eos_token_index >= _special_tokens->size()) {
4345
abort();
4446
}
45-
};
47+
}
48+
49+
explicit Tiktoken(
50+
std::unique_ptr<std::vector<std::string>> special_tokens,
51+
size_t bos_token_index,
52+
size_t eos_token_index)
53+
: Tiktoken(
54+
_get_default_patern(),
55+
std::move(special_tokens),
56+
bos_token_index,
57+
eos_token_index) {}
4658

4759
explicit Tiktoken()
48-
: _special_tokens(_get_default_special_tokens()),
60+
: _pattern(_get_default_patern()),
61+
_special_tokens(_get_default_special_tokens()),
4962
_bos_token_index(kBOSTokenIndex),
5063
_eos_token_index(kEOSTokenIndex){};
5164

@@ -77,6 +90,11 @@ class Tiktoken : public detail::BPETokenizerBase {
7790
return special_tokens;
7891
}
7992

93+
static inline std::string _get_default_patern() {
94+
// Removed negative lookahead \s+(?!\S) since it's not supported by RE2.
95+
return R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)";
96+
}
97+
8098
Error _encode(
8199
const std::string& input,
82100
std::vector<uint64_t>& ret,
@@ -86,14 +104,11 @@ class Tiktoken : public detail::BPETokenizerBase {
86104

87105
detail::TokenMap _build_special_token_map(ssize_t num_base_tokens) const;
88106

107+
std::string _pattern;
89108
std::unique_ptr<std::vector<std::string>> _special_tokens;
90109
size_t _bos_token_index;
91110
size_t _eos_token_index;
92111

93-
// Removed negative lookahead \s+(?!\S) since it's not supported by RE2.
94-
const std::string _pattern =
95-
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)";
96-
97112
std::unique_ptr<IRegex> _regex;
98113
};
99114

pytorch_tokenizers/constants.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# @lint-ignore-every LICENSELINT
7+
8+
CL100K_PAT_STR = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
9+
10+
LLAMA_BASIC_SPECIAL_TOKENS = [
11+
"<|begin_of_text|>",
12+
"<|end_of_text|>",
13+
"<|reserved_special_token_0|>",
14+
"<|reserved_special_token_1|>",
15+
"<|finetune_right_pad_id|>",
16+
"<|step_id|>",
17+
"<|start_header_id|>",
18+
"<|end_header_id|>",
19+
"<|eom_id|>", # end of message
20+
"<|eot_id|>", # end of turn
21+
"<|python_tag|>",
22+
"<|image|>",
23+
]
24+
25+
LLAMA_NUM_RESERVED_SPECIAL_TOKENS = 256
26+
LLAMA_RESERVED_SPECIAL_TOKENS = [
27+
f"<|reserved_special_token_{2 + i}|>"
28+
for i in range(LLAMA_NUM_RESERVED_SPECIAL_TOKENS - len(LLAMA_BASIC_SPECIAL_TOKENS))
29+
]
30+
31+
LLAMA_SPECIAL_TOKENS = LLAMA_BASIC_SPECIAL_TOKENS + LLAMA_RESERVED_SPECIAL_TOKENS

pytorch_tokenizers/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def define_common_targets():
1010
name = "tokenizers",
1111
srcs = [
1212
"__init__.py",
13+
"constants.py",
1314
"llama2c.py",
1415
"tiktoken.py",
1516
"hf_tokenizer.py",

pytorch_tokenizers/tiktoken.py

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from tiktoken.load import load_tiktoken_bpe
2727

28+
from .constants import CL100K_PAT_STR, LLAMA_SPECIAL_TOKENS
29+
2830
logger = getLogger(__name__)
2931

3032

@@ -47,12 +49,6 @@ class TiktokenTokenizer:
4749
WARNING: The regex and special tokens are hardcoded from Llama 3+.
4850
"""
4951

50-
special_tokens: Dict[str, int]
51-
52-
num_reserved_special_tokens = 256
53-
54-
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
55-
5652
@classmethod
5753
def get_instance(cls):
5854
global _INSTANCE
@@ -63,7 +59,12 @@ def get_instance(cls):
6359
)
6460
return _INSTANCE
6561

66-
def __init__(self, model_path: str):
62+
def __init__(
63+
self,
64+
model_path: str,
65+
pat_str: str = CL100K_PAT_STR,
66+
special_tokens: List[str] = LLAMA_SPECIAL_TOKENS,
67+
):
6768
"""
6869
Initializes the Tokenizer with a Tiktoken model.
6970
@@ -74,32 +75,13 @@ def __init__(self, model_path: str):
7475

7576
mergeable_ranks = load_tiktoken_bpe(model_path)
7677
num_base_tokens = len(mergeable_ranks)
77-
special_tokens = [
78-
"<|begin_of_text|>",
79-
"<|end_of_text|>",
80-
"<|reserved_special_token_0|>",
81-
"<|reserved_special_token_1|>",
82-
"<|finetune_right_pad_id|>",
83-
"<|step_id|>",
84-
"<|start_header_id|>",
85-
"<|end_header_id|>",
86-
"<|eom_id|>", # end of message
87-
"<|eot_id|>", # end of turn
88-
"<|python_tag|>",
89-
"<|image|>",
90-
]
91-
reserved_tokens = [
92-
f"<|reserved_special_token_{2 + i}|>"
93-
for i in range(self.num_reserved_special_tokens - len(special_tokens))
94-
]
95-
special_tokens = special_tokens + reserved_tokens
9678

9779
self.special_tokens = {
9880
token: num_base_tokens + i for i, token in enumerate(special_tokens)
9981
}
10082
self.model = tiktoken.Encoding(
10183
name=Path(model_path).name,
102-
pat_str=self.pat_str,
84+
pat_str=pat_str,
10385
mergeable_ranks=mergeable_ranks,
10486
special_tokens=self.special_tokens,
10587
)
@@ -108,15 +90,6 @@ def __init__(self, model_path: str):
10890
# BOS / EOS token IDs
10991
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
11092
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
111-
self.eot_id: int = self.special_tokens["<|eot_id|>"]
112-
self.eom_id: int = self.special_tokens["<|eom_id|>"]
113-
self.python_tag_id = self.special_tokens["<|python_tag|>"]
114-
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
115-
self.stop_tokens = [
116-
self.eos_id,
117-
self.special_tokens["<|eom_id|>"],
118-
self.special_tokens["<|eot_id|>"],
119-
]
12093

12194
def encode(
12295
self,

test/targets.bzl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,16 @@ def define_common_targets():
108108
src = "resources/test_tiktoken_tokenizer.model",
109109
visibility = ["@EXECUTORCH_CLIENTS", "//pytorch/tokenizers/..."],
110110
)
111+
112+
runtime.python_test(
113+
name = "test_tiktoken_py",
114+
srcs = [
115+
"test_tiktoken.py",
116+
],
117+
deps = [
118+
"//pytorch/tokenizers/pytorch_tokenizers:tokenizers",
119+
],
120+
resources = {
121+
":test_tiktoken_tokenizer_model": "test_tiktoken_tokenizer.model",
122+
},
123+
)

test/test_tiktoken.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ namespace tokenizers {
1616

1717
namespace {
1818
// Test case based on Llama 2
19+
const std::string kPattern =
20+
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)";
1921
static constexpr int32_t kSpecialTokensSize = 256;
2022
static inline std::unique_ptr<std::vector<std::string>> _get_special_tokens() {
2123
auto special_tokens =
@@ -50,7 +52,8 @@ static inline std::string _get_resource_path(const std::string& name) {
5052
class TiktokenTest : public Test {
5153
public:
5254
void SetUp() override {
53-
tokenizer_ = std::make_unique<Tiktoken>(_get_special_tokens(), 0, 1);
55+
tokenizer_ =
56+
std::make_unique<Tiktoken>(kPattern, _get_special_tokens(), 0, 1);
5457
modelPath_ = _get_resource_path("test_tiktoken_tokenizer.model");
5558
}
5659

test/test_tiktoken.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# @lint-ignore-every LICENSELINT
7+
8+
import unittest
9+
10+
import pkg_resources
11+
12+
from pytorch_tokenizers.tiktoken import TiktokenTokenizer
13+
14+
15+
class TestTiktokenTokenizer(unittest.TestCase):
16+
def test_default(self):
17+
model_path = pkg_resources.resource_filename(
18+
"pytorch.tokenizers.test", "test_tiktoken_tokenizer.model"
19+
)
20+
tiktoken = TiktokenTokenizer(model_path)
21+
s = "<|begin_of_text|> hellow world."
22+
self.assertEqual(s, tiktoken.decode(tiktoken.encode(s, bos=False, eos=False)))
23+
24+
def test_custom_pattern_and_special_tokens(self):
25+
o220k_pattern = r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
26+
model_path = pkg_resources.resource_filename(
27+
"pytorch.tokenizers.test", "test_tiktoken_tokenizer.model"
28+
)
29+
tiktoken = TiktokenTokenizer(
30+
model_path,
31+
pat_str=o220k_pattern,
32+
special_tokens=[
33+
"<|begin_of_text|>",
34+
"<|end_of_text|>",
35+
"<|custom_token|>",
36+
],
37+
)
38+
custom_token_id = tiktoken.special_tokens["<|custom_token|>"]
39+
40+
s = "<|begin_of_text|> hellow world, this is a custom token: <|custom_token|>."
41+
encoding = tiktoken.encode(
42+
s,
43+
bos=False,
44+
eos=False,
45+
allowed_special="all",
46+
)
47+
self.assertTrue(custom_token_id in encoding)
48+
self.assertEqual(s, tiktoken.decode(encoding))

0 commit comments

Comments
 (0)