1
1
from pathlib import Path
2
- from typing import Union , List
2
+ from typing import List , Union
3
3
from unittest .mock import patch
4
4
5
5
import pytest
6
+
6
7
from pytest_lazyfixture import lazy_fixture
7
8
8
- from ai21_tokenizer .jamba_1_5_tokenizer import Jamba1_5Tokenizer , AsyncJamba1_5Tokenizer
9
- from ai21_tokenizer .tokenizer_factory import JAMBA_1_5_MINI_TOKENIZER_HF_PATH , JAMBA_1_5_LARGE_TOKENIZER_HF_PATH
9
+ from ai21_tokenizer .jamba_1_5_tokenizer import (
10
+ AsyncJamba1_5Tokenizer ,
11
+ AsyncJambaTokenizer ,
12
+ Jamba1_5Tokenizer ,
13
+ SyncJambaTokenizer ,
14
+ )
15
+ from ai21_tokenizer .tokenizer_factory import (
16
+ JAMBA_LARGE_1_6_TOKENIZER_HF_PATH ,
17
+ JAMBA_MINI_1_6_TOKENIZER_HF_PATH ,
18
+ )
10
19
11
20
12
21
@pytest .mark .parametrize (
20
29
(lazy_fixture ("jamba_1_5_large_tokenizer" ),),
21
30
],
22
31
)
23
- def test_tokenizer_mini_encode_decode (tokenizer : Jamba1_5Tokenizer ):
32
+ def test_tokenizer_mini_encode_decode (tokenizer : SyncJambaTokenizer ):
24
33
text = "Hello world!"
25
34
encoded = tokenizer .encode (text )
26
35
decoded = tokenizer .decode (encoded )
@@ -46,7 +55,7 @@ def test_tokenizer_mini_encode_decode(tokenizer: Jamba1_5Tokenizer):
46
55
def test_tokenizer_mini__convert_ids_to_tokens (
47
56
ids : Union [int , List [int ]],
48
57
expected_tokens : Union [str , List [str ]],
49
- tokenizer : Jamba1_5Tokenizer ,
58
+ tokenizer : SyncJambaTokenizer ,
50
59
):
51
60
actual_tokens = tokenizer .convert_ids_to_tokens (ids )
52
61
@@ -111,13 +120,13 @@ def test_tokenizer__decode_with_start_of_line(
111
120
],
112
121
argnames = ["hf_path" ],
113
122
argvalues = [
114
- (JAMBA_1_5_MINI_TOKENIZER_HF_PATH ,),
115
- (JAMBA_1_5_LARGE_TOKENIZER_HF_PATH ,),
123
+ (JAMBA_MINI_1_6_TOKENIZER_HF_PATH ,),
124
+ (JAMBA_LARGE_1_6_TOKENIZER_HF_PATH ,),
116
125
],
117
126
)
118
127
def test_tokenizer__when_cache_dir_not_exists__should_save_tokenizer_in_cache_dir (tmp_path : Path , hf_path : str ):
119
128
assert not (tmp_path / "tokenizer.json" ).exists ()
120
- Jamba1_5Tokenizer (hf_path , tmp_path )
129
+ SyncJambaTokenizer (hf_path , tmp_path )
121
130
122
131
assert (tmp_path / "tokenizer.json" ).exists ()
123
132
@@ -129,18 +138,18 @@ def test_tokenizer__when_cache_dir_not_exists__should_save_tokenizer_in_cache_di
129
138
],
130
139
argnames = ["hf_path" ],
131
140
argvalues = [
132
- (JAMBA_1_5_MINI_TOKENIZER_HF_PATH ,),
133
- (JAMBA_1_5_LARGE_TOKENIZER_HF_PATH ,),
141
+ (JAMBA_MINI_1_6_TOKENIZER_HF_PATH ,),
142
+ (JAMBA_LARGE_1_6_TOKENIZER_HF_PATH ,),
134
143
],
135
144
)
136
145
def test_tokenizer__when_cache_dir_exists__should_load_from_cache (tmp_path : Path , hf_path : str ):
137
146
# Creating tokenizer once from repo
138
147
assert not (tmp_path / "tokenizer.json" ).exists ()
139
- Jamba1_5Tokenizer (hf_path , tmp_path )
148
+ SyncJambaTokenizer (hf_path , tmp_path )
140
149
141
150
# Creating tokenizer again to load from cache
142
- with patch .object (Jamba1_5Tokenizer , Jamba1_5Tokenizer ._load_from_cache .__name__ ) as mock_load_from_cache :
143
- Jamba1_5Tokenizer (hf_path , tmp_path )
151
+ with patch .object (SyncJambaTokenizer , SyncJambaTokenizer ._load_from_cache .__name__ ) as mock_load_from_cache :
152
+ SyncJambaTokenizer (hf_path , tmp_path )
144
153
145
154
# assert load_from_cache was called
146
155
mock_load_from_cache .assert_called_once ()
@@ -253,19 +262,19 @@ async def test_async_tokenizer__decode_with_start_of_line(
253
262
],
254
263
argnames = ["hf_path" ],
255
264
argvalues = [
256
- (JAMBA_1_5_MINI_TOKENIZER_HF_PATH ,),
257
- (JAMBA_1_5_LARGE_TOKENIZER_HF_PATH ,),
265
+ (JAMBA_MINI_1_6_TOKENIZER_HF_PATH ,),
266
+ (JAMBA_LARGE_1_6_TOKENIZER_HF_PATH ,),
258
267
],
259
268
)
260
269
async def test_async_tokenizer_encode_caches_tokenizer__should_have_tokenizer_in_cache_dir (
261
270
tmp_path : Path , hf_path : str
262
271
):
263
272
assert not (tmp_path / "tokenizer.json" ).exists ()
264
- jamba_tokenizer = await AsyncJamba1_5Tokenizer .create (hf_path , tmp_path )
273
+ jamba_tokenizer = await AsyncJambaTokenizer .create (hf_path , tmp_path )
265
274
_ = await jamba_tokenizer .encode ("Hello world!" )
266
275
assert (tmp_path / "tokenizer.json" ).exists ()
267
276
268
277
269
278
def test_async_tokenizer_initialized_directly__should_raise_error ():
270
279
with pytest .raises (ValueError ):
271
- AsyncJamba1_5Tokenizer ()
280
+ AsyncJambaTokenizer ()
0 commit comments