1
1
from __future__ import annotations
2
2
3
+ import logging
3
4
import tempfile
5
+
4
6
from pathlib import Path
5
- from typing import Union , List , Optional , cast
7
+ from typing import List , Optional , Union , cast
6
8
7
9
from tokenizers import Tokenizer
8
10
9
- from ai21_tokenizer import BaseTokenizer , AsyncBaseTokenizer
10
- from ai21_tokenizer .file_utils import PathLike
11
+ from ai21_tokenizer import AsyncBaseTokenizer , BaseTokenizer
11
12
from ai21_tokenizer .base_jamba_tokenizer import BaseJambaTokenizer
13
+ from ai21_tokenizer .file_utils import PathLike
14
+
15
+
16
+ _logger = logging .getLogger (__name__ )
12
17
13
18
_TOKENIZER_FILE = "tokenizer.json"
14
19
_DEFAULT_MODEL_CACHE_DIR = Path (tempfile .gettempdir ()) / "jamba_1_5"
@@ -31,8 +36,11 @@ def __init__(
31
36
self ._tokenizer = self ._init_tokenizer (model_path = model_path , cache_dir = cache_dir or _DEFAULT_MODEL_CACHE_DIR )
32
37
33
38
def _init_tokenizer (self , model_path : PathLike , cache_dir : PathLike ) -> Tokenizer :
34
- if self ._is_cached (cache_dir ):
35
- return self ._load_from_cache (cache_dir / _TOKENIZER_FILE )
39
+ try :
40
+ if self ._is_cached (cache_dir ):
41
+ return self ._load_from_cache (cache_dir / _TOKENIZER_FILE )
42
+ except Exception as e :
43
+ _logger .error (f"Error loading tokenizer from cache. Trying to download: { e } " )
36
44
37
45
tokenizer = cast (
38
46
Tokenizer ,
@@ -127,9 +135,13 @@ def vocab_size(self) -> int:
127
135
return self ._tokenizer .get_vocab_size ()
128
136
129
137
async def _init_tokenizer (self ):
130
- if self ._is_cached (self ._cache_dir ):
131
- self ._tokenizer = await self ._load_from_cache (self ._cache_dir / _TOKENIZER_FILE )
132
- else :
138
+ try :
139
+ if self ._is_cached (self ._cache_dir ):
140
+ self ._tokenizer = await self ._load_from_cache (self ._cache_dir / _TOKENIZER_FILE )
141
+ except Exception as e :
142
+ _logger .error (f"Error loading tokenizer from cache. Trying to download: { e } " )
143
+
144
+ if self ._tokenizer is None :
133
145
tokenizer_from_pretrained = await self ._make_async_call (
134
146
callback_func = Tokenizer .from_pretrained ,
135
147
identifier = self ._model_path ,
0 commit comments