Skip to content

Commit 0b8cfbc

Browse files
committed
fix: Paths to huggingface
1 parent b5629cf commit 0b8cfbc

File tree

4 files changed

+89
-34
lines changed

4 files changed

+89
-34
lines changed

ai21_tokenizer/__init__.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
1-
from ai21_tokenizer.base_tokenizer import BaseTokenizer, AsyncBaseTokenizer
2-
from ai21_tokenizer.jamba_instruct_tokenizer import JambaInstructTokenizer, AsyncJambaInstructTokenizer
3-
from ai21_tokenizer.jurassic_tokenizer import JurassicTokenizer, AsyncJurassicTokenizer
4-
from ai21_tokenizer.tokenizer_factory import TokenizerFactory as Tokenizer, PreTrainedTokenizers
5-
from ai21_tokenizer.jamba_1_5_tokenizer import Jamba1_5Tokenizer, AsyncJamba1_5Tokenizer
1+
from ai21_tokenizer.base_tokenizer import AsyncBaseTokenizer, BaseTokenizer
2+
from ai21_tokenizer.jamba_1_5_tokenizer import (
3+
AsyncJamba1_5Tokenizer,
4+
AsyncJambaTokenizer,
5+
Jamba1_5Tokenizer,
6+
SyncJambaTokenizer,
7+
)
8+
from ai21_tokenizer.jamba_instruct_tokenizer import (
9+
AsyncJambaInstructTokenizer,
10+
JambaInstructTokenizer,
11+
)
12+
from ai21_tokenizer.jurassic_tokenizer import AsyncJurassicTokenizer, JurassicTokenizer
13+
from ai21_tokenizer.tokenizer_factory import (
14+
PreTrainedTokenizers,
15+
TokenizerFactory as Tokenizer,
16+
)
17+
618
from .version import VERSION
719

20+
821
__version__ = VERSION
922

1023
__all__ = [
@@ -19,4 +32,6 @@
1932
"AsyncJambaInstructTokenizer",
2033
"Jamba1_5Tokenizer",
2134
"AsyncJamba1_5Tokenizer",
35+
"SyncJambaTokenizer",
36+
"AsyncJambaTokenizer",
2237
]

ai21_tokenizer/jamba_1_5_tokenizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,11 @@ async def _init_tokenizer(self):
158158
async def _load_from_cache(self, cache_file: Path) -> Tokenizer:
159159
tokenizer_from_file = await self._make_async_call(callback_func=Tokenizer.from_file, path=str(cache_file))
160160
return cast(Tokenizer, tokenizer_from_file)
161+
162+
163+
class SyncJambaTokenizer(Jamba1_5Tokenizer):
164+
pass
165+
166+
167+
class AsyncJambaTokenizer(AsyncJamba1_5Tokenizer):
168+
pass

ai21_tokenizer/tokenizer_factory.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
import os
22
import tempfile
3+
34
from pathlib import Path
45

5-
from ai21_tokenizer.base_tokenizer import BaseTokenizer, AsyncBaseTokenizer
6-
from ai21_tokenizer.jamba_instruct_tokenizer import JambaInstructTokenizer, AsyncJambaInstructTokenizer
7-
from ai21_tokenizer.jamba_1_5_tokenizer import Jamba1_5Tokenizer, AsyncJamba1_5Tokenizer
8-
from ai21_tokenizer.jurassic_tokenizer import JurassicTokenizer, AsyncJurassicTokenizer
6+
from ai21_tokenizer.base_tokenizer import AsyncBaseTokenizer, BaseTokenizer
7+
from ai21_tokenizer.jamba_1_5_tokenizer import (
8+
AsyncJambaTokenizer,
9+
SyncJambaTokenizer,
10+
)
11+
from ai21_tokenizer.jamba_instruct_tokenizer import (
12+
AsyncJambaInstructTokenizer,
13+
JambaInstructTokenizer,
14+
)
15+
from ai21_tokenizer.jurassic_tokenizer import AsyncJurassicTokenizer, JurassicTokenizer
16+
917

1018
_LOCAL_RESOURCES_PATH = Path(__file__).parent / "resources"
1119
_ENV_CACHE_DIR_KEY = "AI21_TOKENIZER_CACHE_DIR"
1220
JAMBA_TOKENIZER_HF_PATH = "ai21labs/Jamba-v0.1"
13-
JAMBA_1_5_MINI_TOKENIZER_HF_PATH = "ai21labs/AI21-Jamba-1.5-Mini"
14-
JAMBA_1_5_LARGE_TOKENIZER_HF_PATH = "ai21labs/AI21-Jamba-1.5-Large"
21+
JAMBA_MINI_1_6_TOKENIZER_HF_PATH = "ai21labs/AI21-Jamba-Mini-1.6"
22+
JAMBA_LARGE_1_6_TOKENIZER_HF_PATH = "ai21labs/AI21-Jamba-Large-1.6"
1523

1624

1725
def _get_cache_dir(tokenizer_name: str) -> Path:
@@ -27,12 +35,17 @@ def _get_cache_dir(tokenizer_name: str) -> Path:
2735

2836

2937
class PreTrainedTokenizers:
38+
# deprecated tokenizers
3039
J2_TOKENIZER = "j2-tokenizer"
3140
JAMBA_INSTRUCT_TOKENIZER = "jamba-instruct-tokenizer"
3241
JAMBA_TOKENIZER = "jamba-tokenizer"
3342
JAMBA_1_5_MINI_TOKENIZER = "jamba-1.5-mini-tokenizer"
3443
JAMBA_1_5_LARGE_TOKENIZER = "jamba-1.5-large-tokenizer"
3544

45+
# active tokenizers
46+
JAMBA_MINI_1_6_TOKENIZER = "jamba-mini-1.6-tokenizer"
47+
JAMBA_LARGE_1_6_TOKENIZER = "jamba-large-1.6-tokenizer"
48+
3649

3750
class TokenizerFactory:
3851
"""
@@ -48,10 +61,16 @@ def get_tokenizer(
4861
cache_dir = _get_cache_dir(tokenizer_name=tokenizer_name)
4962

5063
if tokenizer_name == PreTrainedTokenizers.JAMBA_1_5_MINI_TOKENIZER:
51-
return Jamba1_5Tokenizer(model_path=JAMBA_1_5_MINI_TOKENIZER_HF_PATH, cache_dir=cache_dir)
64+
return SyncJambaTokenizer(model_path=JAMBA_MINI_1_6_TOKENIZER_HF_PATH, cache_dir=cache_dir)
5265

5366
if tokenizer_name == PreTrainedTokenizers.JAMBA_1_5_LARGE_TOKENIZER:
54-
return Jamba1_5Tokenizer(model_path=JAMBA_1_5_LARGE_TOKENIZER_HF_PATH, cache_dir=cache_dir)
67+
return SyncJambaTokenizer(model_path=JAMBA_LARGE_1_6_TOKENIZER_HF_PATH, cache_dir=cache_dir)
68+
69+
if tokenizer_name == PreTrainedTokenizers.JAMBA_MINI_1_6_TOKENIZER:
70+
return SyncJambaTokenizer(model_path=JAMBA_MINI_1_6_TOKENIZER_HF_PATH, cache_dir=cache_dir)
71+
72+
if tokenizer_name == PreTrainedTokenizers.JAMBA_LARGE_1_6_TOKENIZER:
73+
return SyncJambaTokenizer(model_path=JAMBA_LARGE_1_6_TOKENIZER_HF_PATH, cache_dir=cache_dir)
5574

5675
if (
5776
tokenizer_name == PreTrainedTokenizers.JAMBA_INSTRUCT_TOKENIZER
@@ -72,12 +91,16 @@ async def get_async_tokenizer(
7291
cache_dir = _get_cache_dir(tokenizer_name=tokenizer_name)
7392

7493
if tokenizer_name == PreTrainedTokenizers.JAMBA_1_5_MINI_TOKENIZER:
75-
return await AsyncJamba1_5Tokenizer.create(model_path=JAMBA_1_5_MINI_TOKENIZER_HF_PATH, cache_dir=cache_dir)
94+
return await AsyncJambaTokenizer.create(model_path=JAMBA_MINI_1_6_TOKENIZER_HF_PATH, cache_dir=cache_dir)
7695

7796
if tokenizer_name == PreTrainedTokenizers.JAMBA_1_5_LARGE_TOKENIZER:
78-
return await AsyncJamba1_5Tokenizer.create(
79-
model_path=JAMBA_1_5_LARGE_TOKENIZER_HF_PATH, cache_dir=cache_dir
80-
)
97+
return await AsyncJambaTokenizer.create(model_path=JAMBA_LARGE_1_6_TOKENIZER_HF_PATH, cache_dir=cache_dir)
98+
99+
if tokenizer_name == PreTrainedTokenizers.JAMBA_MINI_1_6_TOKENIZER:
100+
return await AsyncJambaTokenizer.create(model_path=JAMBA_MINI_1_6_TOKENIZER_HF_PATH, cache_dir=cache_dir)
101+
102+
if tokenizer_name == PreTrainedTokenizers.JAMBA_LARGE_1_6_TOKENIZER:
103+
return await AsyncJambaTokenizer.create(model_path=JAMBA_LARGE_1_6_TOKENIZER_HF_PATH, cache_dir=cache_dir)
81104

82105
if (
83106
tokenizer_name == PreTrainedTokenizers.JAMBA_INSTRUCT_TOKENIZER

tests/test_jamba_1_5_tokenizer.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from pathlib import Path
2-
from typing import Union, List
2+
from typing import List, Union
33
from unittest.mock import patch
44

55
import pytest
6+
67
from pytest_lazyfixture import lazy_fixture
78

8-
from ai21_tokenizer.jamba_1_5_tokenizer import Jamba1_5Tokenizer, AsyncJamba1_5Tokenizer
9-
from ai21_tokenizer.tokenizer_factory import JAMBA_1_5_MINI_TOKENIZER_HF_PATH, JAMBA_1_5_LARGE_TOKENIZER_HF_PATH
9+
from ai21_tokenizer.jamba_1_5_tokenizer import (
10+
AsyncJamba1_5Tokenizer,
11+
AsyncJambaTokenizer,
12+
Jamba1_5Tokenizer,
13+
SyncJambaTokenizer,
14+
)
15+
from ai21_tokenizer.tokenizer_factory import (
16+
JAMBA_LARGE_1_6_TOKENIZER_HF_PATH,
17+
JAMBA_MINI_1_6_TOKENIZER_HF_PATH,
18+
)
1019

1120

1221
@pytest.mark.parametrize(
@@ -20,7 +29,7 @@
2029
(lazy_fixture("jamba_1_5_large_tokenizer"),),
2130
],
2231
)
23-
def test_tokenizer_mini_encode_decode(tokenizer: Jamba1_5Tokenizer):
32+
def test_tokenizer_mini_encode_decode(tokenizer: SyncJambaTokenizer):
2433
text = "Hello world!"
2534
encoded = tokenizer.encode(text)
2635
decoded = tokenizer.decode(encoded)
@@ -46,7 +55,7 @@ def test_tokenizer_mini_encode_decode(tokenizer: Jamba1_5Tokenizer):
4655
def test_tokenizer_mini__convert_ids_to_tokens(
4756
ids: Union[int, List[int]],
4857
expected_tokens: Union[str, List[str]],
49-
tokenizer: Jamba1_5Tokenizer,
58+
tokenizer: SyncJambaTokenizer,
5059
):
5160
actual_tokens = tokenizer.convert_ids_to_tokens(ids)
5261

@@ -111,13 +120,13 @@ def test_tokenizer__decode_with_start_of_line(
111120
],
112121
argnames=["hf_path"],
113122
argvalues=[
114-
(JAMBA_1_5_MINI_TOKENIZER_HF_PATH,),
115-
(JAMBA_1_5_LARGE_TOKENIZER_HF_PATH,),
123+
(JAMBA_MINI_1_6_TOKENIZER_HF_PATH,),
124+
(JAMBA_LARGE_1_6_TOKENIZER_HF_PATH,),
116125
],
117126
)
118127
def test_tokenizer__when_cache_dir_not_exists__should_save_tokenizer_in_cache_dir(tmp_path: Path, hf_path: str):
119128
assert not (tmp_path / "tokenizer.json").exists()
120-
Jamba1_5Tokenizer(hf_path, tmp_path)
129+
SyncJambaTokenizer(hf_path, tmp_path)
121130

122131
assert (tmp_path / "tokenizer.json").exists()
123132

@@ -129,18 +138,18 @@ def test_tokenizer__when_cache_dir_not_exists__should_save_tokenizer_in_cache_di
129138
],
130139
argnames=["hf_path"],
131140
argvalues=[
132-
(JAMBA_1_5_MINI_TOKENIZER_HF_PATH,),
133-
(JAMBA_1_5_LARGE_TOKENIZER_HF_PATH,),
141+
(JAMBA_MINI_1_6_TOKENIZER_HF_PATH,),
142+
(JAMBA_LARGE_1_6_TOKENIZER_HF_PATH,),
134143
],
135144
)
136145
def test_tokenizer__when_cache_dir_exists__should_load_from_cache(tmp_path: Path, hf_path: str):
137146
# Creating tokenizer once from repo
138147
assert not (tmp_path / "tokenizer.json").exists()
139-
Jamba1_5Tokenizer(hf_path, tmp_path)
148+
SyncJambaTokenizer(hf_path, tmp_path)
140149

141150
# Creating tokenizer again to load from cache
142-
with patch.object(Jamba1_5Tokenizer, Jamba1_5Tokenizer._load_from_cache.__name__) as mock_load_from_cache:
143-
Jamba1_5Tokenizer(hf_path, tmp_path)
151+
with patch.object(SyncJambaTokenizer, SyncJambaTokenizer._load_from_cache.__name__) as mock_load_from_cache:
152+
SyncJambaTokenizer(hf_path, tmp_path)
144153

145154
# assert load_from_cache was called
146155
mock_load_from_cache.assert_called_once()
@@ -253,19 +262,19 @@ async def test_async_tokenizer__decode_with_start_of_line(
253262
],
254263
argnames=["hf_path"],
255264
argvalues=[
256-
(JAMBA_1_5_MINI_TOKENIZER_HF_PATH,),
257-
(JAMBA_1_5_LARGE_TOKENIZER_HF_PATH,),
265+
(JAMBA_MINI_1_6_TOKENIZER_HF_PATH,),
266+
(JAMBA_LARGE_1_6_TOKENIZER_HF_PATH,),
258267
],
259268
)
260269
async def test_async_tokenizer_encode_caches_tokenizer__should_have_tokenizer_in_cache_dir(
261270
tmp_path: Path, hf_path: str
262271
):
263272
assert not (tmp_path / "tokenizer.json").exists()
264-
jamba_tokenizer = await AsyncJamba1_5Tokenizer.create(hf_path, tmp_path)
273+
jamba_tokenizer = await AsyncJambaTokenizer.create(hf_path, tmp_path)
265274
_ = await jamba_tokenizer.encode("Hello world!")
266275
assert (tmp_path / "tokenizer.json").exists()
267276

268277

269278
def test_async_tokenizer_initialized_directly__should_raise_error():
270279
with pytest.raises(ValueError):
271-
AsyncJamba1_5Tokenizer()
280+
AsyncJambaTokenizer()

0 commit comments

Comments
 (0)