Skip to content

Commit f7ca1fc

Browse files
committed
fix: Load from huggingface if cache fails
1 parent 499ff81 commit f7ca1fc

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

ai21_tokenizer/jamba_1_5_tokenizer.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
from __future__ import annotations
22

3+
import logging
34
import tempfile
5+
46
from pathlib import Path
5-
from typing import Union, List, Optional, cast
7+
from typing import List, Optional, Union, cast
68

79
from tokenizers import Tokenizer
810

9-
from ai21_tokenizer import BaseTokenizer, AsyncBaseTokenizer
10-
from ai21_tokenizer.file_utils import PathLike
11+
from ai21_tokenizer import AsyncBaseTokenizer, BaseTokenizer
1112
from ai21_tokenizer.base_jamba_tokenizer import BaseJambaTokenizer
13+
from ai21_tokenizer.file_utils import PathLike
14+
15+
16+
_logger = logging.getLogger(__name__)
1217

1318
_TOKENIZER_FILE = "tokenizer.json"
1419
_DEFAULT_MODEL_CACHE_DIR = Path(tempfile.gettempdir()) / "jamba_1_5"
@@ -31,8 +36,11 @@ def __init__(
3136
self._tokenizer = self._init_tokenizer(model_path=model_path, cache_dir=cache_dir or _DEFAULT_MODEL_CACHE_DIR)
3237

3338
def _init_tokenizer(self, model_path: PathLike, cache_dir: PathLike) -> Tokenizer:
34-
if self._is_cached(cache_dir):
35-
return self._load_from_cache(cache_dir / _TOKENIZER_FILE)
39+
try:
40+
if self._is_cached(cache_dir):
41+
return self._load_from_cache(cache_dir / _TOKENIZER_FILE)
42+
except Exception as e:
43+
_logger.error(f"Error loading tokenizer from cache. Trying to download: {e}")
3644

3745
tokenizer = cast(
3846
Tokenizer,
@@ -127,9 +135,13 @@ def vocab_size(self) -> int:
127135
return self._tokenizer.get_vocab_size()
128136

129137
async def _init_tokenizer(self):
130-
if self._is_cached(self._cache_dir):
131-
self._tokenizer = await self._load_from_cache(self._cache_dir / _TOKENIZER_FILE)
132-
else:
138+
try:
139+
if self._is_cached(self._cache_dir):
140+
self._tokenizer = await self._load_from_cache(self._cache_dir / _TOKENIZER_FILE)
141+
except Exception as e:
142+
_logger.error(f"Error loading tokenizer from cache. Trying to download: {e}")
143+
144+
if self._tokenizer is None:
133145
tokenizer_from_pretrained = await self._make_async_call(
134146
callback_func=Tokenizer.from_pretrained,
135147
identifier=self._model_path,

0 commit comments

Comments
 (0)